Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse MoE code reading #28

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Sparse MoE code reading #28

wants to merge 2 commits into from

Conversation

long8v
Copy link
Owner

@long8v long8v commented May 10, 2022

from torch.distributions.normal import Normal
from mlp import MLP
import numpy as np
class SparseDispatcher(object):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input 미니 배치가 있을 때, 이를 각각의 expert에 넘겨주는 dispatch, 각 expert의 결과물을 모아서 하나의 tensor로 만드는 combine을 하기 위한 헬퍼 함수.

Comment on lines +23 to +25
combine - take output Tensors from each expert and form a combined output
Tensor. Outputs from different experts for the same batch element are
summed together, weighted by the provided "gates".
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

combine은 각각의 expert의 gate에 대한 weighted sum을 해줌

summed together, weighted by the provided "gates".
The class is initialized with a "gates" Tensor, which specifies which
batch elements go to which experts, and the weights to use when combining
the outputs. Batch element b is sent to expert e iff gates[b, e] != 0.
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gates는 one hot vector. batch b가 expert e로 가면 1, 아니면 0

self._gates = gates
self._num_experts = num_experts
# sort experts
sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nonzero -> 0이 아닌 tensor의 index뽑기 -> value 작은 순서대로 sort하기

# sort experts
sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0)
# drop indices
_, self._expert_index = sorted_experts.split(1, dim=1)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dim=1에서 1씩 다른 텐서로 자르고 맨처음만 빼고 다시 self._exepert_index 로 저장함.
https://pytorch.org/docs/stable/generated/torch.split.html


threshold_positions_if_in = torch.arange(batch).to(self.device) * m + self.k
threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1)
is_in = torch.gt(noisy_values, threshold_if_in)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

greater than. https://pytorch.org/docs/stable/generated/torch.gt.html
element-wise로 noisy_values > theshold_if_in이면 true

Comment on lines +197 to +198
noisy_top_values: a `Tensor` of shape [batch, m].
"values" Output of tf.top_k(noisy_top_values, m). m >= k+1
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

top m개의 expert 텐서

and shapes `[expert_batch_size_i]`
"""
# split nonzero gates for each expert
return torch.split(self._nonzero_gates, self._part_sizes, dim=0)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_nonzero_gate들을 각각 expert에 넣어줄 배치개수 만큼 잘라줌.

Comment on lines +154 to +155
"""The squared coefficient of variation of a sample.
Useful as a loss to encourage a positive distribution to be more uniform.
Copy link
Owner Author

@long8v long8v May 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

텐서의 coefficient of variation을 구함.

Comment on lines +172 to +174
def _gates_to_load(self, gates):
"""Compute the true load per expert, given the gates.
The load is the number of examples for which the corresponding gate is >0.
Copy link
Owner Author

@long8v long8v May 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gate들이 있을 때 각 expert 별 true load(noisy 하지 않다는 뜻인듯)를 계산.
gate > 0인 example의 개수를 load로 정의.

@long8v long8v changed the title add MoE code review Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer May 13, 2022
Comment on lines +271 to +272
loss = self.cv_squared(importance) + self.cv_squared(load)
loss *= loss_coef
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

importance에 대한 loss와 load에 대한 loss 합침

Comment on lines +218 to +228
def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2):
"""Noisy top-k gating.
See paper: https://arxiv.org/abs/1701.06538.
Args:
x: input Tensor with shape [batch_size, input_size]
train: a boolean - we only add noise at training time.
noise_epsilon: a float
Returns:
gates: a Tensor with shape [batch_size, num_experts]
load: a Tensor with shape [num_experts]
"""
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

top-k gate noisy 하게 만드는 부분

top_k_gates = self.softmax(top_k_logits)

zeros = torch.zeros_like(logits, requires_grad=True).to(self.device)
gates = zeros.scatter(1, top_k_indices, top_k_gates).to(self.device)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@long8v long8v changed the title Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer Sparse MoE code reading Jul 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant