mirror of
https://gitee.com/mindspore/mindformers.git
synced 2025-12-06 11:29:59 +08:00
eliminate redundant casts and optimize gating
This commit is contained in:
@@ -705,7 +705,8 @@ class DeepseekV3MoE(Cell):
|
||||
self.routed_experts = ParallelMoEV2(ffn, self.config.hidden_size, self.moe_config)
|
||||
else:
|
||||
self.routed_experts = ExpertParallelMoE(ffn, self.config.hidden_size,
|
||||
self.moe_config, self.config.parallel_config.use_alltoall)
|
||||
self.moe_config, self.config.parallel_config.use_alltoall,
|
||||
self.config.compute_dtype)
|
||||
|
||||
self.attn_reduce_scatter = config.parallel_config.attn_reduce_scatter
|
||||
self.attn_allgather = config.parallel_config.attn_allgather
|
||||
@@ -1602,7 +1603,7 @@ class InferenceDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
|
||||
|
||||
from mindformers.parallel_core.inference.parallel_state import get_data_parallel_group
|
||||
tokens_len_per_dp = q_seq_len.sum().reshape(-1)
|
||||
tokens_len_per_dp = ops.AllGather(group=get_data_parallel_group())(tokens_len_per_dp)
|
||||
tokens_len_per_dp = ops.AllGather(group=get_data_parallel_group().group)(tokens_len_per_dp)
|
||||
tokens_len_per_dp = tokens_len_per_dp.asnumpy()
|
||||
padding_size = (tokens_len_per_dp.max() + tp_size - 1) // tp_size * tp_size
|
||||
dp_rank_id = get_rank() // tp_size
|
||||
|
||||
@@ -807,6 +807,7 @@ class ExpertParallelMoE(nn.Cell):
|
||||
ffn (Cell): The FeedForward Module.
|
||||
hidden_size (int): The hidden size of each token.
|
||||
moe_config (MoEConfig): The configuration of MoE (Mixture of Expert).
|
||||
compute_dtype(dtype.Number): The computation type of the layer.
|
||||
Inputs:
|
||||
- **input_tensor** (Tensor) - should be `[batch, seq_length, hidden_size].
|
||||
|
||||
@@ -818,8 +819,10 @@ class ExpertParallelMoE(nn.Cell):
|
||||
ffn,
|
||||
hidden_size,
|
||||
moe_config,
|
||||
use_alltoall):
|
||||
use_alltoall,
|
||||
compute_dtype):
|
||||
super(ExpertParallelMoE, self).__init__()
|
||||
self.compute_dtype = compute_dtype
|
||||
self.hidden_size = hidden_size
|
||||
self.moe_config = moe_config
|
||||
self.expert_num = moe_config.expert_num
|
||||
@@ -836,7 +839,8 @@ class ExpertParallelMoE(nn.Cell):
|
||||
self.moe_token_unpermute = MoeTokenUnpermute()
|
||||
self.moe_init_routing_v2 = MoeInitRoutingV2()
|
||||
self.fused_add_topk_div = FusedAddTopKDiv()
|
||||
self.dummy_token = mint.zeros((1, self.hidden_size), dtype=mstype.bfloat16)
|
||||
self.dummy_token = mint.zeros((1, self.hidden_size), dtype=self.compute_dtype)
|
||||
self.fill_value = Tensor(0, self.compute_dtype)
|
||||
|
||||
self.moe_tp_size = get_moe_tensor_parallel_world_size()
|
||||
self.moe_ep_size = get_moe_expert_parallel_world_size()
|
||||
@@ -853,17 +857,18 @@ class ExpertParallelMoE(nn.Cell):
|
||||
self.local_ep_num = self.expert_num // self.moe_ep_size
|
||||
self.ep_rank_index = get_rank() // self.moe_tp_size
|
||||
self.in_start_expert_idx = self.ep_rank_index * self.local_ep_num
|
||||
self.group_list_index = Tensor([0,], mstype.int32)
|
||||
|
||||
if self.moe_ep_size > 1 and not self.use_alltoall:
|
||||
bias_idx = [idx for idx in range(self.expert_num)]
|
||||
self.bias_idx = bias_idx[self.in_start_expert_idx:] + bias_idx[:self.in_start_expert_idx]
|
||||
self.router.e_score_correction_bias = self.router.e_score_correction_bias[self.bias_idx]
|
||||
|
||||
def moe_with_allgather(self, input_tensor, expert_weight, expert_index):
|
||||
"""moe feed forward with allgather."""
|
||||
global_local_mask = expert_index < self.in_start_expert_idx
|
||||
local_expert_index = expert_index - self.in_start_expert_idx
|
||||
local_expert_index = self.cast(local_expert_index, mstype.int32)
|
||||
local_expert_index = ops.masked_fill(local_expert_index, global_local_mask, self.expert_num - 1)
|
||||
|
||||
expert_weight = expert_weight.astype(input_tensor.dtype) # float32 -> bfloat16
|
||||
global_local_mask1 = local_expert_index >= self.local_ep_num
|
||||
expert_weight = ops.masked_fill(expert_weight, global_local_mask1, 0)
|
||||
local_expert_index = self.cast(expert_index, mstype.int32)
|
||||
expert_weight_mask = expert_index >= self.local_ep_num
|
||||
expert_weight = ops.masked_fill(expert_weight, expert_weight_mask, self.fill_value)
|
||||
|
||||
sorted_input_tensor, unsort_map, group_list, _ = \
|
||||
self.moe_init_routing_v2(
|
||||
@@ -876,7 +881,11 @@ class ExpertParallelMoE(nn.Cell):
|
||||
expert_tokens_count_or_cumsum_flag=2,
|
||||
expert_tokens_before_capacity_flag=True)
|
||||
|
||||
group_list = mint.split(group_list, self.local_ep_num)[0]
|
||||
#Avoid the problem of poor performance of the split(int32) operator
|
||||
group_list = group_list.reshape(self.moe_ep_size, -1)
|
||||
group_list = mint.index_select(group_list, 0, self.group_list_index)
|
||||
group_list = group_list.reshape(-1)
|
||||
|
||||
group_list = self.cast(group_list, mstype.int64)
|
||||
expert_output = self.ffn(sorted_input_tensor, group_list)
|
||||
expert_output = mint.nan_to_num(expert_output, 0, 0, 0)
|
||||
@@ -926,7 +935,6 @@ class ExpertParallelMoE(nn.Cell):
|
||||
yout = ops.AlltoAllV(block_size=self.hidden_size)(y.reshape(-1), recv_list, send_list)
|
||||
expert_output = yout.reshape((-1, self.hidden_size))
|
||||
|
||||
expert_weight = expert_weight.astype(input_tensor.dtype) # float32 -> bfloat16
|
||||
moe_output = self.moe_token_unpermute(permuted_tokens=expert_output,
|
||||
sorted_indices=unsort_map,
|
||||
probs=expert_weight,
|
||||
@@ -990,6 +998,7 @@ class ExpertParallelMoE(nn.Cell):
|
||||
|
||||
# AllGather
|
||||
if not self.use_alltoall:
|
||||
expert_weight = expert_weight.astype(input_tensor.dtype)
|
||||
return self.moe_with_allgather(input_tensor, expert_weight, expert_index)
|
||||
|
||||
return self.moe_with_dispatch_combine(input_tensor, expert_weight, expert_index)
|
||||
|
||||
Reference in New Issue
Block a user