Skip to content

Commit

Permalink
Add invoke/batch test cases and pass stop to batch retry
Browse files Browse the repository at this point in the history
  • Loading branch information
ZixinYang committed Oct 26, 2023
1 parent b80e462 commit c2f6ce4
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 15 deletions.
25 changes: 14 additions & 11 deletions libs/langchain/langchain/llms/fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,16 @@ def _generate(
"model": self.model,
**self.model_kwargs,
}
sub_prompts = self.get_batch_prompts(params, prompts, stop)
sub_prompts = self.get_batch_prompts(prompts)
choices = []
for _prompts in sub_prompts:
response = completion_with_retry_batching(
self, self.use_retry, prompt=_prompts, run_manager=run_manager, **params
self,
self.use_retry,
prompt=_prompts,
run_manager=run_manager,
stop=stop,
**params,
)
choices.extend(response)

Expand All @@ -110,28 +115,26 @@ async def _agenerate(
"model": self.model,
**self.model_kwargs,
}
sub_prompts = self.get_batch_prompts(params, prompts, stop)
sub_prompts = self.get_batch_prompts(prompts)
choices = []
for _prompts in sub_prompts:
response = await acompletion_with_retry_batching(
self, self.use_retry, prompt=_prompts, run_manager=run_manager, **params
self,
self.use_retry,
prompt=_prompts,
run_manager=run_manager,
stop=stop,
**params,
)
choices.extend(response)

return self.create_llm_result(choices, prompts)

def get_batch_prompts(
self,
params: Dict[str, Any],
prompts: List[str],
stop: Optional[List[str]] = None,
) -> List[List[str]]:
"""Get the sub prompts for llm call."""
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop

sub_prompts = [
prompts[i : i + self.batch_size]
for i in range(0, len(prompts), self.batch_size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,23 @@ def test_chat_fireworks_llm_output_stop_words() -> None:
assert llm_result.generations[0][0].text[-1] == ","


def test_fireworks_invoke() -> None:
"""Tests chat completion with invoke"""
chat = ChatFireworks()
result = chat.invoke("How is the weather in New York today?", stop=[","])
assert isinstance(result.content, str)
assert result.content[-1] == ","


@pytest.mark.asyncio
async def test_fireworks_ainvoke() -> None:
"""Tests chat completion with invoke"""
chat = ChatFireworks()
result = await chat.ainvoke("How is the weather in New York today?", stop=[","])
assert isinstance(result.content, str)
assert result.content[-1] == ","


def test_fireworks_batch() -> None:
"""Test batch tokens from ChatFireworks."""
chat = ChatFireworks()
Expand All @@ -91,15 +108,18 @@ def test_fireworks_batch() -> None:
"What is the weather in Redwood City, CA today",
],
config={"max_concurrency": 5},
stop=[","],
)
for token in result:
assert isinstance(token.content, str)
assert token.content[-1] == ","


@pytest.mark.asyncio
async def test_fireworks_abatch() -> None:
"""Test batch tokens from ChatFireworks."""
llm = ChatFireworks()
result = await llm.abatch(
chat = ChatFireworks()
result = await chat.abatch(
[
"What is the weather in Redwood City, CA today",
"What is the weather in Redwood City, CA today",
Expand All @@ -109,9 +129,11 @@ async def test_fireworks_abatch() -> None:
"What is the weather in Redwood City, CA today",
],
config={"max_concurrency": 5},
stop=[","],
)
for token in result:
assert isinstance(token.content, str)
assert token.content[-1] == ","


def test_fireworks_streaming() -> None:
Expand Down Expand Up @@ -154,5 +176,10 @@ async def test_fireworks_astream() -> None:
"""Test streaming tokens from Fireworks."""
llm = ChatFireworks()

async for token in llm.astream("Who's the best quarterback in the NFL?"):
last_token = ""
async for token in llm.astream(
"Who's the best quarterback in the NFL?", stop=[","]
):
last_token = token.content
assert isinstance(token.content, str)
assert last_token[-1] == ","
61 changes: 60 additions & 1 deletion libs/langchain/tests/integration_tests/llms/test_fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,60 @@ def test_fireworks_model_param() -> None:
assert llm.model == "foo"


def test_fireworks_invoke() -> None:
"""Tests completion with invoke"""
llm = Fireworks()
output = llm.invoke("How is the weather in New York today?", stop=[","])
assert isinstance(output, str)
assert output[-1] == ","


@pytest.mark.asyncio
async def test_fireworks_ainvoke() -> None:
"""Tests completion with invoke"""
llm = Fireworks()
output = await llm.ainvoke("How is the weather in New York today?", stop=[","])
assert isinstance(output, str)
assert output[-1] == ","


def test_fireworks_batch() -> None:
"""Tests completion with invoke"""
llm = Fireworks()
output = llm.batch(
[
"How is the weather in New York today?",
"How is the weather in New York today?",
"How is the weather in New York today?",
"How is the weather in New York today?",
"How is the weather in New York today?",
],
stop=[","],
)
for token in output:
assert isinstance(token, str)
assert token[-1] == ","


@pytest.mark.asyncio
async def test_fireworks_abatch() -> None:
"""Tests completion with invoke"""
llm = Fireworks()
output = await llm.abatch(
[
"How is the weather in New York today?",
"How is the weather in New York today?",
"How is the weather in New York today?",
"How is the weather in New York today?",
"How is the weather in New York today?",
],
stop=[","],
)
for token in output:
assert isinstance(token, str)
assert token[-1] == ","


def test_fireworks_multiple_prompts() -> None:
"""Test completion with multiple prompts."""
llm = Fireworks()
Expand Down Expand Up @@ -87,8 +141,13 @@ async def test_fireworks_streaming_async() -> None:
"""Test stream completion."""
llm = Fireworks()

async for token in llm.astream("Who's the best quarterback in the NFL?"):
last_token = ""
async for token in llm.astream(
"Who's the best quarterback in the NFL?", stop=[","]
):
last_token = token
assert isinstance(token, str)
assert last_token[-1] == ","


@pytest.mark.asyncio
Expand Down

0 comments on commit c2f6ce4

Please sign in to comment.