Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate Ruff Formatter into Pre-commit Workflow #119

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
11 changes: 8 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
repos:
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.1
rev: v0.8.6
hooks:
# Run the linter.
- id: ruff
types_or: [ python, pyi ]
args:
- --fix
- --exclude=examples/
# Run the formatter.
- id: ruff-format
types_or: [ python, pyi ]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v5.0.0
hooks:
- id: check-merge-conflict
- id: check-yaml
20 changes: 13 additions & 7 deletions examples/e2b_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

load_dotenv()


class GetCatImageTool(Tool):
name="get_cat_image"
name = "get_cat_image"
description = "Get a cat image"
inputs = {}
output_type = "image"
Expand All @@ -27,17 +28,22 @@ def forward(self):
get_cat_image = GetCatImageTool()

agent = CodeAgent(
tools = [get_cat_image, VisitWebpageTool()],
tools=[get_cat_image, VisitWebpageTool()],
model=HfApiModel(),
additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search",
use_e2b_executor=True
additional_authorized_imports=[
"Pillow",
"requests",
"markdownify",
], # "duckduckgo-search",
use_e2b_executor=True,
)

agent.run(
"Return me an image of a cat. Directly use the image provided in your state.", additional_args={"cat_image":get_cat_image()}
) # Asking to directly return the image from state tests that additional_args are properly sent to server.
"Return me an image of a cat. Directly use the image provided in your state.",
additional_args={"cat_image": get_cat_image()},
) # Asking to directly return the image from state tests that additional_args are properly sent to server.

# Try the agent in a Gradio UI
from smolagents import GradioUI

GradioUI(agent).launch()
GradioUI(agent).launch()
2 changes: 1 addition & 1 deletion examples/gradio_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
tools=[], model=HfApiModel(), max_steps=4, verbosity_level=1
)

GradioUI(agent, file_upload_folder='./data').launch()
GradioUI(agent, file_upload_folder='./data').launch()
4 changes: 3 additions & 1 deletion examples/inspect_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
# Let's setup the instrumentation first

trace_provider = TracerProvider()
trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter("http://0.0.0.0:6006/v1/traces")))
trace_provider.add_span_processor(
SimpleSpanProcessor(OTLPSpanExporter("http://0.0.0.0:6006/v1/traces"))
)

SmolagentsInstrumentor().instrument(tracer_provider=trace_provider, skip_dep_check=True)

Expand Down
19 changes: 13 additions & 6 deletions examples/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@


knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers"))
knowledge_base = knowledge_base.filter(
lambda row: row["source"].startswith("huggingface/transformers")
)

source_docs = [
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]})
Expand All @@ -26,6 +28,7 @@

from smolagents import Tool


class RetrieverTool(Tool):
name = "retriever"
description = "Uses semantic search to retrieve the parts of transformers documentation that could be most relevant to answer your query."
Expand All @@ -39,9 +42,7 @@ class RetrieverTool(Tool):

def __init__(self, docs, **kwargs):
super().__init__(**kwargs)
self.retriever = BM25Retriever.from_documents(
docs, k=10
)
self.retriever = BM25Retriever.from_documents(docs, k=10)

def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
Expand All @@ -56,14 +57,20 @@ def forward(self, query: str) -> str:
]
)


from smolagents import HfApiModel, CodeAgent

retriever_tool = RetrieverTool(docs_processed)
agent = CodeAgent(
tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_steps=4, verbosity_level=2
tools=[retriever_tool],
model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"),
max_steps=4,
verbose=True,
)

agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?")
agent_output = agent.run(
"For a transformers model training, which is slower, the forward or the backward pass?"
)

print("Final output:")
print(agent_output)
8 changes: 6 additions & 2 deletions examples/text_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,14 @@
inspector = inspect(engine)
columns_info = [(col["name"], col["type"]) for col in inspector.get_columns("receipts")]

table_description = "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info])
table_description = "Columns:\n" + "\n".join(
[f" - {name}: {col_type}" for name, col_type in columns_info]
)
print(table_description)

from smolagents import tool


@tool
def sql_engine(query: str) -> str:
"""
Expand All @@ -66,10 +69,11 @@ def sql_engine(query: str) -> str:
output += "\n" + str(row)
return output


from smolagents import CodeAgent, HfApiModel

agent = CodeAgent(
tools=[sql_engine],
model=HfApiModel("meta-llama/Meta-Llama-3.1-8B-Instruct"),
)
agent.run("Can you give me the name of the client who got the most expensive receipt?")
agent.run("Can you give me the name of the client who got the most expensive receipt?")
4 changes: 3 additions & 1 deletion examples/tool_calling_agent_from_any_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-20240620'
model = LiteLLMModel(model_id="gpt-4o")


@tool
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
"""
Expand All @@ -21,6 +22,7 @@ def get_weather(location: str, celsius: Optional[bool] = False) -> str:
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"


agent = ToolCallingAgent(tools=[get_weather], model=model)

print(agent.run("What's the weather like in Paris?"))
print(agent.run("What's the weather like in Paris?"))
6 changes: 4 additions & 2 deletions examples/tool_calling_agent_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

