Skip to content

Commit

Permalink
Improve inference choice examples (#311)
Browse files Browse the repository at this point in the history
* Improve inference choice examples

* Fix style

---------

Co-authored-by: Albert Villanova del Moral <[email protected]>
  • Loading branch information
aymeric-roucher and albertvillanova authored Jan 24, 2025
1 parent 0196dc7 commit de7b0ee
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 93 deletions.
51 changes: 51 additions & 0 deletions examples/agent_from_any_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Optional

from smolagents import HfApiModel, LiteLLMModel, TransformersModel, tool
from smolagents.agents import CodeAgent, ToolCallingAgent


# Choose which inference type to use!

available_inferences = ["hf_api", "transformers", "ollama", "litellm"]
chosen_inference = "transformers"

print(f"Chose model {chosen_inference}")

if chosen_inference == "hf_api":
model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct")

elif chosen_inference == "transformers":
model = TransformersModel(model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct", device_map="auto", max_new_tokens=1000)

elif chosen_inference == "ollama":
model = LiteLLMModel(
model_id="ollama_chat/llama3.2",
api_base="http://localhost:11434", # replace with remote open-ai compatible server if necessary
api_key="your-api-key", # replace with API key if necessary
)

elif chosen_inference == "litellm":
# For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-latest'
model = LiteLLMModel(model_id="gpt-4o")


@tool
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
"""
Get weather in the next days at given location.
Secretly this tool does not care about the location, it hates the weather everywhere.
Args:
location: the location
celsius: the temperature
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"


agent = ToolCallingAgent(tools=[get_weather], model=model)

print("ToolCallingAgent:", agent.run("What's the weather like in Paris?"))

agent = CodeAgent(tools=[get_weather], model=model)

print("ToolCallingAgent:", agent.run("What's the weather like in Paris?"))
30 changes: 0 additions & 30 deletions examples/tool_calling_agent_from_any_llm.py

This file was deleted.

29 changes: 0 additions & 29 deletions examples/tool_calling_agent_mcp.py

This file was deleted.

29 changes: 0 additions & 29 deletions examples/tool_calling_agent_ollama.py

This file was deleted.

10 changes: 5 additions & 5 deletions src/smolagents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,6 @@ def __call__(
messages=messages,
stop_sequences=stop_sequences,
grammar=grammar,
tools_to_call_from=tools_to_call_from,
**kwargs,
)

Expand All @@ -497,9 +496,6 @@ def __call__(
if max_new_tokens:
completion_kwargs["max_new_tokens"] = max_new_tokens

if stop_sequences:
completion_kwargs["stopping_criteria"] = self.make_stopping_criteria(stop_sequences)

if tools_to_call_from is not None:
prompt_tensor = self.tokenizer.apply_chat_template(
messages,
Expand All @@ -518,7 +514,11 @@ def __call__(
prompt_tensor = prompt_tensor.to(self.model.device)
count_prompt_tokens = prompt_tensor["input_ids"].shape[1]

out = self.model.generate(**prompt_tensor, **completion_kwargs)
out = self.model.generate(
**prompt_tensor,
stopping_criteria=(self.make_stopping_criteria(stop_sequences) if stop_sequences else None),
**completion_kwargs,
)
generated_tokens = out[0, count_prompt_tokens:]
output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
self.last_input_token_count = count_prompt_tokens
Expand Down

0 comments on commit de7b0ee

Please sign in to comment.