-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
168 lines (129 loc) · 4.31 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import math
import queue
import random
import time
import threading
from typing import Iterator, Optional, Tuple
from rich import box
from rich.align import Align
from rich.layout import Layout
from rich.console import Console
from rich.live import Live
from rich.progress_bar import ProgressBar
from rich.table import Table
from rich.columns import Columns
from river.compose import TransformerUnion
from river.feature_extraction import Agg
Stream = Iterator[Tuple[int, dict]]
def silly_stream() -> Stream:
while True:
yield {
"t": (t := time.time()),
"c": (c := random.choice(["a", "b"])),
"d": (d := random.choice(["c", "d"])),
"x": {"a": -10, "b": 10}[c] * (1 + math.cos(t)),
"y": {"a": -10, "b": 10}[c] * (1 + math.sin(t)),
}
time.sleep(0.1)
class Buffer(threading.Thread):
def __init__(self, stream: Stream):
super().__init__()
self.running = True
self.stream = stream
self.records = queue.Queue()
def stop(self):
self.running = False
def run(self):
while self.running:
self.records.put(next(self.stream))
def __iter__(self):
while not self.records.empty():
yield self.records.get()
def __len__(self):
return self.records.qsize()
def _river_agg_to_rich_table(agg: Agg) -> Table:
table = Table(title=agg.feature_name)
for by in agg.by:
table.add_column(by, justify="center", no_wrap=True)
table.add_column(agg.on)
# The list() is here to copy the values of agg.groups, because the
# size might change during iteration over items, due to the data
# being updated in the background.
for by, stat in list(agg.groups.items()):
table.add_row(*by, f"{stat.get():,.5f}")
table.box = box.SIMPLE_HEAD
return table
class ETL(threading.Thread):
def __init__(self, *aggs, stream: Stream):
super().__init__()
self.running = True
self.stream = stream
self.agg = TransformerUnion(*aggs)
self.n = 0
def stop(self):
self.running = False
@property
def _percent_processed(self) -> Optional[float]:
"""Return the % of records that have been processed.
This only makes sense if the stream has some notion of state. It only works
if the stream can tell us how much data it is holding for us to process. Raw
streams do no do this. A buffer needs to be wrapped around the stream for this
to be possible.
"""
try:
return self.n / (self.n + len(self.stream))
except AttributeError:
return None
def run(self):
while self.running:
for record in self.stream:
self.agg.learn_one(record)
self.n += 1
class Display(threading.Thread):
def __init__(self, etl: ETL):
super().__init__()
self.running = True
self.etl = etl
def stop(self):
self.running = False
def run(self):
def make_tables():
return [
_river_agg_to_rich_table(agg)
for agg in self.etl.agg.transformers.values()
]
with Live(refresh_per_second=10, screen=True, transient=True) as live:
while self.running:
layout = Layout()
layout.split_column(
Layout(name="upper", ratio=99),
Layout(name="lower", ratio=1),
)
if percent_processed := self.etl._percent_processed is not None:
progress = ProgressBar(
total=self.etl.n + len(self.etl.stream), completed=self.etl.n
)
layout["lower"].update(progress)
layout["upper"].update(
Align(Columns(make_tables()), align="center", vertical="middle")
)
# layout = Align(layout, vertical="middle")
live.update(layout)
from river import stats
buffer = Buffer(silly_stream())
buffer.start()
time.sleep(5)
etl = ETL(
Agg(on="x", by="c", how=stats.Mean()),
Agg(on="x", by="d", how=stats.Mean()),
stream=buffer,
)
etl.start()
time.sleep(3)
etl.stop()
display = Display(etl)
display.start()
time.sleep(10)
display.stop()
etl.stop()
buffer.stop()