Skip to content

Commit

Permalink
[Bug] Fix wrong parallel mapping in tensor parallelism
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Nov 27, 2023
1 parent 9fd3394 commit 7a28f7a
Showing 1 changed file with 0 additions and 16 deletions.
16 changes: 0 additions & 16 deletions pipegoose/nn/parallel_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,13 @@ def __init__(self, module_name: Tuple[str], **kwargs):


class ParallelMapping:
# def __init__(self, mapping: Dict):
# self.mapping = mapping

@staticmethod
def _search(module_name: str) -> Optional[ParallelInfo]:
"""
Search for module_name in mappings.
"""
module_name = ParallelMapping._extract_module_name(module_name)
for child_class in ParallelMapping.__subclasses__():
from pipegoose.nn.tensor_parallel.parallel_mapping import (
TensorParallelMapping,
)

if child_class == TensorParallelMapping:
continue

if hasattr(child_class, "__MAPPING__"):
for items in child_class.__MAPPING__.values():
for item in items:
Expand All @@ -34,12 +24,6 @@ def _search(module_name: str) -> Optional[ParallelInfo]:
# NOTE: only search the first subclass of the current instance
break

# for items in self.mapping.values():
# for item in items:
# item = cast(ParallelInfo, item)
# if any(module_name in mapping_name for mapping_name in item.module_name):
# return item

return None

@staticmethod
Expand Down

0 comments on commit 7a28f7a

Please sign in to comment.