Skip to content

Commit

Permalink
fixed issue with saving numpy arrays to json instead of lists
Browse files Browse the repository at this point in the history
  • Loading branch information
HenryMBaldwin committed Sep 30, 2023
1 parent f7a1c3f commit 95ff60f
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def parallel_handler(self, gen_mutex, gs_mutex, ld_mutex, num, solvable = True):
#not exactly sure if this is necassary but it is a fairly simple precaution to implement
self.gen_mutex = gen_mutex
self.gs_mutex = gs_mutex
self.ld_mutex = ld_(gs_mutex)
self.ld_mutex = ld_mutex
print(gs_mutex)
self.generate_levels(num, solvable)

def generate_levels(self, num, solvable=True):
Expand Down Expand Up @@ -85,7 +86,7 @@ def generate_levels(self, num, solvable=True):
norm_move = self.convert_to_normalized(move)
#save current move and last state, as each [state, move] pair is the current state and the move
#that **led** to the current state
save_training_data(one_hot_state.flatten(), norm_move)
self.save_training_data(one_hot_state.flatten(), norm_move)
last_state = state
break

Expand All @@ -106,7 +107,7 @@ def convert_to_normalized(self,move):
pattern = re.compile(r"\((\d+),(\d+)\) -> \((\d+),(\d+)\)")

# Search the string for matches
match = pattern.search(coord_string)
match = pattern.search(move)

# Validate the input string
if match is None:
Expand Down Expand Up @@ -167,18 +168,20 @@ def save_level(self, level_data, filename="levels/level_data.json"):
self.ld_mutex.release()

def save_training_data(self, state, move, filename="levels/training_data.json"):
if not (isinstance(state, list) and isinstance(move, list)):
raise ValueError("Inputs should be of type list")

# Storing arrays in a dictionary
data = {
"state": state,
"move": move
}

if isinstance(state, np.ndarray):
state = state.tolist()

if isinstance(move, np.ndarray):
move = move.tolist()

# Writing the dictionary to a JSON file
with open(filename, 'w') as f:
json.dump(data, f)
data = {
"state": state,
"move": move
}

with open(filename, 'w') as f:
json.dump(data, f)

def generate_levels_parallel(self, num_processes, num, solvable=True):
num = num//num_processes
Expand Down

0 comments on commit 95ff60f

Please sign in to comment.