Skip to content

Commit

Permalink
add flexible source dir feature
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruben Queiros committed May 17, 2023
1 parent 6007f4b commit db3fbc1
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
12 changes: 7 additions & 5 deletions model/ns3gym/ns3gym/ns3env.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

class Ns3ZmqBridge(object):
"""docstring for Ns3ZmqBridge"""
def __init__(self, port=0, startSim=True, simSeed=0, simArgs={}, debug=False):
def __init__(self, port=0, startSim=True, simSeed=0, simArgs={}, debug=False, src_dir=os.getcwd()):
super(Ns3ZmqBridge, self).__init__()
port = int(port)
self.port = port
Expand All @@ -35,6 +35,7 @@ def __init__(self, port=0, startSim=True, simSeed=0, simArgs={}, debug=False):
self.simPid = None
self.wafPid = None
self.ns3Process = None
self.src_dir = src_dir

context = zmq.Context()
self.socket = context.socket(zmq.REP)
Expand Down Expand Up @@ -63,7 +64,7 @@ def __init__(self, port=0, startSim=True, simSeed=0, simArgs={}, debug=False):

if self.startSim:
# run simulation script
self.ns3Process = start_sim_script(port, simSeed, simArgs, debug)
self.ns3Process = start_sim_script(port, simSeed, simArgs, debug, src_dir)
else:
print("Waiting for simulation script to connect on port: tcp://localhost:{}".format(port))
print('Please start proper ns-3 simulation script using ./waf --run "..."')
Expand Down Expand Up @@ -361,13 +362,14 @@ def _pack_data(self, actions, spaceDesc):


class Ns3Env(gym.Env):
def __init__(self, stepTime=0, port=0, startSim=True, simSeed=0, simArgs={}, debug=False):
def __init__(self, stepTime=0, port=0, startSim=True, simSeed=0, simArgs={}, debug=False, src_dir=os.getcwd()):
self.stepTime = stepTime
self.port = port
self.startSim = startSim
self.simSeed = simSeed
self.simArgs = simArgs
self.debug = debug
self.src_dir = src_dir

# Filled in reset function
self.ns3ZmqBridge = None
Expand All @@ -378,7 +380,7 @@ def __init__(self, stepTime=0, port=0, startSim=True, simSeed=0, simArgs={}, deb
self.state = None
self.steps_beyond_done = None

self.ns3ZmqBridge = Ns3ZmqBridge(self.port, self.startSim, self.simSeed, self.simArgs, self.debug)
self.ns3ZmqBridge = Ns3ZmqBridge(self.port, self.startSim, self.simSeed, self.simArgs, self.debug, self.src_dir)
self.ns3ZmqBridge.initialize_env(self.stepTime)
self.action_space = self.ns3ZmqBridge.get_action_space()
self.observation_space = self.ns3ZmqBridge.get_observation_space()
Expand Down Expand Up @@ -413,7 +415,7 @@ def reset(self):
self.ns3ZmqBridge = None

self.envDirty = False
self.ns3ZmqBridge = Ns3ZmqBridge(self.port, self.startSim, self.simSeed, self.simArgs, self.debug)
self.ns3ZmqBridge = Ns3ZmqBridge(self.port, self.startSim, self.simSeed, self.simArgs, self.debug, self.src_dir)
self.ns3ZmqBridge.initialize_env(self.stepTime)
self.action_space = self.ns3ZmqBridge.get_action_space()
self.observation_space = self.ns3ZmqBridge.get_observation_space()
Expand Down
9 changes: 4 additions & 5 deletions model/ns3gym/ns3gym/start_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,12 @@ def build_ns3_project(debug=True):
os.chdir(cwd)


def start_sim_script(port=5555, sim_seed=0, sim_args={}, debug=False):
def start_sim_script(port=5555, sim_seed=0, sim_args={}, debug=False, src_dir=os.getcwd()):
"""
Actually run the ns3 scenario
"""
cwd = os.getcwd()
sim_script_name = os.path.basename(cwd)
ns3_path = find_ns3_path(cwd)
sim_script_name = os.path.basename(src_dir)
ns3_path = find_ns3_path(src_dir)
base_ns3_dir = os.path.dirname(ns3_path)

os.chdir(base_ns3_dir)
Expand Down Expand Up @@ -133,5 +132,5 @@ def start_sim_script(port=5555, sim_seed=0, sim_args={}, debug=False):
print("Started ns3 simulation script, Process Id: ", ns3_proc.pid)

# go back to my dir
os.chdir(cwd)
os.chdir(src_dir)
return ns3_proc

0 comments on commit db3fbc1

Please sign in to comment.