-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sparse MoE code reading #28
base: main
Are you sure you want to change the base?
Conversation
from torch.distributions.normal import Normal | ||
from mlp import MLP | ||
import numpy as np | ||
class SparseDispatcher(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input 미니 배치가 있을 때, 이를 각각의 expert에 넘겨주는 dispatch
, 각 expert의 결과물을 모아서 하나의 tensor로 만드는 combine
을 하기 위한 헬퍼 함수.
combine - take output Tensors from each expert and form a combined output | ||
Tensor. Outputs from different experts for the same batch element are | ||
summed together, weighted by the provided "gates". |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
combine은 각각의 expert의 gate에 대한 weighted sum을 해줌
summed together, weighted by the provided "gates". | ||
The class is initialized with a "gates" Tensor, which specifies which | ||
batch elements go to which experts, and the weights to use when combining | ||
the outputs. Batch element b is sent to expert e iff gates[b, e] != 0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gates는 one hot vector. batch b가 expert e로 가면 1, 아니면 0
self._gates = gates | ||
self._num_experts = num_experts | ||
# sort experts | ||
sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nonzero -> 0이 아닌 tensor의 index뽑기 -> value 작은 순서대로 sort하기
# sort experts | ||
sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0) | ||
# drop indices | ||
_, self._expert_index = sorted_experts.split(1, dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dim=1에서 1씩 다른 텐서로 자르고 맨처음만 빼고 다시 self._exepert_index 로 저장함.
https://pytorch.org/docs/stable/generated/torch.split.html
|
||
threshold_positions_if_in = torch.arange(batch).to(self.device) * m + self.k | ||
threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1) | ||
is_in = torch.gt(noisy_values, threshold_if_in) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
greater than. https://pytorch.org/docs/stable/generated/torch.gt.html
element-wise로 noisy_values > theshold_if_in이면 true
noisy_top_values: a `Tensor` of shape [batch, m]. | ||
"values" Output of tf.top_k(noisy_top_values, m). m >= k+1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
top m개의 expert 텐서
and shapes `[expert_batch_size_i]` | ||
""" | ||
# split nonzero gates for each expert | ||
return torch.split(self._nonzero_gates, self._part_sizes, dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_nonzero_gate들을 각각 expert에 넣어줄 배치개수 만큼 잘라줌.
"""The squared coefficient of variation of a sample. | ||
Useful as a loss to encourage a positive distribution to be more uniform. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
텐서의 coefficient of variation을 구함.
def _gates_to_load(self, gates): | ||
"""Compute the true load per expert, given the gates. | ||
The load is the number of examples for which the corresponding gate is >0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gate들이 있을 때 각 expert 별 true load(noisy 하지 않다는 뜻인듯)를 계산.
gate > 0인 example의 개수를 load로 정의.
loss = self.cv_squared(importance) + self.cv_squared(load) | ||
loss *= loss_coef |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
importance에 대한 loss와 load에 대한 loss 합침
def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2): | ||
"""Noisy top-k gating. | ||
See paper: https://arxiv.org/abs/1701.06538. | ||
Args: | ||
x: input Tensor with shape [batch_size, input_size] | ||
train: a boolean - we only add noise at training time. | ||
noise_epsilon: a float | ||
Returns: | ||
gates: a Tensor with shape [batch_size, num_experts] | ||
load: a Tensor with shape [num_experts] | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
top-k gate noisy 하게 만드는 부분
top_k_gates = self.softmax(top_k_logits) | ||
|
||
zeros = torch.zeros_like(logits, requires_grad=True).to(self.device) | ||
gates = zeros.scatter(1, top_k_indices, top_k_gates).to(self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Paper
https://arxiv.org/abs/1701.06538
논문정리 : notion
구현체
https://github.com/davidmrau/mixture-of-experts