Skip to content

Commit

Permalink
[sharktank] Remove 'torch' from deps and warn instead (#706)
Browse files Browse the repository at this point in the history
Instead of enforcing the installation for 'torch' as a dependency, error
if 'torch' cannot be imported and point the user to how to install.

Co-authored-by: Scott Todd <[email protected]>
  • Loading branch information
marbre and ScottTodd authored Dec 16, 2024
1 parent 0660b07 commit d1980c7
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
13 changes: 10 additions & 3 deletions docs/user_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,20 @@ Setup your Python environment with the following commands:
# Set up a virtual environment to isolate packages from other envs.
python3.11 -m venv 3.11.venv
source 3.11.venv/bin/activate
```

## Install SHARK and its dependencies

First install a torch version that fulfills your needs:

# Optional: faster installation of torch with just CPU support.
# See other options at https://pytorch.org/get-started/locally/
```bash
# Fast installation of torch with just CPU support.
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
```

## Install SHARK and its dependencies
For other options, see https://pytorch.org/get-started/locally/.

Next install shark-ai:

```bash
pip install shark-ai[apps]
Expand Down
4 changes: 0 additions & 4 deletions sharktank/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ huggingface-hub==0.22.2
transformers==4.40.0
datasets

# It is expected that you have installed a PyTorch version/variant specific
# to your needs, so we only include a minimum version spec.
torch>=2.3.0

# Serving deps.
fastapi>=0.112.2
uvicorn>=0.30.6
10 changes: 10 additions & 0 deletions sharktank/sharktank/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,13 @@
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import importlib.util

msg = """No module named 'torch'. Follow https://pytorch.org/get-started/locally/#start-locally to install 'torch'.
For example, on Linux to install with CPU support run:
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
"""

if spec := importlib.util.find_spec("torch") is None:
raise ModuleNotFoundError(msg)

0 comments on commit d1980c7

Please sign in to comment.