diff --git a/data_factory.py b/data_factory.py index 7cd45a2..4bb09eb 100644 --- a/data_factory.py +++ b/data_factory.py @@ -42,7 +42,8 @@ def parallel_handler(self, gen_mutex, gs_mutex, ld_mutex, num, solvable = True): #not exactly sure if this is necassary but it is a fairly simple precaution to implement self.gen_mutex = gen_mutex self.gs_mutex = gs_mutex - self.ld_mutex = ld_(gs_mutex) + self.ld_mutex = ld_mutex + print(gs_mutex) self.generate_levels(num, solvable) def generate_levels(self, num, solvable=True): @@ -85,7 +86,7 @@ def generate_levels(self, num, solvable=True): norm_move = self.convert_to_normalized(move) #save current move and last state, as each [state, move] pair is the current state and the move #that **led** to the current state - save_training_data(one_hot_state.flatten(), norm_move) + self.save_training_data(one_hot_state.flatten(), norm_move) last_state = state break @@ -106,7 +107,7 @@ def convert_to_normalized(self,move): pattern = re.compile(r"\((\d+),(\d+)\) -> \((\d+),(\d+)\)") # Search the string for matches - match = pattern.search(coord_string) + match = pattern.search(move) # Validate the input string if match is None: @@ -167,18 +168,20 @@ def save_level(self, level_data, filename="levels/level_data.json"): self.ld_mutex.release() def save_training_data(self, state, move, filename="levels/training_data.json"): - if not (isinstance(state, list) and isinstance(move, list)): - raise ValueError("Inputs should be of type list") - - # Storing arrays in a dictionary - data = { - "state": state, - "move": move - } + + if isinstance(state, np.ndarray): + state = state.tolist() + + if isinstance(move, np.ndarray): + move = move.tolist() - # Writing the dictionary to a JSON file - with open(filename, 'w') as f: - json.dump(data, f) + data = { + "state": state, + "move": move + } + + with open(filename, 'w') as f: + json.dump(data, f) def generate_levels_parallel(self, num_processes, num, solvable=True): num = num//num_processes