Skip to content

Commit

Permalink
[sharktank] Remove 'torch' from deps and warn instead
Browse files Browse the repository at this point in the history
Instead of enforcing the installation for 'torch' as a dependency, error
if 'torch' cannot be imported and point the user to how to install.
  • Loading branch information
marbre committed Dec 16, 2024
1 parent ba78824 commit 4de7adf
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
10 changes: 7 additions & 3 deletions docs/user_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,18 @@ Setup your Python environment with the following commands:
# Set up a virtual environment to isolate packages from other envs.
python3.11 -m venv 3.11.venv
source 3.11.venv/bin/activate
```

## Install SHARK and its dependencies

Before installing `shark-ai` install a torch version that fulfills your needs.

# Optional: faster installation of torch with just CPU support.
```bash
# Fast installation of torch with just CPU support.
# See other options at https://pytorch.org/get-started/locally/
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
```

## Install SHARK and its dependencies

```bash
pip install shark-ai[apps]
```
Expand Down
4 changes: 0 additions & 4 deletions sharktank/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ huggingface-hub==0.22.2
transformers==4.40.0
datasets

# It is expected that you have installed a PyTorch version/variant specific
# to your needs, so we only include a minimum version spec.
torch>=2.3.0

# Serving deps.
fastapi>=0.112.2
uvicorn>=0.30.6
10 changes: 10 additions & 0 deletions sharktank/sharktank/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,13 @@
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import importlib.util

msg = """No module named 'torch'. Follow https://pytorch.org/get-started/locally/#start-locally to install 'torch'.
For example, on Linux to install with CPU support run:
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
"""

if spec := importlib.util.find_spec("torch") is None:
raise ModuleNotFoundError(msg)

0 comments on commit 4de7adf

Please sign in to comment.