diff --git a/distkeras/distributed.py b/distkeras/distributed.py index 4822dae..d5a8e13 100644 --- a/distkeras/distributed.py +++ b/distkeras/distributed.py @@ -424,9 +424,6 @@ def allocate_worker(self): def train(self, data, shuffle=False): # Start the communication service. self.start_service() - # Check if the data needs to be shuffled. - if shuffle: - data = shuffle(data) # Allocate a worker program. worker = self.allocate_worker() # Fetch the current number of partitions. @@ -436,6 +433,9 @@ def train(self, data, shuffle=False): data = data.coalesce(self.num_workers) else: data = data.repartition(self.num_workers) + # Check if the data needs to be shuffled. + if shuffle: + data = shuffle(data) for i in range(0, self.num_epoch): data.rdd.mapPartitionsWithIndex(worker.train).collect() # Stop the communication service.