model = LiteLLMModel(
model_id="ollama_chat/llama3.2",
api_base="http://localhost:11434", # replace with remote open-ai compatible server if necessary
api_key="your-api-key" # replace with API key if necessary
api_base="http://localhost:11434", # replace with remote open-ai compatible server if necessary
api_key="your-api-key", # replace with API key if necessary
)


@tool
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
"""
Expand All @@ -20,6 +21,7 @@ def get_weather(location: str, celsius: Optional[bool] = False) -> str:
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"


agent = ToolCallingAgent(tools=[get_weather], model=model)

print(agent.run("What's the weather like in Paris?"))
6 changes: 4 additions & 2 deletions src/smolagents/e2b_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,16 @@ def __init__(self, additional_imports: List[str], tools: List[Tool], logger):
tool_definition_code = "\n".join(
[f"import {module}" for module in BASE_BUILTIN_MODULES]
)
tool_definition_code += textwrap.dedent("""
tool_definition_code += textwrap.dedent(
"""
class Tool:
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)

def forward(self, *args, **kwargs):
pass # to be implemented in child class
""")
"""
)
tool_definition_code += "\n\n".join(tool_codes)

tool_definition_execution = self.run_code_raise_errors(tool_definition_code)
Expand Down
48 changes: 33 additions & 15 deletions src/smolagents/local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,21 +1301,39 @@ def evaluate_ast(
return getattr(value, expression.attr)
elif isinstance(expression, ast.Slice):
return slice(
evaluate_ast(
expression.lower, state, static_tools, custom_tools, authorized_imports
)
if expression.lower is not None
else None,
evaluate_ast(
expression.upper, state, static_tools, custom_tools, authorized_imports
)
if expression.upper is not None
else None,
evaluate_ast(
expression.step, state, static_tools, custom_tools, authorized_imports
)
if expression.step is not None
else None,
(
evaluate_ast(
expression.lower,
state,
static_tools,
custom_tools,
authorized_imports,
)
if expression.lower is not None
else None
),
(
evaluate_ast(
expression.upper,
state,
static_tools,
custom_tools,
authorized_imports,
)
if expression.upper is not None
else None
),
(
evaluate_ast(
expression.step,
state,
static_tools,
custom_tools,
authorized_imports,
)
if expression.step is not None
else None
),
)
elif isinstance(expression, ast.DictComp):
return evaluate_dictcomp(
Expand Down
40 changes: 21 additions & 19 deletions src/smolagents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,12 @@ def validate_arguments(self):
f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
)
for input_name, input_content in self.inputs.items():
assert isinstance(input_content, dict), (
f"Input '{input_name}' should be a dictionary."
)
assert "type" in input_content and "description" in input_content, (
f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
)
assert isinstance(
input_content, dict
), f"Input '{input_name}' should be a dictionary."
assert (
"type" in input_content and "description" in input_content
), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
if input_content["type"] not in AUTHORIZED_TYPES:
raise Exception(
f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {AUTHORIZED_TYPES}."
Expand All @@ -207,13 +207,13 @@ def validate_arguments(self):
json_schema = _convert_type_hints_to_json_schema(self.forward)
for key, value in self.inputs.items():
if "nullable" in value:
assert key in json_schema and "nullable" in json_schema[key], (
f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
)
assert (
key in json_schema and "nullable" in json_schema[key]
), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
if key in json_schema and "nullable" in json_schema[key]:
assert "nullable" in value, (
f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs."
)
assert (
"nullable" in value
), f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs."

def forward(self, *args, **kwargs):
return NotImplementedError("Write this method in your subclass of `Tool`.")
Expand Down Expand Up @@ -275,7 +275,8 @@ def save(self, output_dir):
raise (ValueError("\n".join(method_checker.errors)))

forward_source_code = inspect.getsource(self.forward)
tool_code = textwrap.dedent(f"""
tool_code = textwrap.dedent(
f"""
from smolagents import Tool
from typing import Optional

Expand All @@ -284,7 +285,8 @@ class {class_name}(Tool):
description = "{self.description}"
inputs = {json.dumps(self.inputs, separators=(",", ":"))}
output_type = "{self.output_type}"
""").strip()
"""
).strip()
import re

def add_self_argument(source_code: str) -> str:
Expand Down Expand Up @@ -325,15 +327,17 @@ def replacement(match):
app_file = os.path.join(output_dir, "app.py")
with open(app_file, "w", encoding="utf-8") as f:
f.write(
textwrap.dedent(f"""
textwrap.dedent(
f"""
from smolagents import launch_gradio_demo
from typing import Optional
from tool import {class_name}

tool = {class_name}()

launch_gradio_demo(tool)
""").lstrip()
"""
).lstrip()
)

# Save requirements file
Expand Down Expand Up @@ -449,9 +453,7 @@ def from_hub(
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
others will be passed along to its init.
"""
assert trust_remote_code, (
"Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool."
)
assert trust_remote_code, "Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool."

hub_kwargs_names = [
"cache_dir",
Expand Down
Loading