Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about dataset conversion and best hyper parameters for sgan #32

Open
kangbeenlee opened this issue Apr 9, 2023 · 0 comments
Open

Comments

@kangbeenlee
Copy link

kangbeenlee commented Apr 9, 2023

Hello!

I'm really grateful to share your open source prediction model.
I have a few question about your code.

1. Dataset
Most of prediction paper use ETH and UCY dataset such as ETH, HOTEL, UNIV, ZARA1, and ZARA2 (in social gan paper). So, I wanna use only ETH and UCY datasets for training. I found several compressed datasets in data folder from trajnetplusplusdataset repository like belows:

  • ewap_dataset_light.tgz

    seq_eth, seq_hotel

  • data_zara.rar

    crowds_zara01, crowds_zara02, crowds_zara03

  • data_university_students.rar

    students001, students003, uni_examples

Is it right to use seq_eth, seq_hotel for ETH and HOTEL dataset, use crowds_zara01 for ZARA1, use crowds_zara02 for ZARA2, and use students001, students003, uni_examples for UNIV? if wrong please let me know about proper dataset file!

2. Data conversion for leave-one-out approach training
Most of paper use leave-one-out approach (in social lstm, social gan paper), training on 4 sets and test on the remaining set. So, I tried to train your social lstm model with seq_hotel (HOTEL), crowds_zara01 (ZARA1), crowds_zara02 (ZARA2), students001, students003, uni_examples (UNIV) and then test the trained model with seq_eth (ETH).
After training the social lstm model like above way and testing with seq_eth (ETH) dataset, I got test_pred folder!
However, when I try to visualize the result with visualize_predictions.py, I encountered some error like below:

python -m evaluator.visualize_predictions DATA_BLOCK/trajdata/test_private/biwi_eth.ndjson DATA_BLOCK/trajdata/test_pred/lstm_social_ETH_modes1/biwi_eth.ndjson --n 10
Scene ID: 714
Traceback (most recent call last):
File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/kblee/Study/trajnet++/trajnetplusplusbaselines/evaluator/visualize_predictions.py", line 100, in
main()
File "/home/kblee/Study/trajnet++/trajnetplusplusbaselines/evaluator/visualize_predictions.py", line 90, in main
full_predicted_paths = add_gt_observation_to_prediction(paths, predicted_paths)
File "/home/kblee/Study/trajnet++/trajnetplusplusbaselines/evaluator/visualize_predictions.py", line 18, in add_gt_observation_to_prediction
full_predicted_paths = [gt_observation[ped_id][:obs_length] + pred for ped_id, pred in enumerate(model_prediction)]
File "/home/kblee/Study/trajnet++/trajnetplusplusbaselines/evaluator/visualize_predictions.py", line 18, in
full_predicted_paths = [gt_observation[ped_id][:obs_length] + pred for ped_id, pred in enumerate(model_prediction)]
IndexError: list index out of range

The reason was due to the number of tracks (pedestrain) in observation (gt) and model prediction for each scene are different. Could you give any solution to solve this problem? The code I used when train and test is as follows:

training: python -m trajnetbaselines.lstm.trainer --type social --augment --epochs 25 --step_size 10 --n 16 --cell_side 0.6 --embedding_arch two_layer --layer_dims 1024 --batch_size 8 --loss pred --output ETH
test (evaluate): python -m trajnetbaselines.lstm.trajnet_evaluator --path trajdata --output OUTPUT_BLOCK/trajdata/lstm_social_ETH.pkl

Test data is made with only seq_eth dataset like this (python -m trajnetdataset.convert --train_fraction 0.0 --val_fraction 0.0)
(I changed some code related to floating point like (test_fraction = 1 - args.train_fraction - args.val_fraction)

3. The difference between output_pre and output
What is the difference between output_pre and output during data conversion?

4. What is --n parameter in visualize_predictions.py?

5. Best hyper parameters for sgan
I want to train social-gan model on ETH and UCY datasets just like in social gan paper. Would you mind sharing the best hyper parameters that achieve the results stated in the paper? I tried to use below code.
python -m trajnetbaselines.sgan.trainer --type hiddenstatemlp --augment --noise_dim 8 --k 20 --output ETH
Is there any parameter I need to add or change?

Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant