Skip to content

Commit

Permalink
Generic call for prepare_cos_sin in rotary embedding (#638)
Browse files Browse the repository at this point in the history
Generic name discovery for rope.prepare_cos_sin. It fixes errors in
models that don't follow a specific naming hierarchy
  • Loading branch information
tzielinski-habana authored Dec 18, 2024
1 parent d81f829 commit 88ef381
Showing 1 changed file with 55 additions and 44 deletions.
99 changes: 55 additions & 44 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,40 +169,37 @@ def forward_hook(module, args, output):
modify_decoder_layer(child_module, suffix, n, counter)


def get_names_for_rope(model: torch.nn.Module):
"""Dynamically get layer names needed for cos and sin preparation for rope.
Every model can have a different naming convention for it's layers.
This function dynamically retrieves layer names to access rope layer.
If there's no rope layer, the function returns None.
This function assumes the following layer type layout:
Model -> ModuleList -> Attention -> RotaryEmbedding
def get_path_to_rope(model: torch.nn.Module):
"""Dynamically get the path to the RotaryEmbedding layer in the model.
This function will recursively search through the module hierarchy to find
a RotaryEmbedding layer and return the full path to that layer as a list
of names.
If no such layer is found, it returns None.
"""

def get_child(parent, suffix, is_list=False):
def find_rope_layer(parent, path):
# Base case: check if this parent is None
if parent is None:
return None, None
parent = parent[0] if is_list else parent
for child_name, child_module in parent.named_children():
if child_module.__class__.__name__.endswith(suffix):
return child_name, child_module
return None, None

model_name, model_module = get_child(model, "Model")
layers_name, layers_module = get_child(model_module, "ModuleList")
attn_name, attn_module = get_child(layers_module,
"Attention",
is_list=True)
rope_name, _ = get_child(attn_module, "RotaryEmbedding")

if rope_name is not None:
return {
'model_name': model_name,
'layers_name': layers_name,
'attn_name': attn_name,
'rope_name': rope_name
}
return None

# Check if the current layer is a RotaryEmbedding
if hasattr(parent, 'named_children'):
for child_name, child_module in parent.named_children():
# If the current child is of type RotaryEmbedding,
# return the full path
if child_module.__class__.__name__.endswith("RotaryEmbedding"):
return path + [child_name]
# Otherwise, recurse into this child to check its children
result = find_rope_layer(child_module, path + [child_name])
if result is not None:
return result
return None

# Start the search from the top level model
path_to_rope = find_rope_layer(model, [])

# Return the result if found, otherwise None
return path_to_rope


class HpuModelAdapter:
Expand Down Expand Up @@ -353,17 +350,31 @@ def _update_metadata(self, attn_metadata, batch_size, seq_len, device,
return attn_metadata

def _prepare_cos_sin(self, positions):
model_name = self.layer_names['model_name']
layers_name = self.layer_names['layers_name']
attn_name = self.layer_names['attn_name']
rope_name = self.layer_names['rope_name']

base_model = getattr(self.model, model_name)
first_model_layer = getattr(base_model, layers_name)[0]
attention_layer = getattr(first_model_layer, attn_name)
rope = getattr(attention_layer, rope_name)

rope.prepare_cos_sin(positions)
"""Navigate through the model using the provided path and call
the prepare_cos_sin method on the 'RotaryEmbedding' layer."""

current_module = self.model # Start from the top level of the model

for layer in self.layer_names:
if layer.isdigit(): # Check if the layer is an index
layer = int(layer)

# Check if the current layer is a name in a module
if isinstance(
layer,
str) and not isinstance(layer, int): # Name-based access
current_module = getattr(current_module, layer)
elif isinstance(layer,
int): # Indexed-based access (like ModuleList)
current_module = list(current_module._modules.values())[layer]

# At the end, we should be at the RotaryEmbedding layer.
if hasattr(current_module, 'prepare_cos_sin'):
current_module.prepare_cos_sin(positions)
else:
raise AttributeError(
"The module at the end of the path does not have \
a 'prepare_cos_sin' method.")

def forward(self, *args, **kwargs):
kwargs = kwargs.copy()
Expand Down Expand Up @@ -744,7 +755,7 @@ def load_model(self) -> None:
get_decoder_layer_suffix(model_config.model_type if
model_config is not None else None),
hidden_layer_markstep_interval)
names_for_rope = get_names_for_rope(self.model)
path_to_rope = get_path_to_rope(self.model)
torch.hpu.synchronize()

with HabanaMemoryProfiler() as m_wrap:
Expand All @@ -753,7 +764,7 @@ def load_model(self) -> None:
self.block_size,
dtype=self.model_config.dtype,
enforce_eager=self.enforce_eager,
layer_names=names_for_rope)
layer_names=path_to_rope)
msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}"
logger.info(msg)

Expand Down

0 comments on commit 88ef381

Please sign in to comment.