From b682409a6244eb7d0a0789cb0a637ea477889ee5 Mon Sep 17 00:00:00 2001 From: Jianhua Zheng Date: Fri, 28 Jun 2024 02:53:03 +0000 Subject: [PATCH] support new device --- libai/tokenizer/tokenization_base.py | 6 +++--- libai/utils/distributed.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/libai/tokenizer/tokenization_base.py b/libai/tokenizer/tokenization_base.py index 026902fdf..ea6f6c4fa 100644 --- a/libai/tokenizer/tokenization_base.py +++ b/libai/tokenizer/tokenization_base.py @@ -782,9 +782,9 @@ def convert_to_tensors(self, token_ids, return_tensors=None, is_global=False, ** return_token_ids = flow.tensor(token_ids, dtype=flow.long) elif is_global: sbp = kwargs.get("sbp", dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast])) - placement = kwargs.get( - "placement", flow.placement("cuda", list(range(dist.get_world_size()))) - ) + placement = kwargs.get("placement") + if placement is None: + placement = flow.placement("cuda", list(range(dist.get_world_size()))) return_token_ids = flow.tensor( token_ids, sbp=sbp, placement=placement, dtype=flow.long ) diff --git a/libai/utils/distributed.py b/libai/utils/distributed.py index e7914a0ad..42f0593fe 100644 --- a/libai/utils/distributed.py +++ b/libai/utils/distributed.py @@ -438,7 +438,7 @@ def convert_to_distributed_default_setting(t): def ttol(tensor, pure_local=False, ranks=None): """Global tensor to local tensor.""" if tensor.is_global: - placement = tensor.placement if not ranks else flow.placement("cuda", ranks) + placement = tensor.placement if not ranks else flow.placement(tensor.placement.type, ranks) if pure_local: tensor = tensor.to_global(placement=placement).to_local() else: