Skip to content

Commit

Permalink
build: stream llama test
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed Dec 9, 2024
1 parent 514c8cd commit 22be725
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions tests/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@ def debug_logging():
async def test_llama(create_kani, gh_log):
"""Do one round of conversation with LLaMA."""
ai = create_kani()
resp = await ai.chat_round_str("What are some cool things to do in Tokyo?")
print("Q: What are some cool things to do in Tokyo?\n")
stream = ai.chat_round_stream("What are some cool things to do in Tokyo?")
async for token in stream:
print(token, end="", flush=True)
resp = await stream.message()

gh_log.write(
"# LLaMA Basic\n"
"*This is a real output from kani running LLaMA v2 on GitHub Actions.*\n\n"
Expand All @@ -71,6 +76,18 @@ async def test_chatting_llamas(create_kani, gh_log):
f"### Tourist\n{tourist_response}\n"
)
for _ in range(5):
guide_response = await guide.chat_round_str(tourist_response)
tourist_response = await tourist.chat_round_str(guide_response)
print("\n========== GUIDE ==========\n")
guide_stream = guide.chat_round_stream(tourist_response)
async for token in guide_stream:
print(token, end="", flush=True)
guide_msg = await guide_stream.message()
guide_response = guide_msg.text

print("\n========== TOURIST ==========\n")
tourist_stream = tourist.chat_round_stream(guide_response)
async for token in tourist_stream:
print(token, end="", flush=True)
tourist_msg = await tourist_stream.message()
tourist_response = tourist_msg.text

gh_log.write(f"### Guide\n{guide_response}\n\n### Tourist\n{tourist_response}\n\n")

0 comments on commit 22be725

Please sign in to comment.