Compare commits

...

4 Commits

Author SHA1 Message Date
i-robot
983b95ebc0 !6217 fix telechat2 infer
Merge pull request !6217 from 森镇/fix_telechat2_infer_1.6.0
2025-05-19 13:21:13 +00:00
senzhen
5d6081e56f fix telechat2 infer 2025-05-19 16:29:49 +08:00
zhulinhong
556c5f36c1 !6215 【bugfix】cherry-pick aux_loss from pr 6119
Merge pull request !6215 from 黄勇/aux_loss_r1.6
2025-05-16 13:33:04 +00:00
huangyong
fa3d2c9e74 aux_loss bugfix 2025-05-16 16:51:40 +08:00
11 changed files with 130 additions and 37 deletions

View File

@@ -1335,6 +1335,7 @@ class FreqsMgrDynamicNTK(Cell):
def __init__(self,
head_dim,
max_position_embedding,
base_seqlen=None,
rotary_dtype=mstype.float16,
theta=10000,
parallel_config=None,
@@ -1378,7 +1379,7 @@ class FreqsMgrDynamicNTK(Cell):
self.base = theta
self.max_position_embedding = max_position_embedding
self.max_position_embedding_inverse = 1 / max_position_embedding
self.max_position_embedding_inverse = 1 / base_seqlen if base_seqlen else 1 / max_position_embedding
self.log_scale_inverse = 1 / math.log(2)
self.log_scale = math.log(2)
self.min_ntk_alpha = 1.0

View File

@@ -1567,17 +1567,21 @@ class TopkRouterV2(Cell):
expert_gate = self._normalize(expert_gate) # (dp, N, k) <-- (dp, N, k)
expert_gate = ops.mul(self.moe_config.routed_scaling_factor, expert_gate) # (dp, N, k) <-- (dp, N, k)
if self.moe_config.balance_via_topk_bias and (self.aux_loss_config.get("expert", 0) > 0 \
or self.aux_loss_config.get("device", 0) > 0 or self.aux_loss_config.get("comm", 0) > 0):
_, expert_index_for_aux = self.topk(router_prob_for_aux, self.num_experts_chosen)
else:
expert_index_for_aux = expert_index
if self.aux_loss_config.get("expert", 0):
expert_load_loss = self._expert_load_balancing(router_prob_for_aux, expert_index,
expert_load_loss = self._expert_load_balancing(router_prob_for_aux, expert_index_for_aux,
self.aux_loss_config.get("expert"), seq_chunk=seq_chunk)
extra_loss = self.add_scalar(extra_loss, expert_load_loss)
if self.aux_loss_config.get("device", 0):
device_load_loss = self._device_load_balancing(router_prob_for_aux, expert_index,
device_load_loss = self._device_load_balancing(router_prob_for_aux, expert_index_for_aux,
self.aux_loss_config.get("device"))
extra_loss = self.add_scalar(extra_loss, device_load_loss)
if self.aux_loss_config.get("comm", 0):
comm_load_loss = self._comm_load_balancing(router_prob_for_aux, expert_index,
comm_load_loss = self._comm_load_balancing(router_prob_for_aux, expert_index_for_aux,
self.aux_loss_config.get("comm"))
extra_loss = self.add_scalar(extra_loss, comm_load_loss)

View File

@@ -241,7 +241,11 @@ class MoEV3(Cell):
router_coeff = expert_gate
router_coeff = self.mul(self.moe_config.routed_scaling_factor, router_coeff)
# float32 <-- (dp, N, E) fp32, (dp, N, k) int32, float32
router_aux_loss = self._expert_load_balancing(router_prob_for_aux, expert_index, self.aux_loss_factor,
if self.moe_config.balance_via_topk_bias:
_, expert_index_for_aux = self.topk(router_prob_for_aux, self.num_experts_chosen)
else:
expert_index_for_aux = expert_index
router_aux_loss = self._expert_load_balancing(router_prob_for_aux, expert_index_for_aux, self.aux_loss_factor,
seq_chunk=seq_chunk)
if self.enable_deredundency or self.use_3d_tensor_parallel:

View File

@@ -15,9 +15,9 @@
"""Telechat models' APIs."""
import numpy as np
import mindspore as ms
import mindspore.common.dtype as mstype
from mindspore import Tensor, ops, mint
from mindspore.communication import get_group_size
from mindspore import Tensor, ops, mint, mutable
from mindspore.communication._comm_helper import _is_initialized
from mindformers.experimental.infer.core.layers import ColumnParallelLinear
@@ -25,9 +25,14 @@ from mindformers.experimental.parallel_core.pynative.parallel_state import get_g
from mindformers.experimental.infer.models.llama.utils import convert_model_config
from mindformers.models.modeling_utils import PreTrainedModel
from mindformers.modules import Linear
from mindformers.tools.register.register import MindFormerModuleType, MindFormerRegister
from mindformers.tools.utils import get_predict_run_mode
from mindformers.tools.logger import logger
from mindformers.tools.register.register import MindFormerModuleType, MindFormerRegister
from mindformers.tools.utils import get_predict_run_mode, is_pynative
from mindformers.experimental.parallel_core.pynative.parallel_state import (
get_data_parallel_group,
get_tensor_model_parallel_group,
)
from mindformers.models.utils import jit
from research.telechat2.infer.telechat_transformers import TelechatParallelTransformer
from research.telechat2.telechat_config import TelechatConfig
@@ -49,7 +54,7 @@ class TelechatPreTrainedModel(PreTrainedModel):
@MindFormerRegister.register(MindFormerModuleType.MODELS)
class ParallelTelechatForCausalLM(TelechatPreTrainedModel):
r"""
Provide llama training loss or logits through network.
Provide telechat training loss or logits through network.
Args:
config (TelechatConfig): The config of llama model.
@@ -61,15 +66,24 @@ class ParallelTelechatForCausalLM(TelechatPreTrainedModel):
def __init__(self, config):
super().__init__(config, auto_prefix=True)
if get_group_info('tp').group is None and _is_initialized():
initialize_model_parallel(get_group_size(), order='tp')
self.config = convert_model_config(config)
self.config.out_proj_has_bias = True
tp_group = get_group_info('tp').group is None
dp_group = get_group_info('dp').group is None
logger.info(f"tp_group is:{tp_group}")
logger.info(f"dp_group is:{dp_group}")
all_groups_initialized = tp_group and dp_group
if all_groups_initialized and _is_initialized():
initialize_model_parallel(tensor_model_parallel_size=self.config.parallel_config.model_parallel,
order='tp-dp')
logger.info(f"data_parallel_group:{get_data_parallel_group()}")
logger.info(f"tensor_model_parallel_group:{get_tensor_model_parallel_group()}")
self.ignore_token_id = config.ignore_token_id
self.pad_token_id = config.pad_token_id
self.use_past = config.use_past
self.vocab_size = config.vocab_size
self.is_first_iteration = True
self.is_pynative = is_pynative()
self.shape = ops.Shape()
self.reshape = ops.Reshape()
@@ -79,7 +93,7 @@ class ParallelTelechatForCausalLM(TelechatPreTrainedModel):
self.mul = ops.Mul()
self.add = ops.Add()
self.ones = ops.Ones()
self.gather = ops.Gather()
self.gather = ops.Gather(1) if self.is_pynative else ops.Gather()
self.sub_batch_valid_len = ops.Sub()
self.model = TelechatParallelTransformer(config=config)
if config.parallel_config.vocab_emb_dp:
@@ -106,6 +120,8 @@ class ParallelTelechatForCausalLM(TelechatPreTrainedModel):
self.predict_run_mode = get_predict_run_mode()
self.use_past = config.use_past
self.npu_mem_size = config.npu_mem_size if hasattr(config, "npu_mem_size") else 2
self.return_hidden_states = config.return_hidden_states
# pylint: disable=W0613
def prepare_inputs_for_predict_layout(self, input_ids, **kwargs):
@@ -117,22 +133,73 @@ class ParallelTelechatForCausalLM(TelechatPreTrainedModel):
prefix_keys_values = Tensor(kwargs["prefix_keys_values"]) if "prefix_keys_values" in kwargs else None
return input_ids, labels, None, None, None, None, None, None, None, None, None, slot_mapping, prefix_keys_values
def prepare_inputs_for_generation(self, input_ids, **kwargs):
"""
prepare inputs for generation.
A model class needs to define a `prepare_inputs_for_generation` method
in order to use `.generate()`
"""
model_inputs = {"input_ids": Tensor.from_numpy(input_ids.astype(np.int32))}
batch_valid_length = kwargs.get("valid_length_each_example")
prefill = kwargs.get("prefill")
if self.is_pynative:
model_inputs = {}
if self.config.is_dynamic and "origin_inputs" in kwargs and self.use_past:
input_ids = kwargs["origin_inputs"]
model_inputs["input_ids"] = Tensor.from_numpy(input_ids.astype(np.int32))
else:
if self.config.is_dynamic:
if prefill and "origin_inputs" in kwargs:
origin_inputs = kwargs["origin_inputs"]
slot_mapping = kwargs.get("slot_mapping")
model_inputs = self._prepare_inputs_for_prefill_flatten(origin_inputs,
batch_valid_length,
slot_mapping,
model_inputs)
position_ids = batch_valid_length - 1
model_inputs["position_ids"] = ms.Tensor(position_ids, dtype=ms.int32).reshape(-1)
if not prefill:
q_seq_lens = np.ones(batch_valid_length.shape, dtype=np.int32).reshape(-1)
else:
q_seq_lens = batch_valid_length.astype(np.int32).reshape(-1)
model_inputs["q_seq_lens"] = Tensor.from_numpy(q_seq_lens)
model_inputs["attention_mask"] = self.model.casual_mask.gen_attention_mask(prefill)
return model_inputs
def set_dynamic_inputs(self, **kwargs):
"""Set dynamic input for telechat."""
dynamic_input_ids = Tensor(shape=[None, None], dtype=mstype.int32)
dynamic_batch_valid_length = Tensor(shape=[None, None], dtype=mstype.int32)
dynamic_block_tables = Tensor(shape=[None, None], dtype=mstype.int32)
dynamic_slot_mapping = Tensor(shape=[None], dtype=mstype.int32)
dynamic_position_ids = Tensor(shape=[None], dtype=mstype.int32)
dynamic_q_seq_lens = Tensor(shape=[None], dtype=mstype.int32)
dynamic_attention_mask = Tensor(shape=[None, None], dtype=self.config.compute_dtype)
have_prefix_keys_values = getattr(kwargs, "have_prefix_keys_values", False)
def get_input():
if self.npu_mem_size > 0:
return None
cache_list = []
for _ in self.model.layers:
cache_list.append(Tensor(shape=[None, None, None, None], dtype=self.config.compute_dtype))
return mutable(cache_list)
key_cache = get_input()
value_cache = get_input()
if have_prefix_keys_values:
dynamic_prefix_keys_values = Tensor(shape=[2, None, None, None, None], dtype=mstype.float16)
self.set_inputs(dynamic_input_ids, None, None, None, None, None, None,
dynamic_batch_valid_length, None, None, dynamic_block_tables,
dynamic_slot_mapping, dynamic_prefix_keys_values, None)
dynamic_slot_mapping, dynamic_prefix_keys_values, None, key_cache, value_cache)
else:
self.set_inputs(dynamic_input_ids, None, None, None, None, None, None,
self.set_inputs(dynamic_input_ids, None, None, dynamic_position_ids, dynamic_attention_mask, None, None,
dynamic_batch_valid_length, None, None, dynamic_block_tables,
dynamic_slot_mapping, None, None)
dynamic_slot_mapping, None, None, dynamic_q_seq_lens, key_cache, value_cache)
logger.info("Set dynamic input for telechat.")
def add_flags_custom(self, is_first_iteration):
@@ -145,22 +212,30 @@ class ParallelTelechatForCausalLM(TelechatPreTrainedModel):
layer.attention.paged_attention_mgr.add_flags(is_first_iteration=is_first_iteration)
# pylint: disable=W0613
@jit
def construct(self, input_ids, labels=None, input_position=None, position_ids=None, attention_mask=None,
input_embeds=None, init_reset=None, batch_valid_length=None, batch_index=None, zactivate_len=None,
block_tables=None, slot_mapping=None, prefix_keys_values=None, llm_boost_inputs=None):
block_tables=None, slot_mapping=None, prefix_keys_values=None, llm_boost_inputs=None,
q_seq_lens=None, key_cache=None, value_cache=None):
"""
Forward of llama model.
Forward of telechat model.
"""
bsz, _ = self.shape(input_ids)
if batch_valid_length is not None:
batch_valid_length = batch_valid_length.reshape(-1,)
else:
batch_valid_length = self.ones((bsz,), mstype.int32)
if self.use_past:
if not isinstance(batch_valid_length, Tensor):
batch_valid_length = self.ones((bsz,), mstype.int32)
else:
batch_valid_length = self.reshape(batch_valid_length, (-1,))
output = self.model(input_ids, batch_valid_length, batch_index, zactivate_len, block_tables,
slot_mapping, prefix_keys_values)
slot_mapping, prefix_keys_values, position_ids=position_ids, attention_mask=attention_mask,
q_seq_lens=q_seq_lens, key_cache=key_cache, value_cache=value_cache)
if self.return_hidden_states:
output = self.reshape(output, (-1, output.shape[-1]))
return output
pre_gather = (not self.use_past or self.is_first_iteration) and batch_valid_length is not None
if pre_gather:
batch_valid_length = mint.cumsum(batch_valid_length, 0)
if not self.is_pynative:
batch_valid_length = mint.cumsum(batch_valid_length, 0)
output = self.gather(output, self.sub_batch_valid_len(batch_valid_length, 1), 1)
logits = self.lm_head(output)

View File

@@ -204,8 +204,8 @@ class TelechatParallelAttention(ParallelAttention):
"""
def construct(self, x, batch_valid_length, block_tables, slot_mapping, freqs_cis=None,
attn_mask=None, alibi_mask=None, prefix_keys_values=None, encoder_output=None,
key_cache=None, value_cache=None):
attn_mask=None, alibi_mask=None, encoder_output=None, prefix_keys_values=None,
q_seq_lens=None, key_cache=None, value_cache=None):
"""Construct function of attention block."""
# hidden_states: [B, S, H]
ori_dtype = x.dtype
@@ -283,7 +283,7 @@ class TelechatParallelAttention(ParallelAttention):
context_layer = self.core_attention(query, key, value, attn_mask)
else:
context_layer = self.paged_attention_mgr.paged_attn(query, batch_valid_length, block_tables,
key_cache=key_cache, value_cache=value_cache)
attn_mask, q_seq_lens, key_cache, value_cache)
else:
# [B, S, H] -> [B, N, S, D]
query = query.reshape(bs, seq_len, -1, self.head_dim).transpose((0, 2, 1, 3))
@@ -390,8 +390,8 @@ class TelechatParallelTransformerLayer(ParallelTransformerLayer):
# MLP
self.expert_num = 1 if config.moe_config is None else config.moe_config.expert_num
self.use_moe_infer = config.use_past and self.expert_num > 1
config.moe_config.router_dense_type = config.router_dense_type
if self.use_moe_infer:
config.moe_config.router_dense_type = config.router_dense_type
self.feed_forward = TelechatParallelMoE(
ffn=TelechatRoutedParallelMLP(config),
hidden_size=config.hidden_size,
@@ -436,8 +436,11 @@ class TelechatParallelTransformer(ParallelTransformer):
self.enable_dynamic_ntk = False
if config.extend_method == 'DYNAMIC_NTK':
self.enable_dynamic_ntk = True
base_seqlen = config.base_seqlen \
if hasattr(config, "base_seqlen") and config.base_seqlen else None
self.freqs_mgr = FreqsMgrDynamicNTK(head_dim=self.head_dim,
max_position_embedding=config.max_position_embedding,
base_seqlen=base_seqlen,
rotary_dtype=config.rotary_dtype,
theta=config.theta,
parallel_config=config.parallel_config,
@@ -451,7 +454,8 @@ class TelechatParallelTransformer(ParallelTransformer):
# pylint: disable=W0613
def construct(self, tokens: Tensor, batch_valid_length=None, batch_index=None, zactivate_len=None,
block_tables=None, slot_mapping=None, prefix_keys_values=None, key_cache=None, value_cache=None):
block_tables=None, slot_mapping=None, prefix_keys_values=None, position_ids=None, attention_mask=None,
q_seq_lens=None, key_cache=None, value_cache=None):
"""
Forward of ParallelTransformer.
@@ -466,7 +470,7 @@ class TelechatParallelTransformer(ParallelTransformer):
"""
# preprocess
bs, seq_len = self.shape(tokens)
mask = None
mask = attention_mask
if self.use_past:
if self.is_first_iteration:
if self.enable_dynamic_ntk:
@@ -476,8 +480,6 @@ class TelechatParallelTransformer(ParallelTransformer):
if self.is_pynative:
mask = self.casual_mask(tokens)
else:
mask = self.casual_mask.prefill()
if prefix_keys_values is not None:
if mask is None:
@@ -504,7 +506,7 @@ class TelechatParallelTransformer(ParallelTransformer):
value_cache_i = value_cache[i] if value_cache is not None else None
hidden_states = self.layers[i](hidden_states, freqs_cis, mask, batch_valid_length=batch_valid_length,
block_tables=block_tables, slot_mapping=slot_mapping,
prefix_keys_values=prefix_kv,
prefix_keys_values=prefix_kv, q_seq_lens=q_seq_lens,
key_cache=key_cache_i, value_cache=value_cache_i)
if self.post_norm:

View File

@@ -19,6 +19,7 @@ parallel:
parallel_config:
data_parallel: 1
model_parallel: 8
vocab_emb_dp: False
# mindspore context init config
context:

View File

@@ -19,6 +19,7 @@ parallel:
parallel_config:
data_parallel: 1
model_parallel: 2
vocab_emb_dp: False
# mindspore context init config
context:

View File

@@ -33,6 +33,7 @@ parallel:
parallel_config:
data_parallel: 1
model_parallel: 2
vocab_emb_dp: False
# mindspore context init config
context:

View File

@@ -20,6 +20,7 @@ parallel:
parallel_config:
data_parallel: 1
model_parallel: 1
vocab_emb_dp: False
# mindspore context init config
context:

View File

@@ -93,6 +93,7 @@ class TelechatConfig(PretrainedConfig):
The maximum number of tokens in one block can have when using paged attention.
num_blocks (`int`, *optional*, defaults to 512):
The maximum number of blocks when using paged attention.
return_hidden_states (bool, optional): Whether to return hidden states. Default: ``False``.
Returns:
Class, TelechatConfig.
"""
@@ -155,6 +156,7 @@ class TelechatConfig(PretrainedConfig):
quant: str = "",
sigma: float = 0.0048,
mean: float = 0.0,
return_hidden_states: bool = False,
**kwargs):
super(TelechatConfig, self).__init__(**kwargs)
if isinstance(parallel_config, dict):
@@ -218,3 +220,4 @@ class TelechatConfig(PretrainedConfig):
self.num_blocks = num_blocks
self.quant = quant
self.qkv_concat = qkv_concat
self.return_hidden_states = return_hidden_states

View File

@@ -52,17 +52,17 @@ class TestInferParallel:
Expectation: AssertionError
"""
commands = [
(f"export ASCEND_RT_VISIBLE_DEVICES=0,1 && export LCAL_IF_PORT=10068 && "
(f"export ASCEND_RT_VISIBLE_DEVICES=0,1 && export LCAL_COMM_ID=127.0.0.1:10068 && "
f"export HCCL_IF_BASE_PORT=61000 && msrun --worker_num=2 "
f"--local_worker_num=2 --master_port=8222 --log_dir=parallel_qwen2_0_5b_predict_mp2 --join=True "
f"{cur_dir}/run_parallel.py --mode parallel_qwen2_0_5b_predict_mp2",
'parallel_qwen2_0_5b_predict_mp2/worker_0.log'),
(f"export ASCEND_RT_VISIBLE_DEVICES=2,3 && export LCAL_IF_PORT=10069 && "
(f"export ASCEND_RT_VISIBLE_DEVICES=2,3 && export LCAL_COMM_ID=127.0.0.1:10069 && "
f"export HCCL_IF_BASE_PORT=61100 && msrun --worker_num=2 "
f"--local_worker_num=2 --master_port=8226 --log_dir=parallel_glm3_6b_predict_mp2 --join=True "
f"{cur_dir}/run_parallel.py --mode parallel_glm3_6b_predict_mp2",
'parallel_glm3_6b_predict_mp2/worker_0.log'),
(f"export ASCEND_RT_VISIBLE_DEVICES=4,5 && export LCAL_IF_PORT=10070 && "
(f"export ASCEND_RT_VISIBLE_DEVICES=4,5 && export LCAL_COMM_ID=127.0.0.1:10070 && "
f"export HCCL_IF_BASE_PORT=61200 && msrun --worker_num=2 "
f"--local_worker_num=2 --master_port=8228 --log_dir=parallel_shared_expert_predict_mp2 --join=True "
f"{cur_dir}/run_parallel.py --mode parallel_shared_expert_predict_mp2",