Skip to content

Commit

Permalink
Fix RecordReader not reading from stdin by default (#94)
Browse files Browse the repository at this point in the history
Calling `RecordReader()` without arguments should always default to stdin
  • Loading branch information
yunzheng authored Oct 27, 2023
1 parent 6144cf4 commit 9cad89f
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
2 changes: 1 addition & 1 deletion flow/record/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ def RecordAdapter(
if sub_adapter:
cls_url = sub_adapter + "://" + cls_url
if out is False:
if url in ("-", ""):
if url in ("-", "", None) and fileobj is None:
# For reading stdin, we cannot rely on an extension to know what sort of stream is incoming. Thus, we will
# treat it as a 'fileobj', where we can peek into the stream and try to select the appropriate adapter.
fileobj = getattr(sys.stdin, "buffer", sys.stdin)
Expand Down
38 changes: 38 additions & 0 deletions tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,5 +646,43 @@ def test_fieldtype_typedlist_net_ipaddress():
assert issubclass(fieldtype("net.ipaddress[]"), fieldtypes.FieldType)


def test_record_reader_default_stdin(tmp_path):
"""RecordWriter should default to stdin if no path is given"""
TestRecord = RecordDescriptor(
"test/record",
[
("string", "text"),
],
)

# write some records
records_path = tmp_path / "test.records"
with RecordWriter(records_path) as writer:
writer.write(TestRecord("foo"))

# Test stdin
with patch("sys.stdin", BytesIO(records_path.read_bytes())):
with RecordReader() as reader:
for record in reader:
assert record.text == "foo"


def test_record_writer_default_stdout(capsysbinary):
"""RecordWriter should default to stdout if no path is given"""
TestRecord = RecordDescriptor(
"test/record",
[
("string", "text"),
],
)

# write a record to stdout
with RecordWriter() as writer:
writer.write(TestRecord("foo"))

stdout = capsysbinary.readouterr().out
assert stdout.startswith(b"\x00\x00\x00\x0f\xc4\rRECORDSTREAM\n")


if __name__ == "__main__":
__import__("standalone_test").main(globals())

0 comments on commit 9cad89f

Please sign in to comment.