Skip to content

Commit

Permalink
LEXIO-38100 Refactor ID setting algorithm (#11)
Browse files Browse the repository at this point in the history
* refactor id algo

* Version bumped to 0.11.0

Co-authored-by: ns-circle-ci <[email protected]>
  • Loading branch information
jdrake and ns-circle-ci authored May 25, 2022
1 parent 43b49ac commit e7ad07e
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 38 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pysaql"
version = "0.10.0"
version = "0.11.0"
description = "Python SAQL query builder"
authors = ["Jonathan Drake <[email protected]>"]
license = "BSD-3-Clause"
Expand Down
2 changes: 1 addition & 1 deletion pysaql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Python SAQL query builder"""

__version__ = "0.10.0"
__version__ = "0.11.0"
85 changes: 49 additions & 36 deletions pysaql/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .enums import FillDateTypeString, JoinType, Order
from .scalar import BinaryOperation, field, Scalar
from .util import stringify, stringify_list
from .util import flatten, stringify, stringify_list

__ALL__ = ["load", "cogroup"]

Expand All @@ -22,6 +22,15 @@ class StreamStatement(ABC):

stream: "Stream"

def get_streams(self) -> list[Stream]:
"""Get a flat list of streams nested within this stream statement
Returns:
list of streams
"""
return []


class Stream:
"""Base class for a SAQL data stream"""
Expand All @@ -44,39 +53,14 @@ def ref(self) -> str:
"""Stream reference in the SAQL query"""
return f"q{self._id}"

def increment_id(self, incr: int) -> int:
"""Increment the stream ID
This should not be called by clients.
Args:
incr: Value to increment
def get_streams(self) -> list[Stream]:
"""Get a flat list of streams nested within this stream
Returns:
new stream ID
list of streams
"""
max_id = 0
i = 0
for statement in self._statements:
if isinstance(statement, LoadStatement):
statement.stream._id += incr + i
max_id = max(max_id, statement.stream._id)
i += 1
elif isinstance(statement, (CogroupStatement, UnionStatement)):
# For cogroup and union statements, leave the left-most (first) branch alone
if isinstance(statement, CogroupStatement):
streams = [stream for (stream, _) in statement.streams[1:]]
else:
streams = list(statement.streams[1:])

for stream in streams:
stream.increment_id(incr + i)
max_id = max(max_id, stream._id)
i += 1

self._id = max_id + 1
return self._id
return flatten([s.get_streams() for s in self._statements])

def add_statement(self, statement: StreamStatement) -> None:
"""Add a statement to the stream
Expand All @@ -86,6 +70,9 @@ def add_statement(self, statement: StreamStatement) -> None:
"""
self._statements.append(statement)
# Update all stream IDs
for i, s in enumerate(flatten(statement.get_streams())):
s._id = i

def field(self, name: str) -> field:
"""Create a new field object scoped to this stream
Expand Down Expand Up @@ -215,6 +202,15 @@ def __str__(self) -> str:
"""Cast this load statement to a string"""
return f'{self.stream.ref} = load "{self.name}";'

def get_streams(self) -> list[Stream]:
"""Get a flat list of streams nested within this stream statement
Returns:
list of streams
"""
return [self.stream]


class ProjectionStatement(StreamStatement):
"""Statement to project columns from a stream"""
Expand Down Expand Up @@ -402,6 +398,20 @@ def __str__(self) -> str:

return "\n".join(lines)

def get_streams(self) -> list[Stream]:
"""Get a flat list of streams nested within this stream statement
Returns:
list of streams
"""
return flatten(
[
[stream.get_streams() for (stream, _) in self.streams],
[self.stream],
]
)


class UnionStatement(StreamStatement):
"""Statement to combine (union) two or more streams with the same structure into one"""
Expand Down Expand Up @@ -436,6 +446,15 @@ def __str__(self) -> str:
lines.append(f"{self.stream.ref} = union {', '.join(stream_refs)};")
return "\n".join(lines)

def get_streams(self) -> list[Stream]:
"""Get a flat list of streams nested within this stream statement
Returns:
list of streams
"""
return flatten([[s.get_streams() for s in self.streams], [self.stream]])


class FillStatement(StreamStatement):
"""Statement to fill a data stream with missing dates"""
Expand Down Expand Up @@ -506,9 +525,6 @@ def cogroup(
"""
stream = Stream()
stream.add_statement(CogroupStatement(stream, streams, join_type))
# Increment stream IDs for all streams contained in this cogroup statement.
# We'll use the ID of the first stream as the basis for incrementing.
stream.increment_id(streams[0][0]._id)
return stream


Expand All @@ -527,7 +543,4 @@ def union(*streams: Stream) -> Stream:
"""
stream = Stream()
stream.add_statement(UnionStatement(stream, streams))
# Increment stream IDs for all streams contained in this union statement.
# We'll use the ID of the first stream as the basis for incrementing.
stream.increment_id(streams[0]._id)
return stream
17 changes: 17 additions & 0 deletions pysaql/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,20 @@ def stringify_list(seq: Sequence) -> str:
"""
seq = [seq] if not isinstance(seq, (list, tuple, set)) else seq
return f"({', '.join(str(s) for s in seq)})" if len(seq) > 1 else str(seq[0])


def flatten(seq: list) -> list:
"""Recursively flatten a list
Args:
seq: Sequence of items
Returns:
flatten list of items
"""
if not seq:
return seq
if isinstance(seq[0], list):
return flatten(seq[0]) + flatten(seq[1:])
return seq[:1] + flatten(seq[1:])
10 changes: 10 additions & 0 deletions tests/unit/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,13 @@ def test_stringify_list__one():
def test_stringify_list__multiple():
"""Should stringify a list with one item"""
assert mod_ut.stringify_list(["foo", "bar"]) == "(foo, bar)"


def test_flatten__empty():
"""Should return empty list"""
assert mod_ut.flatten([]) == []


def test_flatten__nested():
"""Should flatten nested list"""
assert mod_ut.flatten([1, [2, [3, [4, 5]], 6], 7]) == [1, 2, 3, 4, 5, 6, 7]

0 comments on commit e7ad07e

Please sign in to comment.