You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I went through some of the code for PPO and it seems to be still using numpy arrays which are used in the observation spaces and actions spaces. So i'm guessing that the environments are still on the cpu and their steps also take place on the CPU...(correct me if I'm wrong). So does this mean that this library doesn't address the issue of CPU-GPU Data transfer. So all the speed up is basically just due to the optimization taking place using jax jit complied codes only for the policy optimization part... Am I right?... Could you guys please add support for Custom environments created on the GPU? Maybe I could contribute? I currently have jitted jax environment entirely on GPU and i wanted to just use existing libraries to train the policy but apparently all the libraries use some sort of conventional environment library on which their code is based. My environment is build from scratch using jax and contains the main step() and reset() functions. But now i am stuck with the only option to implement the policy and optimization code all on my own. So I just want to know if there is anyway your library can be used on a custom Environment where all the functions are stateless.
The text was updated successfully, but these errors were encountered:
So i'm guessing that the environments are still on the cpu and their steps also take place on the CPU.
yes
doesn't address the issue of CPU-GPU Data transfer.
for PPO, when not using images, (and when not using isaac sim), there is no need for GPU.
(see runtime reports from https://rlj.cs.umass.edu/2024/papers/Paper18.html and several issues about that on SB3 repo)
just due to the optimization taking place using jax jit complied codes only for the policy optimization part.
and for the policy/value prediction.
Could you guys please add support for Custom environments created on the GPU?
This is currently not planned (as SBX is based on SB3).
However, if you want no data transfer, you will need:
Custom vec env (with have some examples, including for isaac sim in SB3 doc)
Custom replay buffer (should not be too hard)
You might need to remove some calls to .numpy() in SB3 (I'm not 100% sure about that part)
if you do so, please open source your fork of SBX, it might be helpful for others.
I went through some of the code for PPO and it seems to be still using numpy arrays which are used in the observation spaces and actions spaces. So i'm guessing that the environments are still on the cpu and their steps also take place on the CPU...(correct me if I'm wrong). So does this mean that this library doesn't address the issue of CPU-GPU Data transfer. So all the speed up is basically just due to the optimization taking place using jax jit complied codes only for the policy optimization part... Am I right?... Could you guys please add support for Custom environments created on the GPU? Maybe I could contribute? I currently have jitted jax environment entirely on GPU and i wanted to just use existing libraries to train the policy but apparently all the libraries use some sort of conventional environment library on which their code is based. My environment is build from scratch using jax and contains the main step() and reset() functions. But now i am stuck with the only option to implement the policy and optimization code all on my own. So I just want to know if there is anyway your library can be used on a custom Environment where all the functions are stateless.
The text was updated successfully, but these errors were encountered: