Skip to content

Commit

Permalink
Fast header-only streaming log reads (#468)
Browse files Browse the repository at this point in the history
* Improve header_only log reading performance

* Revert changes to pyproject.toml

* Preserve exception information

---------

Co-authored-by: jjallaire <[email protected]>
  • Loading branch information
MSchmatzAISI and jjallaire authored Sep 20, 2024
1 parent 0e45169 commit 8844754
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 50 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ debugpy
docstring-parser>=0.16
fsspec
httpx
json-stream
ijson
jsonlines
jsonpatch
jsonschema
Expand Down
132 changes: 85 additions & 47 deletions src/inspect_ai/log/_file.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
import re
from pathlib import Path
from typing import Any, Callable, Literal, cast
from typing import Any, Callable, Literal, cast, get_args
from urllib.parse import urlparse

import json_stream # type: ignore
import ijson # type: ignore
from ijson import IncompleteJSONError
from pydantic import BaseModel
from pydantic_core import from_json, to_json

Expand Down Expand Up @@ -152,6 +153,72 @@ def eval_log_json(log: EvalLog) -> str:
).decode()


def _validate_version(ver: int) -> None:
if ver > LOG_SCHEMA_VERSION:
raise ValueError(f"Unable to read version {ver} of log format.")


def _read_header_streaming(log_file: str) -> EvalLog:
with file(log_file, "r") as f:
# Do low-level parsing to get the version number and also
# detect the presence of results or error sections
version: int | None = None
has_results = False
has_error = False

for prefix, event, value in ijson.parse(f):
if (prefix, event) == ("version", "number"):
version = value
elif (prefix, event) == ("results", "start_map"):
has_results = True
elif (prefix, event) == ("error", "start_map"):
has_error = True
elif prefix == "samples":
# Break as soon as we hit samples as that can be very large
break

if version is None:
raise ValueError("Unable to read version of log format.")

_validate_version(version)
version = LOG_SCHEMA_VERSION

# Rewind the file to the beginning to re-parse the contents of fields
f.seek(0)

# Parse the log file, stopping before parsing samples
for k, v in ijson.kvitems(f, ""):
if k == "status":
assert v in get_args(
Literal["started", "success", "cancelled", "error"]
)
status: Literal["started", "success", "cancelled", "error"] = v
if k == "eval":
eval = EvalSpec(**v)
elif k == "plan":
plan = EvalPlan(**v)
elif k == "results":
results = EvalResults(**v)
elif k == "stats":
stats = EvalStats(**v)
if not has_error:
# Exit before parsing samples
break
elif k == "error":
error = EvalError(**v)
break

return EvalLog(
eval=eval,
plan=plan,
results=results if has_results else None,
stats=stats,
status=status,
version=version,
error=error if has_error else None,
)


def read_eval_log(log_file: str | FileInfo, header_only: bool = False) -> EvalLog:
"""Read an evaluation log.
Expand All @@ -166,51 +233,22 @@ def read_eval_log(log_file: str | FileInfo, header_only: bool = False) -> EvalLo
# resolve to file path
log_file = log_file if isinstance(log_file, str) else log_file.name

# verify we know about this version of the log file format
def validate_version(ver: int) -> None:
if ver > LOG_SCHEMA_VERSION:
raise ValueError(f"Unable to read version {ver} of log format.")

# header-only uses json-stream
if header_only:
with file(log_file, "r") as f:
try:
data = json_stream.load(f, persistent=True)

def read_field(field: str) -> Any:
if field in data.keys():
return json_stream.to_standard_types(data[field])
else:
return None

# fail for unknown version
validate_version(read_field("version"))

# set the version to the schema version we'll be returning
version = LOG_SCHEMA_VERSION

results = read_field("results")
error = read_field("error")

return EvalLog(
version=version,
status=read_field("status"),
eval=EvalSpec(**read_field("eval")),
plan=EvalPlan(**read_field("plan")),
results=EvalResults(**results) if results else None,
stats=EvalStats(**read_field("stats")),
error=EvalError(**error) if error else None,
)
# The Python JSON serializer supports NaN and Inf, however
# this isn't technically part of the JSON spec. The json-stream
# library shares this limitation, so if we fail with an
# invalid character then we move on and and parse w/ pydantic
# (which does support NaN and Inf by default)
except ValueError as ex:
if str(ex).find("Invalid JSON character") != -1:
pass
else:
raise ex
try:
return _read_header_streaming(log_file)
# The Python JSON serializer supports NaN and Inf, however
# this isn't technically part of the JSON spec. The json-stream
# library shares this limitation, so if we fail with an
# invalid character then we move on and and parse w/ pydantic
# (which does support NaN and Inf by default)
except (ValueError, IncompleteJSONError) as ex:
if (
str(ex).find("Invalid JSON character") != -1
or str(ex).find("invalid char in json text") != -1
):
pass
else:
raise ValueError(f"Unable to read log file: {log_file}") from ex

# parse full log (also used as a fallback for header_only encountering NaN or Inf)
with file(log_file, "r") as f:
Expand All @@ -219,7 +257,7 @@ def read_field(field: str) -> Any:
log = EvalLog(**raw_data)

# fail for unknown version
validate_version(log.version)
_validate_version(log.version)

# set the version to the schema version we'll be returning
log.version = LOG_SCHEMA_VERSION
Expand Down
4 changes: 4 additions & 0 deletions src/inspect_ai/log/_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,10 @@ class EvalStats(BaseModel):


class EvalLog(BaseModel):
# WARNING: The order of these fields is important for the log file format.
# Do not change the order of these fields without incrementing the version number,
# updating the log file read/write functionality (such as read_eval_log),
# and updating the tests.
version: int = Field(default=2)
"""Eval log file format version."""

Expand Down
4 changes: 4 additions & 0 deletions tests/log/test_eval_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def test_fail_version():
check_log_raises(log_path("log_version_3"))


def test_valid_log_header():
read_eval_log(log_path("log_valid"), header_only=True)


def check_log_raises(log_file):
with pytest.raises(ValueError):
read_eval_log(log_file)
Expand Down
63 changes: 63 additions & 0 deletions tests/log/test_eval_log/log_valid.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
{
"version": 2,
"status": "success",
"eval": {
"task": "wikipedia",
"task_version": 0,
"task_file": "examples/langchain/wikipedia.py",
"task_id": "YAdbKczyeSb6mEgPd3R9Qs",
"run_id": "i5LyrzaUdD9K4EW5WTAd5t",
"created": "2024-05-05T07:59:35",
"dataset": {
"name": "wikipedia",
"location": "wikipedia.jsonl"
},
"model": "openai/gpt-4",
"task_attribs": {},
"task_args": {},
"model_args": {},
"config": {
"limit": 20
}
},
"plan": {
"name": "plan",
"steps": [
{
"solver": "wikipedia_search",
"params": {}
}
],
"config": {}
},
"results": {
"scorers": [{
"name": "model_graded_fact",
"params": {},
"metrics": {
"accuracy": {
"name": "accuracy",
"value": 1,
"options": {}
},
"bootstrap_std": {
"name": "bootstrap_std",
"value": 0.0,
"options": {}
}
}
}]
},
"stats": {
"started_at": "2024-05-05T07:59:35",
"completed_at": "2024-05-05T08:00:03",
"model_usage": {
"openai/gpt-4": {
"input_tokens": 8868,
"output_tokens": 1351,
"total_tokens": 10219
}
}
},
"logging": []
}
4 changes: 2 additions & 2 deletions tests/log/test_eval_log/log_version_3.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"results": {
"scorers": [{
"name": "model_graded_fact",
"params": {}
"params": {},
"metrics": {
"accuracy": {
"name": "accuracy",
Expand All @@ -60,4 +60,4 @@
}
},
"logging": []
}
}

0 comments on commit 8844754

Please sign in to comment.