classMixtralSparseMoeBlock(nn.Module):"""
This implementation is
strictly equivalent to standard MoE with full capacity (no
dropped tokens). It's faster since it formulates MoE operations
in terms of block-sparse operations to accomodate imbalanced
assignments of tokens to experts, whereas standard MoE either
(1) drop tokens at the cost of reduced performance or (2) set
capacity factor to number of experts and thus waste computation
and memory on padding.
"""def__init__(self,config):super().__init__()self.hidden_dim=config.hidden_sizeself.ffn_dim=config.intermediate_sizeself.num_experts=config.num_local_expertsself.top_k=config.num_experts_per_tok# gatingself.gate=nn.Linear(self.hidden_dim,self.num_experts,bias=False)self.experts=nn.ModuleList([MixtralBLockSparseTop2MLP(config)for_inrange(self.num_experts)])defforward(self,hidden_states:torch.Tensor)->torch.Tensor:""" """# 由attention计算后输出的hidden_states作为输入batch_size,sequence_length,hidden_dim=hidden_states.shape# 将hidden_states构建成一个二维的形状,用于处理每一个tokenhidden_states=hidden_states.view(-1,hidden_dim)# router_logits: (batch * sequence_length, n_experts)# 通过门控来生成路由,用来决定每一个token由哪些专家处理router_logits=self.gate(hidden_states)# 通过softmax计算每一个专家对于每个token的处理权重routing_weights=F.softmax(router_logits,dim=1,dtype=torch.float)# 选取每个token的前top_k个专家和其对应的权重 selected_experts: (batch * sequence_length, top_k)routing_weights,selected_experts=torch.topk(routing_weights,self.top_k,dim=-1)# 对每一个token对应的专家的权重值进行归一化,使其权重之和为1routing_weights/=routing_weights.sum(dim=-1,keepdim=True)# we cast back to the input dtyperouting_weights=routing_weights.to(hidden_states.dtype)# final_hidden_states用来存储每个token对应的专家结果,初始值为0final_hidden_states=torch.zeros((batch_size*sequence_length,hidden_dim),dtype=hidden_states.dtype,device=hidden_states.device)# One hot encode the selected experts to create an expert mask# this will be used to easily index which expert is going to be sollicitated# 使用one hot编码来代表每个token使用哪些专家# one hot: (batch * sequence_length, top_k, num_experts) => expert_mask: (num_experts, top_k, batch * sequence_length)# 这样做的好处就是,用专家的视角,每次遍历只需要遍历每个专家所需要处理的token即可,否则需要遍历每个token使用了哪个专家,前向的次数随着文本的长度线性增加。expert_mask=torch.nn.functional.one_hot(selected_experts,num_classes=self.num_experts).permute(2,1,0)# Loop over all available experts in the model and perform the computation on each expertforexpert_idxinrange(self.num_experts):expert_layer=self.experts[expert_idx]# idx代表当前专家作为top1需要负责的token索引、作为top2需要负责的token的索引# top_x代表当前专家负责的token的索引位置。idx,top_x=torch.where(expert_mask[expert_idx])# 如果top_x中没有1,则代表当前专家不负责任何token,就跳过这个专家iftop_x.shape[0]==0:continue# in torch it is faster to index using lists than torch tensorstop_x_list=top_x.tolist()idx_list=idx.tolist()# Index the correct hidden states and compute the expert hidden state for# the current expert. We need to make sure to multiply the output hidden# states by `routing_weights` on the corresponding tokens (top-1 and top-2)# 根据索引从输入的隐向量中取得对应的向量,传入到专家模型中进行前向计算current_state=hidden_states[None,top_x_list].reshape(-1,hidden_dim)current_hidden_states=expert_layer(current_state)*routing_weights[top_x_list,idx_list,None]# However `index_add_` only support torch tensors for indexing so we'll use# the `top_x` tensor here.# 将当前专家模型的输出写入到预先定义好的final_hidden_states中final_hidden_states.index_add_(0,top_x,current_hidden_states.to(hidden_states.dtype))final_hidden_states=final_hidden_states.reshape(batch_size,sequence_length,hidden_dim)returnfinal_hidden_states,router_logits