Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

python 3.10 update #126

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,34 +40,41 @@ Our model consists of three key components: Generator (G), Pooling Module (PM) a
</div>

## Setup
All code was developed and tested on Ubuntu 16.04 with Python 3.5 and PyTorch 0.4.
All code was developed and tested on Ubuntu 22.04 with Python 3.10 and torch

You can setup a virtual environment to run the code like this:
You can setup a virtual conda environment to run the code like this:

```bash
python3 -m venv env # Create a virtual environment
source env/bin/activate # Activate virtual environment
pip install -r requirements.txt # Install dependencies
echo $PWD > env/lib/python3.5/site-packages/sgan.pth # Add current directory to python path
conda create -n test python=3.10 -y # Create a virtual environment
conda activate test # Activate virtual environment
# Work for a while ...
deactivate # Exit virtual environment
conda deactivate # Exit virtual environment
```

## Pretrained Models
You can download pretrained models by running the script `bash scripts/download_models.sh`. This will download the following models:
## clone repo and download files

- `sgan-models/<dataset_name>_<pred_len>.pt`: Contains 10 pretrained models for all five datasets. These models correspond to SGAN-20V-20 in Table 1.
- `sgan-p-models/<dataset_name>_<pred_len>.pt`: Contains 10 pretrained models for all five datasets. These models correspond to SGAN-20VP-20 in Table 1.
```bash
git clone https://github.com/bharath5673/Social-GAN.git
cd Social-GAN
pip install -r requirements.txt # Install dependencies
sh scripts/download_data.sh
sh scripts/download_models.sh
```

Please refer to [Model Zoo](MODEL_ZOO.md) for results.

## Running Models
You can use the script `scripts/evaluate_model.py` to easily run any of the pretrained models on any of the datsets. For example you can replicate the Table 1 results for all datasets for SGAN-20V-20 like this:

```bash
python scripts/evaluate_model.py \
--model_path models/sgan-models
cd scripts
sh run_eval.sh
```

## Training new models

```bash
cd scripts
sh run_traj.sh
```
Instructions for training new models can be [found here](TRAINING.md).
13 changes: 6 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
attrdict==2.0.0
numpy==1.14.5
Pillow==6.2.0
pkg-resources==0.0.0
six==1.11.0
torch==0.4.0
torchvision==0.2.1
scripts/attrdict-2.0.1-py2.py3-none-any.whl
numpy
Pillow
six
torch
torchvision
Binary file added scripts/attrdict-2.0.1-py2.py3-none-any.whl
Binary file not shown.
3 changes: 3 additions & 0 deletions scripts/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

from attrdict import AttrDict

import sys
sys.path.append('../')

from sgan.data.loader import data_loader
from sgan.models import TrajectoryGenerator
from sgan.losses import displacement_error, final_displacement_error
Expand Down
1 change: 1 addition & 0 deletions scripts/run_eval.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python3 evaluate_model.py --model_path ../models/sgan-models
5 changes: 4 additions & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
import torch.nn as nn
import torch.optim as optim


import sys
sys.path.append('../')
from sgan.data.loader import data_loader
from sgan.losses import gan_g_loss, gan_d_loss, l2_loss
from sgan.losses import displacement_error, final_displacement_error
Expand Down Expand Up @@ -109,7 +112,7 @@ def get_dtypes(args):


def main(args):
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
# os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
train_path = get_dset_path(args.dataset_name, 'train')
val_path = get_dset_path(args.dataset_name, 'val')

Expand Down
Binary file added sgan/__pycache__/losses.cpython-310.pyc
Binary file not shown.
Binary file added sgan/__pycache__/models.cpython-310.pyc
Binary file not shown.
Binary file added sgan/__pycache__/utils.cpython-310.pyc
Binary file not shown.
Binary file added sgan/data/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file added sgan/data/__pycache__/loader.cpython-310.pyc
Binary file not shown.
Binary file not shown.
2 changes: 1 addition & 1 deletion sgan/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def forward(self, obs_traj):
"""
# Encode observed Trajectory
batch = obs_traj.size(1)
obs_traj_embedding = self.spatial_embedding(obs_traj.view(-1, 2))
obs_traj_embedding = self.spatial_embedding(obs_traj.reshape(-1, 2))
obs_traj_embedding = obs_traj_embedding.view(
-1, batch, self.embedding_dim
)
Expand Down