Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add compatibility with ray and tensorflow #121

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions model/gym-interface/py/ns3ai_gym_env/envs/ns3_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ def _create_space(self, spaceDesc):
mtype = boxSpacePb.dtype

if mtype == pb.INT:
mtype = np.int
mtype = int
elif mtype == pb.UINT:
mtype = np.uint
raise NotImplementedError("uint is not supported by all rl frameworks. Use int instead!")
elif mtype == pb.DOUBLE:
mtype = np.float
mtype = np.float64
else:
mtype = np.float
mtype = np.float32

space = spaces.Box(low=low, high=high, shape=shape, dtype=mtype)

Expand Down Expand Up @@ -203,8 +203,7 @@ def _pack_data(self, actions, spaceDesc):
boxContainerPb.intData.extend(actions)

elif spaceDesc.dtype in ['uint', 'uint8', 'uint16', 'uint32', 'uint64']:
boxContainerPb.dtype = pb.UINT
boxContainerPb.uintData.extend(actions)
raise NotImplementedError("uint is not supported by all rl frameworks. Use int instead!")

elif spaceDesc.dtype in ['float', 'float32', 'float64']:
boxContainerPb.dtype = pb.FLOAT
Expand Down Expand Up @@ -274,6 +273,8 @@ def get_state(self):
def __init__(self, targetName, ns3Path, ns3Settings=None, shmSize=4096):
if self._created:
raise Exception('Error: Ns3Env is singleton')
self.targetName = targetName
self.shmSize = shmSize
self._created = True
self.exp = Experiment(targetName, ns3Path, py_binding, shmSize=shmSize)
self.ns3Settings = ns3Settings
Expand Down Expand Up @@ -336,3 +337,16 @@ def close(self):
self.exp.kill()
# destroy the message interface and its shared memory segment
del self.exp

def __getstate__(self):
return {
"targetName": self.targetName,
"ns3Path": ".",
"ns3Settings": self.ns3Settings,
"shmSize": self.shmSize,
}

def __setstate__(self, state):
if hasattr(self, "exp"):
self.close()
self.__init__(**state)