Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Halve import time by removing torch dependency #147

Merged
merged 4 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/quality.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: "3.10"
python-version: "3.12"

# Setup venv
# TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed.
Expand Down
6 changes: 3 additions & 3 deletions docs/source/en/examples/multiagents.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ Run the line below to install the required dependencies:

Let's login in order to call the HF Inference API:

```py
from huggingface_hub import notebook_login
```
from huggingface_hub import login

notebook_login()
login()
```

⚡️ Our agent will be powered by [Qwen/Qwen2.5-Coder-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-Coder-32B-Instruct) using `HfApiModel` class that uses HF's Inference API: the Inference API allows to quickly and easily run any OS model.
Expand Down
4 changes: 2 additions & 2 deletions docs/source/en/tutorials/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ agent.run("How many more blocks (also denoted as layers) are in BERT base encode

### Manage your agent's toolbox

You can manage an agent's toolbox by adding or replacing a tool.
You can manage an agent's toolbox by adding or replacing a tool in attribute `agent.tools`, since it is a standard dictionary.

Let's add the `model_download_tool` to an existing agent initialized with only the default toolbox.

Expand All @@ -187,7 +187,7 @@ from smolagents import HfApiModel
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")

agent = CodeAgent(tools=[], model=model, add_base_tools=True)
agent.tools.append(model_download_tool)
agent.tools[model_download_tool.name] = model_download_tool
```
Now we can leverage the new tool:

