You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm really grateful to share your open source prediction model.
I have a few question about your code.
1. Dataset
Most of prediction paper use ETH and UCY dataset such as ETH, HOTEL, UNIV, ZARA1, and ZARA2 (in social gan paper). So, I wanna use only ETH and UCY datasets for training. I found several compressed datasets in data folder from trajnetplusplusdataset repository like belows:
ewap_dataset_light.tgz
seq_eth, seq_hotel
data_zara.rar
crowds_zara01, crowds_zara02, crowds_zara03
data_university_students.rar
students001, students003, uni_examples
Is it right to use seq_eth, seq_hotel for ETH and HOTEL dataset, use crowds_zara01 for ZARA1, use crowds_zara02 for ZARA2, and use students001, students003, uni_examples for UNIV? if wrong please let me know about proper dataset file!
2. Data conversion for leave-one-out approach training
Most of paper use leave-one-out approach (in social lstm, social gan paper), training on 4 sets and test on the remaining set. So, I tried to train your social lstm model with seq_hotel (HOTEL), crowds_zara01 (ZARA1), crowds_zara02 (ZARA2), students001, students003, uni_examples (UNIV) and then test the trained model with seq_eth (ETH).
After training the social lstm model like above way and testing with seq_eth (ETH) dataset, I got test_pred folder!
However, when I try to visualize the result with visualize_predictions.py, I encountered some error like below:
python -m evaluator.visualize_predictions DATA_BLOCK/trajdata/test_private/biwi_eth.ndjson DATA_BLOCK/trajdata/test_pred/lstm_social_ETH_modes1/biwi_eth.ndjson --n 10
Scene ID: 714
Traceback (most recent call last):
File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/kblee/Study/trajnet++/trajnetplusplusbaselines/evaluator/visualize_predictions.py", line 100, in
main()
File "/home/kblee/Study/trajnet++/trajnetplusplusbaselines/evaluator/visualize_predictions.py", line 90, in main
full_predicted_paths = add_gt_observation_to_prediction(paths, predicted_paths)
File "/home/kblee/Study/trajnet++/trajnetplusplusbaselines/evaluator/visualize_predictions.py", line 18, in add_gt_observation_to_prediction
full_predicted_paths = [gt_observation[ped_id][:obs_length] + pred for ped_id, pred in enumerate(model_prediction)]
File "/home/kblee/Study/trajnet++/trajnetplusplusbaselines/evaluator/visualize_predictions.py", line 18, in
full_predicted_paths = [gt_observation[ped_id][:obs_length] + pred for ped_id, pred in enumerate(model_prediction)]
IndexError: list index out of range
The reason was due to the number of tracks (pedestrain) in observation (gt) and model prediction for each scene are different. Could you give any solution to solve this problem? The code I used when train and test is as follows:
training: python -m trajnetbaselines.lstm.trainer --type social --augment --epochs 25 --step_size 10 --n 16 --cell_side 0.6 --embedding_arch two_layer --layer_dims 1024 --batch_size 8 --loss pred --output ETH
test (evaluate): python -m trajnetbaselines.lstm.trajnet_evaluator --path trajdata --output OUTPUT_BLOCK/trajdata/lstm_social_ETH.pkl
Test data is made with only seq_eth dataset like this (python -m trajnetdataset.convert --train_fraction 0.0 --val_fraction 0.0)
(I changed some code related to floating point like (test_fraction = 1 - args.train_fraction - args.val_fraction)
3. The difference between output_pre and output
What is the difference between output_pre and output during data conversion?
4. What is --n parameter in visualize_predictions.py?
5. Best hyper parameters for sgan
I want to train social-gan model on ETH and UCY datasets just like in social gan paper. Would you mind sharing the best hyper parameters that achieve the results stated in the paper? I tried to use below code.
python -m trajnetbaselines.sgan.trainer --type hiddenstatemlp --augment --noise_dim 8 --k 20 --output ETH
Is there any parameter I need to add or change?
Thanks.
The text was updated successfully, but these errors were encountered:
Hello!
I'm really grateful to share your open source prediction model.
I have a few question about your code.
1. Dataset
Most of prediction paper use ETH and UCY dataset such as ETH, HOTEL, UNIV, ZARA1, and ZARA2 (in social gan paper). So, I wanna use only ETH and UCY datasets for training. I found several compressed datasets in data folder from trajnetplusplusdataset repository like belows:
Is it right to use seq_eth, seq_hotel for ETH and HOTEL dataset, use crowds_zara01 for ZARA1, use crowds_zara02 for ZARA2, and use students001, students003, uni_examples for UNIV? if wrong please let me know about proper dataset file!
2. Data conversion for leave-one-out approach training
Most of paper use leave-one-out approach (in social lstm, social gan paper), training on 4 sets and test on the remaining set. So, I tried to train your social lstm model with seq_hotel (HOTEL), crowds_zara01 (ZARA1), crowds_zara02 (ZARA2), students001, students003, uni_examples (UNIV) and then test the trained model with seq_eth (ETH).
After training the social lstm model like above way and testing with seq_eth (ETH) dataset, I got test_pred folder!
However, when I try to visualize the result with visualize_predictions.py, I encountered some error like below:
python -m evaluator.visualize_predictions DATA_BLOCK/trajdata/test_private/biwi_eth.ndjson DATA_BLOCK/trajdata/test_pred/lstm_social_ETH_modes1/biwi_eth.ndjson --n 10
Scene ID: 714
Traceback (most recent call last):
File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/kblee/Study/trajnet++/trajnetplusplusbaselines/evaluator/visualize_predictions.py", line 100, in
main()
File "/home/kblee/Study/trajnet++/trajnetplusplusbaselines/evaluator/visualize_predictions.py", line 90, in main
full_predicted_paths = add_gt_observation_to_prediction(paths, predicted_paths)
File "/home/kblee/Study/trajnet++/trajnetplusplusbaselines/evaluator/visualize_predictions.py", line 18, in add_gt_observation_to_prediction
full_predicted_paths = [gt_observation[ped_id][:obs_length] + pred for ped_id, pred in enumerate(model_prediction)]
File "/home/kblee/Study/trajnet++/trajnetplusplusbaselines/evaluator/visualize_predictions.py", line 18, in
full_predicted_paths = [gt_observation[ped_id][:obs_length] + pred for ped_id, pred in enumerate(model_prediction)]
IndexError: list index out of range
The reason was due to the number of tracks (pedestrain) in observation (gt) and model prediction for each scene are different. Could you give any solution to solve this problem? The code I used when train and test is as follows:
training: python -m trajnetbaselines.lstm.trainer --type social --augment --epochs 25 --step_size 10 --n 16 --cell_side 0.6 --embedding_arch two_layer --layer_dims 1024 --batch_size 8 --loss pred --output ETH
test (evaluate): python -m trajnetbaselines.lstm.trajnet_evaluator --path trajdata --output OUTPUT_BLOCK/trajdata/lstm_social_ETH.pkl
Test data is made with only seq_eth dataset like this (python -m trajnetdataset.convert --train_fraction 0.0 --val_fraction 0.0)
(I changed some code related to floating point like (test_fraction = 1 - args.train_fraction - args.val_fraction)
3. The difference between output_pre and output
What is the difference between output_pre and output during data conversion?
4. What is --n parameter in visualize_predictions.py?
5. Best hyper parameters for sgan
I want to train social-gan model on ETH and UCY datasets just like in social gan paper. Would you mind sharing the best hyper parameters that achieve the results stated in the paper? I tried to use below code.
python -m trajnetbaselines.sgan.trainer --type hiddenstatemlp --augment --noise_dim 8 --k 20 --output ETH
Is there any parameter I need to add or change?
Thanks.
The text was updated successfully, but these errors were encountered: