Skip to content

Commit

Permalink
Stream DB events rather than loading the entire table in memory
Browse files Browse the repository at this point in the history
  • Loading branch information
ESultanik committed Apr 13, 2021
1 parent 7220475 commit 3d4acd9
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions polytracker/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from sqlalchemy.orm import relationship, sessionmaker
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm.session import Session
from sqlalchemy.orm.query import Query
from sqlalchemy import (
BigInteger,
BLOB,
Expand Down Expand Up @@ -60,6 +61,17 @@ class EdgeType(IntEnum):
BACKWARD = 1


def stream_results(query: Query, window_size: int = 1000) -> Iterator:
start = 0
while True:
stop = start + window_size
results = query.slice(start, stop).all()
if len(results) == 0:
break
yield from results
start += window_size


class DBInput(Base, Input): # type: ignore
__tablename__ = "input"
uid = Column("id", Integer, primary_key=True)
Expand Down Expand Up @@ -460,13 +472,13 @@ def __len__(self) -> int:
return self.session.query(DBTraceEvent).count()

def __iter__(self) -> Iterator[TraceEvent]:
return iter(
self.session.query(DBTraceEvent).order_by(DBTraceEvent.event_id.asc()).all()
return stream_results(
self.session.query(DBTraceEvent).order_by(DBTraceEvent.event_id.asc())
)

@property
def functions(self) -> Iterable[Function]:
return self.session.query(DBFunction).all()
return self.session.query(DBFunction)

def get_function(self, name: str) -> Function:
try:
Expand All @@ -482,18 +494,20 @@ def has_function(self, name: str) -> bool:

@property
def basic_blocks(self) -> Iterable[BasicBlock]:
return self.session.query(DBBasicBlock).all()
return self.session.query(DBBasicBlock)

def access_sequence(self) -> Iterator[TaintAccess]:
yield from self.session.query(DBTaintAccess).order_by(DBTaintAccess.event_id.asc()).all()
yield from stream_results(
self.session.query(DBTaintAccess).order_by(DBTaintAccess.event_id.asc())
)

@property
def num_accesses(self) -> int:
return self.session.query(DBTaintAccess).count()

@property
def inputs(self) -> Iterable[Input]:
return self.session.query(DBInput).all()
return self.session.query(DBInput)

def __getitem__(self, uid: int) -> TraceEvent:
raise NotImplementedError()
Expand Down

0 comments on commit 3d4acd9

Please sign in to comment.