Skip to content

Commit

Permalink
fix(trace): add retry to download_trace example
Browse files Browse the repository at this point in the history
  • Loading branch information
jezekra1 committed Dec 20, 2024
1 parent 843f16a commit 67dd0ca
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 117 deletions.
24 changes: 15 additions & 9 deletions examples/download_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

import json
import os
import time
from contextlib import suppress
from pprint import pprint

from openai import BaseModel, OpenAI
from openai import BaseModel, OpenAI, NotFoundError


def heading(text: str) -> str:
Expand All @@ -20,11 +22,7 @@ def heading(text: str) -> str:
bee_client = OpenAI(base_url=f'{os.getenv("BEE_API")}/v1', api_key=os.getenv("BEE_API_KEY"))

# Instantiate Observe client with Bee credentials from env, but DIFFERENT base_url (!)
observe_client = OpenAI(
base_url=f'{os.getenv("BEE_API")}/observe/v1', api_key=os.getenv("BEE_API_KEY"),
# Uploading trace is an asynchronous process that takes 40-60s, hence we use a higher number of retries
max_retries=10
)
observe_client = OpenAI(base_url=f'{os.getenv("BEE_API")}/observe/v1', api_key=os.getenv("BEE_API_KEY"))

print(heading("Create run"))
assistant = bee_client.beta.assistants.create(model="meta-llama/llama-3-1-70b-instruct")
Expand All @@ -44,9 +42,17 @@ def heading(text: str) -> str:
# Get trace_id
trace_info = bee_client.get(f"/threads/{thread.id}/runs/{run.id}/trace", cast_to=BaseModel)

# Get trace
params = {"include_tree": True}
trace = observe_client.get(f"/traces/{trace_info.id}", options={"params": params}, cast_to=BaseModel)

def get_trace(trace_id: str, params: dict):
# Uploading trace is an asynchronous process that takes 40-60s, hence we need to retry a few times.
for attempt in range(1, 10):
with suppress(NotFoundError):
return observe_client.get(f"/traces/{trace_info.id}", options={"params": params}, cast_to=BaseModel)
time.sleep(attempt * 0.5)
raise RuntimeError("Unable to download trace")


trace = get_trace(trace_info.id, {"include_tree": True})
print("Trace:")
print(json.dumps(trace.model_dump(mode="json"), indent=2))

Expand Down
Loading

0 comments on commit 67dd0ca

Please sign in to comment.