Add support for NoPE layers and update configuration parameters

This commit is contained in:
JavaZero
2025-10-25 11:46:35 +08:00
parent d2f5d626e4
commit 2a82110b0e
4 changed files with 33 additions and 20 deletions

View File

@@ -319,6 +319,9 @@ class GPTModel(nn.Cell):
)
elif self.position_embedding_type == 'mrope':
raise NotImplementedError("position_embedding_type = mrope is not supported now.")
elif self.position_embedding_type == 'none':
self.rotary_pos_emb = None
if self.use_rotary_position_embeddings:
self.rotary_pos_emb.shard(config)
# Transformer.

View File

@@ -24,6 +24,7 @@ from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode
import mindspore.common.dtype as mstype
from mindformers.tools.logger import logger
from mindformers.parallel_core.utils.spec_utils import ModuleSpec, build_module
from mindformers.parallel_core.transformer_config import TransformerConfig
from mindformers.parallel_core.training_graph.transformer.enums import AttnMaskType
@@ -141,26 +142,31 @@ class Attention(nn.Cell):
self.next_tokens = 0 if self.config.attention_next_tokens is None else self.config.attention_next_tokens
self.keep_prob = 1.0 if self.config.attention_dropout is None else 1 - self.config.attention_dropout
self.use_attention_mask = True if self.config.use_attention_mask is None else self.config.use_attention_mask
self.is_nope_layer = (
config.nope_layer_interval is not None
and (layer_number + 1) % config.nope_layer_interval == 0
)
if config.nope_layer_interval is not None and config.nope_layer_interval > 0:
logger.info(f"NOPE layer interleaving is enabled. Layer {layer_number} is_nope_layer={self.is_nope_layer}")
# Define ulysses context parallel related parameters
self.cp_ds = self.config.hierarchical_context_parallel_sizes
self.cp_co = self.cp // self.cp_ds
if self.hidden_size % self.num_heads != 0:
raise ValueError("For 'MultiHeadAttention', the class variable 'hidden_size' must be a multiple "
"of 'num_heads', but got the hidden_size is {} and the num_heads is {}."
.format(self.hidden_size, self.num_heads))
raise ValueError(f"For 'MultiHeadAttention', the class variable 'hidden_size' must be a multiple "
f"of 'num_heads', but got the hidden_size is {self.hidden_size} "
f"and the num_heads is {self.num_heads}.")
# Check if num_heads and kv_num_heads are multiples of tp * cp_ds
if self.num_heads % (self.tp * self.cp_ds) != 0:
raise ValueError("For 'ParallelAttention', the class variable 'num_heads' must be a multiple of "
"'tensor_parallel * ulysses_cp_num', but got num_heads is {}, tensor_parallel is {}, "
"ulysses_cp_num is {}."
.format(self.num_heads, self.tp, self.cp_ds))
raise ValueError(f"For 'ParallelAttention', the class variable 'num_heads' must be a multiple of "
f"'tensor_parallel * ulysses_cp_num', but got num_heads is {self.num_heads}, "
f"tensor_parallel is {self.tp}, ulysses_cp_num is {self.cp_ds}.")
if self.kv_num_heads % (self.tp * self.cp_ds) != 0 and self.kv_num_heads % self.tp != 0:
raise ValueError("For 'ParallelAttention', the class variable 'kv_num_heads' must be a multiple of "
"'tensor_parallel * ulysses_cp_num', but got kv_num_heads is {}, tensor_parallel is {}, "
"ulysses_cp_num is {}."
.format(self.kv_num_heads, self.tp, self.cp_ds))
raise ValueError(f"For 'ParallelAttention', the class variable 'kv_num_heads' must be a multiple of "
f"'tensor_parallel * ulysses_cp_num', but got kv_num_heads is {self.kv_num_heads}, "
f"tensor_parallel is {self.tp}, ulysses_cp_num is {self.cp_ds}.")
self.core_attention = build_module(
submodules.core_attention,
@@ -271,7 +277,7 @@ class Attention(nn.Cell):
key = self.reshape(key, (seq_len, bs, self.kv_num_heads, self.head_dim))
# apply rotary position embedding
if rotary_pos_emb is not None:
if not self.is_nope_layer and rotary_pos_emb is not None:
query = self.apply_rotary_pos_emb(query, rotary_pos_emb)
key = self.apply_rotary_pos_emb(key, rotary_pos_emb)

View File

@@ -110,6 +110,9 @@ class TransformerConfig(ModelParallelConfig, MFModelConfig):
position_embedding_type: str = "rope"
"""Position embedding type to use for the attention layer."""
nope_layer_interval: int = None
"""Interval for inserting NoPE (No Position Embedding) layers among RoPE layers."""
rotary_base: float = 10000.0
"""Rotary base for the rotary embeddings, used by rope and yarn. Mindformers required."""
@@ -459,8 +462,8 @@ class TransformerConfig(ModelParallelConfig, MFModelConfig):
self.num_query_groups = self.num_attention_heads
if self.context_parallel_size > 1 and not self.use_flash_attention:
raise ValueError(f"context_parallel is only available for flash attention for now, "
f"please set use_flash_attention=True.")
raise ValueError("context_parallel is only available for flash attention for now, "
"please set use_flash_attention=True.")
if self.use_flash_attention:
if self.use_eod_attn_mask_compression and not self.use_ring_attention:
@@ -507,7 +510,7 @@ class TransformerConfig(ModelParallelConfig, MFModelConfig):
if self.moe_shared_expert_intermediate_size is not None:
if self.shared_expert_num == 0:
logger.warning(f"The hidden-size of shared experts ('moe_shared_expert_intermediate_size') is set, "
logger.warning("The hidden-size of shared experts ('moe_shared_expert_intermediate_size') is set, "
"but get shared_expert_num = 0. The shared_expert_num will be ignored.")
elif self.moe_shared_expert_intermediate_size != self.moe_ffn_hidden_size * self.shared_expert_num:
logger.warning(
@@ -650,7 +653,7 @@ class TransformerConfig(ModelParallelConfig, MFModelConfig):
raise TypeError("'moe_layer_freq' should be <int> or <list[int]>, "
f"but got {type(self.moe_layer_freq)}")
self.is_dryrun = (os.environ.get('MS_SIMULATION_LEVEL', '0') != '0')
self.is_dryrun = os.environ.get('MS_SIMULATION_LEVEL', '0') != '0'
if self.is_dryrun:
if self.num_moe_experts is not None and self.seq_length % self.num_moe_experts != 0:
raise ValueError(

View File

@@ -265,6 +265,7 @@ COMMON_CONFIG_MAPPING = {
"activation_func": "activation_func",
"normalization": "normalization",
"fused_norm": "fused_norm",
"nope_layer_interval": "nope_layer_interval",
"calculate_per_token_loss": "calculate_per_token_loss",
"multi_latent_attention": "multi_latent_attention",
"compute_dtype": "compute_dtype",
@@ -481,7 +482,7 @@ def convert_to_transformer_config(
if not isinstance(mapping_key, str):
(mapping_key, trans_func) = mapping_key
value = trans_func(value)
if mapping_key in update_dict.keys():
if mapping_key in update_dict:
raise KeyError(f"Multiple configurations provided for the same setting. "
f"Please check these conflicting configs: {list(reversed_mapping[mapping_key])}")
update_dict[mapping_key] = value
@@ -491,16 +492,16 @@ def convert_to_transformer_config(
for parallel_key, parallel_value in model_config['parallel_config'].items():
if parallel_key == 'recompute' and isinstance(parallel_value, dict):
for recompute_key, recompute_value in parallel_value.items():
if recompute_key in convert_map.keys():
if recompute_key in convert_map:
mapping_config(recompute_key, recompute_value)
continue
if parallel_key in convert_map.keys():
if parallel_key in convert_map:
mapping_config(parallel_key, parallel_value)
model_config.pop('parallel_config')
for model_config_key, model_config_value in model_config.items():
if model_config_key in not_convert_whitelist:
continue
if model_config_key in convert_map.keys():
if model_config_key in convert_map:
mapping_config(model_config_key, model_config_value)
else:
not_convert_keys_list.append(model_config_key)