diff --git a/libai/tokenizer/tokenization_base.py b/libai/tokenizer/tokenization_base.py index 026902fdf..ea6f6c4fa 100644 --- a/libai/tokenizer/tokenization_base.py +++ b/libai/tokenizer/tokenization_base.py @@ -782,9 +782,9 @@ def convert_to_tensors(self, token_ids, return_tensors=None, is_global=False, ** return_token_ids = flow.tensor(token_ids, dtype=flow.long) elif is_global: sbp = kwargs.get("sbp", dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast])) - placement = kwargs.get( - "placement", flow.placement("cuda", list(range(dist.get_world_size()))) - ) + placement = kwargs.get("placement") + if placement is None: + placement = flow.placement("cuda", list(range(dist.get_world_size()))) return_token_ids = flow.tensor( token_ids, sbp=sbp, placement=placement, dtype=flow.long ) diff --git a/libai/utils/distributed.py b/libai/utils/distributed.py index e7914a0ad..42f0593fe 100644 --- a/libai/utils/distributed.py +++ b/libai/utils/distributed.py @@ -438,7 +438,7 @@ def convert_to_distributed_default_setting(t): def ttol(tensor, pure_local=False, ranks=None): """Global tensor to local tensor.""" if tensor.is_global: - placement = tensor.placement if not ranks else flow.placement("cuda", ranks) + placement = tensor.placement if not ranks else flow.placement(tensor.placement.type, ranks) if pure_local: tensor = tensor.to_global(placement=placement).to_local() else: