eliminate redundant casts and optimize gating

This commit is contained in:
w00521005
2025-08-15 14:43:29 +08:00
parent 04f872faf4
commit 449319eb8e
2 changed files with 24 additions and 14 deletions

View File

@@ -705,7 +705,8 @@ class DeepseekV3MoE(Cell):
self.routed_experts = ParallelMoEV2(ffn, self.config.hidden_size, self.moe_config)
else:
self.routed_experts = ExpertParallelMoE(ffn, self.config.hidden_size,
self.moe_config, self.config.parallel_config.use_alltoall)
self.moe_config, self.config.parallel_config.use_alltoall,
self.config.compute_dtype)
self.attn_reduce_scatter = config.parallel_config.attn_reduce_scatter
self.attn_allgather = config.parallel_config.attn_allgather
@@ -1602,7 +1603,7 @@ class InferenceDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
from mindformers.parallel_core.inference.parallel_state import get_data_parallel_group
tokens_len_per_dp = q_seq_len.sum().reshape(-1)
tokens_len_per_dp = ops.AllGather(group=get_data_parallel_group())(tokens_len_per_dp)
tokens_len_per_dp = ops.AllGather(group=get_data_parallel_group().group)(tokens_len_per_dp)
tokens_len_per_dp = tokens_len_per_dp.asnumpy()
padding_size = (tokens_len_per_dp.max() + tp_size - 1) // tp_size * tp_size
dp_rank_id = get_rank() // tp_size

View File

@@ -807,6 +807,7 @@ class ExpertParallelMoE(nn.Cell):
ffn (Cell): The FeedForward Module.
hidden_size (int): The hidden size of each token.
moe_config (MoEConfig): The configuration of MoE (Mixture of Expert).
compute_dtype(dtype.Number): The computation type of the layer.
Inputs:
- **input_tensor** (Tensor) - should be `[batch, seq_length, hidden_size].
@@ -818,8 +819,10 @@ class ExpertParallelMoE(nn.Cell):
ffn,
hidden_size,
moe_config,
use_alltoall):
use_alltoall,
compute_dtype):
super(ExpertParallelMoE, self).__init__()
self.compute_dtype = compute_dtype
self.hidden_size = hidden_size
self.moe_config = moe_config
self.expert_num = moe_config.expert_num
@@ -836,7 +839,8 @@ class ExpertParallelMoE(nn.Cell):
self.moe_token_unpermute = MoeTokenUnpermute()
self.moe_init_routing_v2 = MoeInitRoutingV2()
self.fused_add_topk_div = FusedAddTopKDiv()
self.dummy_token = mint.zeros((1, self.hidden_size), dtype=mstype.bfloat16)
self.dummy_token = mint.zeros((1, self.hidden_size), dtype=self.compute_dtype)
self.fill_value = Tensor(0, self.compute_dtype)
self.moe_tp_size = get_moe_tensor_parallel_world_size()
self.moe_ep_size = get_moe_expert_parallel_world_size()
@@ -853,17 +857,18 @@ class ExpertParallelMoE(nn.Cell):
self.local_ep_num = self.expert_num // self.moe_ep_size
self.ep_rank_index = get_rank() // self.moe_tp_size
self.in_start_expert_idx = self.ep_rank_index * self.local_ep_num
self.group_list_index = Tensor([0,], mstype.int32)
if self.moe_ep_size > 1 and not self.use_alltoall:
bias_idx = [idx for idx in range(self.expert_num)]
self.bias_idx = bias_idx[self.in_start_expert_idx:] + bias_idx[:self.in_start_expert_idx]
self.router.e_score_correction_bias = self.router.e_score_correction_bias[self.bias_idx]
def moe_with_allgather(self, input_tensor, expert_weight, expert_index):
"""moe feed forward with allgather."""
global_local_mask = expert_index < self.in_start_expert_idx
local_expert_index = expert_index - self.in_start_expert_idx
local_expert_index = self.cast(local_expert_index, mstype.int32)
local_expert_index = ops.masked_fill(local_expert_index, global_local_mask, self.expert_num - 1)
expert_weight = expert_weight.astype(input_tensor.dtype) # float32 -> bfloat16
global_local_mask1 = local_expert_index >= self.local_ep_num
expert_weight = ops.masked_fill(expert_weight, global_local_mask1, 0)
local_expert_index = self.cast(expert_index, mstype.int32)
expert_weight_mask = expert_index >= self.local_ep_num
expert_weight = ops.masked_fill(expert_weight, expert_weight_mask, self.fill_value)
sorted_input_tensor, unsort_map, group_list, _ = \
self.moe_init_routing_v2(
@@ -876,7 +881,11 @@ class ExpertParallelMoE(nn.Cell):
expert_tokens_count_or_cumsum_flag=2,
expert_tokens_before_capacity_flag=True)
group_list = mint.split(group_list, self.local_ep_num)[0]
#Avoid the problem of poor performance of the split(int32) operator
group_list = group_list.reshape(self.moe_ep_size, -1)
group_list = mint.index_select(group_list, 0, self.group_list_index)
group_list = group_list.reshape(-1)
group_list = self.cast(group_list, mstype.int64)
expert_output = self.ffn(sorted_input_tensor, group_list)
expert_output = mint.nan_to_num(expert_output, 0, 0, 0)
@@ -926,7 +935,6 @@ class ExpertParallelMoE(nn.Cell):
yout = ops.AlltoAllV(block_size=self.hidden_size)(y.reshape(-1), recv_list, send_list)
expert_output = yout.reshape((-1, self.hidden_size))
expert_weight = expert_weight.astype(input_tensor.dtype) # float32 -> bfloat16
moe_output = self.moe_token_unpermute(permuted_tokens=expert_output,
sorted_indices=unsort_map,
probs=expert_weight,
@@ -990,6 +998,7 @@ class ExpertParallelMoE(nn.Cell):
# AllGather
if not self.use_alltoall:
expert_weight = expert_weight.astype(input_tensor.dtype)
return self.moe_with_allgather(input_tensor, expert_weight, expert_index)
return self.moe_with_dispatch_combine(input_tensor, expert_weight, expert_index)