Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Onboard DeepSeek MoE with shared experts #1242

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Conversation

RissyRan
Copy link
Collaborator

@RissyRan RissyRan commented Feb 6, 2025

Description

Onboard DeepSeek MoE with shared experts (functional first - reference from DeepSeek):

  • Refactor models.py to handle mixed layers, i.e. both dense and moe layers
  • Add DeepSeek v3 config and deepseek.py as decoder layer
  • Add DeepSeekMoeBlock to reuse the dense and moe blocks so we could use either dropping or dropless for future tuning
  • Add a compile test

Tests

base_num_decoder_layers: 5 & num_experts: 16

  • Small config - Functional tests for scan_layers=True: link
  • Small config - Functional tests for scan_layers=False: link
  • One profile- matmul shapes on both dense and moe LGTM

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link
Collaborator

@gagika gagika left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only a few minor comments, Thanks!



class DeepSeekMoELayer(nn.Module):
"""DeepSeek-style MoE layer."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you comment what are main differences of DeepSeekMoELayer from regular MoELayer?

@@ -371,7 +378,11 @@ def permute(self, inputs, gate_logits):
inputs_shape = inputs.shape
inputs_2d = jnp.reshape(inputs, (inputs_shape[0] * inputs_shape[1], inputs_shape[2]))
weights, selected_experts = jax.lax.top_k(gate_logits, self.num_experts_per_tok)
weights = jax.nn.softmax(weights.astype(jnp.float32), axis=-1).astype(self.dtype)
if self.config.decoder_block == "deepseek":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some comment how/why it's different for deepseek?

also, perhaps you can move it to a function, as it's used in 2 places, e.g.

def _deepseek_scale_weights(self, weights):
"""Scales weights according to DeepSeek's ... ."""
weights /= weights.sum(-1, keepdims=True)
weights *= self.config.routed_scaling_factor
return weights

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants