'''
Example of an asynchronous pipeline for generating daily trade candles.

You can run this example either against your own data,
or create a sample database using the following commands:

	docker run -it --name trade-candles-db -e POSTGRES_PASSWORD=password -p 5432:5432 postgres:14

Exec into docker container:

	docker exec -it trade-candles-db psql -U postgres

Create table:

	CREATE TABLE "trades" (
		time TIMESTAMP NOT NULL,
		ticker VARCHAR NOT NULL,
		price FLOAT NOT NULL,
		quantity BIGINT NOT NULL
	);

Generate some artificial data:

	WITH daily_ranges AS (
		SELECT
			daily_ranges.*
		FROM generate_series(TIMESTAMP '2021-12-1', TIMESTAMP '2021-12-31', '1 DAY') AS dates
		CROSS JOIN (SELECT ticker, base_price FROM (VALUES ('AAPL', 173), ('MSFT', 310), ('TSLA', 1049)) AS t (ticker, base_price)) AS tickers,
		LATERAL (
			SELECT
				dates,
				ticker,
				base_price * (1 + (random() / (random() * 10))) AS "r1",
				base_price * (1 + (random() / (random() * 10))) AS "r2"
		) AS "daily_ranges"
	)
	INSERT INTO "trades" (
		SELECT
			trades.*
		FROM daily_ranges,
		LATERAL (
			SELECT
				dates + ('14:30' + random() * (dates + '8 hours' - dates)),
				ticker,
				round((LEAST(r1, r2) + (abs(r1 - r2) * random()))::numeric, 4)::float AS "price",
				(random() * 10000)::bigint AS "quantity"
			FROM generate_series(100, (random() * 10000)::int)
		) AS trades
	);

Run the code:

	python full_pipeline_example.py

Note that you have to have `asyncpg` installed.
This can be done by using `pip install asyncpg`
'''
#!/usr/bin/env python3.10
from typing import Any, AsyncIterator, Optional, TypeAlias, TypedDict
import asyncio
import collections
import datetime
import pprint

import asyncpg


async def acquire_data(credentials: dict[str, Any]) -> AsyncIterator[dict[str, Any]]:
	connection = await asyncpg.connect(**credentials)
	async with connection.transaction():
		async for row in connection.cursor(
			'''
			SELECT
					ticker,
					time,
					quantity,
					price
			FROM "trades"
			WHERE
					time BETWEEN $1 AND $2
			ORDER BY time ASC
			''',
			datetime.datetime(2021, 12, 1),
			datetime.datetime(2021, 12, 31),
		):
			yield row


Trade: TypeAlias = dict[str, Any]


async def group_by_date_and_ticker(
	stream: AsyncIterator[
		Trade
	],
) -> AsyncIterator[list[Trade]]:
	last_observed_date: Optional[datetime.date] = None
	# Data belonging to same date
	# Key here is ticker and value is a list of trades for same date and ticker
	# We use defaultdict so that if key is missing it is added with an empty list value
	grouped_data: dict[str, list[Trade]] = collections.defaultdict(list)

	async for trade in stream:
		# Extract date on which this trade happend
		date = trade['time'].date()
		# In case we've detected change of dates
		# We yield grouped data that we've stored for previous date
		if date != last_observed_date:
			for ticker_data in grouped_data.values():
				yield ticker_data
			grouped_data.clear()

		grouped_data[trade['ticker']].append(trade)
		last_observed_date = date

	# After we've iterated over all of the data from input (stream)
	# we need to yield the rest of the grouped data.
	# At this point grouped data should contain data from the last available date
	for ticker_data in grouped_data.values():
		yield ticker_data


# Thats how output of the aggregation should look like
class TradeCandle(TypedDict):
	date: datetime.date
	ticker: str
	open: float
	high: float
	low: float
	close: float
	volume: int


async def aggregate_into_trade_candles(
	# Mind the use of "Trade" TypeAlias from the previous code block
	stream: AsyncIterator[list[Trade]],
) -> AsyncIterator[TradeCandle]:
	async for trades in stream:
		# If we receive empty list, skip it
		if not trades:
			continue
		# We compute volume, high and low in loop so that we iterate
		# over the data only once.
		# We could use something like:
		#   high_price = max(trades, key = operator.itemgetter('price'))
		#   low_price = min(trades, key = operator.itemgetter('price'))
		#   volume = sum(trades, key = operator.itemgetter('quantity'))
		# But that would iterate over the data three times
		# Note that `operator.itemgetter(key)` is same as `lambda trade: trade[key]`
		high_price: Optional[float] = None
		low_price: Optional[float] = None
		volume: int = 0
		for trade in trades:
			if low_price is None or low_price > trade['price']:
				low_price = trade['price']

			if high_price is None or high_price < trade['price']:
				high_price = trade['price']
			volume += trade['quantity']

		yield TradeCandle(
			# Trades still have exact time during the day, we need to extract the date part
			date = trades[0]['time'].date(),
			ticker = trades[0]['ticker'],
			# Price of the first trade
			open = trades[0]['price'],
			high = high_price,
			low = low_price,
			# Price of the last trade
			close = trades[-1]['price'],
		)


async def print_output(stream: AsyncIterator[TradeCandle]) -> None:
	async for trade_candle in stream:
		pprint.pprint(trade_candle)


if __name__ == '__main__':
	asyncio.run(
		print_output(
			aggregate_into_trade_candles(
				group_by_date_and_ticker(
					acquire_data(
						credentials = {
							'host': 'localhost',
							'port': 5432,
							'user': 'postgres',
							'password': 'password',
							'database': 'postgres',
						}
					)
				)
			)
		)
	)