Expand Down
17 changes: 13 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@ authors = [
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"torch",
"torchaudio",
"torchvision",
"transformers>=4.0.0",
"requests>=2.32.3",
"rich>=13.9.4",
Expand All @@ -30,10 +27,22 @@ dependencies = [
]

[tool.ruff]
ignore = ["F403"]
lint.ignore = ["F403"]

[project.optional-dependencies]
dev = [
"torch",
"torchaudio",
"torchvision",
"sqlalchemy",
"accelerate",
"soundfile",
"litellm>=1.55.10",
]
test = [
"torch",
"torchaudio",
"torchvision",
"pytest>=8.1.0",
"sqlalchemy",
"ruff>=0.5.0",
Expand Down
17 changes: 12 additions & 5 deletions src/smolagents/default_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@
from typing import Dict, Optional

from huggingface_hub import hf_hub_download, list_spaces
from transformers.models.whisper import (
WhisperForConditionalGeneration,
WhisperProcessor,
)
from transformers.utils import is_offline_mode


from transformers.utils import is_offline_mode, is_torch_available

from .local_python_executor import (
BASE_BUILTIN_MODULES,
Expand All @@ -34,6 +32,15 @@
from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool
from .types import AgentAudio

if is_torch_available():
from transformers.models.whisper import (
WhisperForConditionalGeneration,
WhisperProcessor,
)
else:
WhisperForConditionalGeneration = object
WhisperProcessor = object


@dataclass
class PreTool:
Expand Down
29 changes: 10 additions & 19 deletions src/smolagents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from enum import Enum
from typing import Dict, List, Optional

import torch
from huggingface_hub import (
InferenceClient,
ChatCompletionOutputMessage,
Expand All @@ -35,6 +34,7 @@
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
is_torch_available,
)
import openai

Expand Down Expand Up @@ -147,29 +147,12 @@ def __init__(self):
self.last_input_token_count = None
self.last_output_token_count = None

def get_token_counts(self):
def get_token_counts(self) -> Dict[str, int]:
return {
"input_token_count": self.last_input_token_count,
"output_token_count": self.last_output_token_count,
}

def generate(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
):
raise NotImplementedError

def get_tool_call(
self,
messages: List[Dict[str, str]],
available_tools: List[Tool],
stop_sequences,
):
raise NotImplementedError

def __call__(
self,
messages: List[Dict[str, str]],
Expand Down Expand Up @@ -256,6 +239,10 @@ def __call__(
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> str:
"""
Gets an LLM output message for the given list of input messages.
If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call.
"""
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
Expand Down Expand Up @@ -293,6 +280,10 @@ class TransformersModel(Model):

def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None):
super().__init__()
if not is_torch_available():
raise ImportError("Please install torch in order to use TransformersModel.")
import torch

default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
if model_id is None:
model_id = default_model_id
Expand Down
22 changes: 10 additions & 12 deletions src/smolagents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from pathlib import Path
from typing import Callable, Dict, Optional, Union, get_type_hints

import torch
from huggingface_hub import (
create_repo,
get_collection,
Expand All @@ -37,7 +36,6 @@
)
from huggingface_hub.utils import RepositoryNotFoundError
from packaging import version
from transformers import AutoProcessor
from transformers.dynamic_module_utils import get_imports
from transformers.utils import (
TypeHintParsingException,
Expand All @@ -54,13 +52,14 @@

logger = logging.getLogger(__name__)


if is_torch_available():
pass

if is_accelerate_available():
pass
from accelerate import PartialState
from accelerate.utils import send_to_device

if is_torch_available():
from transformers import AutoProcessor
else:
AutoProcessor = object

TOOL_CONFIG_FILE = "tool_config.json"

Expand Down Expand Up @@ -1026,8 +1025,6 @@ def setup(self):
"""
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
"""
from accelerate import PartialState

if isinstance(self.pre_processor, str):
self.pre_processor = self.pre_processor_class.from_pretrained(
self.pre_processor, **self.hub_kwargs
Expand Down Expand Up @@ -1066,6 +1063,8 @@ def forward(self, inputs):
"""
Sends the inputs through the `model`.
"""
import torch

with torch.no_grad():
return self.model(**inputs)

Expand All @@ -1076,16 +1075,15 @@ def decode(self, outputs):
return self.post_processor(outputs)

def __call__(self, *args, **kwargs):
import torch

args, kwargs = handle_agent_input_types(*args, **kwargs)

if not self.is_initialized:
self.setup()

encoded_inputs = self.encode(*args, **kwargs)

import torch
from accelerate.utils import send_to_device

tensor_inputs = {
k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)
}
Expand Down
8 changes: 4 additions & 4 deletions src/smolagents/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
import numpy as np
import requests
from transformers.utils import (
is_soundfile_availble,
is_torch_available,
is_vision_available,
)
from transformers.utils.import_utils import _is_package_available

logger = logging.getLogger(__name__)

Expand All @@ -41,7 +41,7 @@
else:
Tensor = object

if is_soundfile_availble():
if _is_package_available("soundfile"):
import soundfile as sf


Expand Down Expand Up @@ -189,7 +189,7 @@ class AgentAudio(AgentType, str):
def __init__(self, value, samplerate=16_000):
super().__init__(value)

if not is_soundfile_availble():
if not _is_package_available("soundfile"):
raise ImportError("soundfile must be installed in order to handle audio.")

self._path = None
Expand Down Expand Up @@ -253,7 +253,7 @@ def to_string(self):
INSTANCE_TYPE_MAPPING = {
str: AgentText,
ImageType: AgentImage,
torch.Tensor: AgentAudio,
Tensor: AgentAudio,
}

if is_torch_available():
Expand Down
13 changes: 9 additions & 4 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,19 @@
import uuid
from pathlib import Path

import torch
from PIL import Image
from transformers.testing_utils import (
require_soundfile,
require_torch,
require_vision,
)
from transformers.utils import (
is_soundfile_availble,
from transformers.utils.import_utils import (
_is_package_available,
)

from smolagents.types import AgentAudio, AgentImage, AgentText

if is_soundfile_availble():
if _is_package_available("soundfile"):
import soundfile as sf


Expand All @@ -44,6 +43,8 @@ def get_new_path(suffix="") -> str:
@require_torch
class AgentAudioTests(unittest.TestCase):
def test_from_tensor(self):
import torch

tensor = torch.rand(12, dtype=torch.float64) - 0.5
agent_type = AgentAudio(tensor)
path = str(agent_type.to_string())
Expand All @@ -61,6 +62,8 @@ def test_from_tensor(self):
self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4))

def test_from_string(self):
import torch

tensor = torch.rand(12, dtype=torch.float64) - 0.5
path = get_new_path(suffix=".wav")
sf.write(path, tensor, 16000)
Expand All @@ -75,6 +78,8 @@ def test_from_string(self):
@require_torch
class AgentImageTests(unittest.TestCase):
def test_from_tensor(self):
import torch

tensor = torch.randint(0, 256, (64, 64, 3))
agent_type = AgentImage(tensor)
path = str(agent_type.to_string())
Expand Down
Loading