diff --git a/libai/data/structures.py b/libai/data/structures.py index 380a8a1ce..eeaf4d7d7 100644 --- a/libai/data/structures.py +++ b/libai/data/structures.py @@ -29,7 +29,7 @@ class DistTensorData: placement_idx: int = 0 # Tensor-like methods - def to_global(self, sbp=None, placement=None, device_type="cuda"): + def to_global(self, sbp=None, placement=None, device_type="cuda", check_meta=True, sync_data=True): if sbp is not None: self.sbp = sbp else: @@ -47,7 +47,7 @@ def to_global(self, sbp=None, placement=None, device_type="cuda"): self.sbp = dist.get_nd_sbp(sbp_list) if placement is not None: - self.tensor = self.tensor.to_global(sbp=self.sbp, placement=placement) + self.tensor = self.tensor.to_global(sbp=self.sbp, placement=placement, check_meta=check_meta, sync_data=sync_data) else: # Convert local tensor to global tensor with default setting, # if the placement parameter is not provided. @@ -62,10 +62,11 @@ def to_global(self, sbp=None, placement=None, device_type="cuda"): # by the fist device group, in case that each device group containg # some random augmentations to the tensors without setting the same global seed. main_placement = dist.get_layer_placement(0, device_type) - self.tensor = self.tensor.to_global(sbp=self.sbp, placement=main_placement) + self.tensor = self.tensor.to_global(sbp=self.sbp, placement=main_placement, check_meta=check_meta, sync_data=sync_data) if self.placement_idx != 0: self.tensor = self.tensor.to_global( - placement=dist.get_layer_placement(self.placement_idx, device_type) + placement=dist.get_layer_placement(self.placement_idx, device_type), + check_meta=check_meta, sync_data=sync_data ) @staticmethod diff --git a/libai/engine/default.py b/libai/engine/default.py index 14a107167..9239dddc5 100644 --- a/libai/engine/default.py +++ b/libai/engine/default.py @@ -496,6 +496,8 @@ def get_batch( data: Instance, input_placement_device: str = "cuda", mixup_func: Optional[Callable] = None, + check_meta: bool=True, + sync_data: bool=True, ): """ Convert batched local tensor to distributed tensor for model step running. @@ -516,7 +518,7 @@ def get_batch( ret_dict = {} for key, value in data.get_fields().items(): - value.to_global(device_type=input_placement_device) + value.to_global(device_type=input_placement_device, check_meta=check_meta, sync_data=sync_data) ret_dict[key] = value.tensor return ret_dict