mirror of
https://gitee.com/mindspore/mindformers.git
synced 2025-12-06 11:29:59 +08:00
Add support for NoPE layers and update configuration parameters
This commit is contained in:
@@ -319,6 +319,9 @@ class GPTModel(nn.Cell):
|
||||
)
|
||||
elif self.position_embedding_type == 'mrope':
|
||||
raise NotImplementedError("position_embedding_type = mrope is not supported now.")
|
||||
elif self.position_embedding_type == 'none':
|
||||
self.rotary_pos_emb = None
|
||||
|
||||
if self.use_rotary_position_embeddings:
|
||||
self.rotary_pos_emb.shard(config)
|
||||
# Transformer.
|
||||
|
||||
@@ -24,6 +24,7 @@ from mindspore.context import ParallelMode
|
||||
from mindspore.parallel._utils import _get_parallel_mode
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from mindformers.tools.logger import logger
|
||||
from mindformers.parallel_core.utils.spec_utils import ModuleSpec, build_module
|
||||
from mindformers.parallel_core.transformer_config import TransformerConfig
|
||||
from mindformers.parallel_core.training_graph.transformer.enums import AttnMaskType
|
||||
@@ -141,26 +142,31 @@ class Attention(nn.Cell):
|
||||
self.next_tokens = 0 if self.config.attention_next_tokens is None else self.config.attention_next_tokens
|
||||
self.keep_prob = 1.0 if self.config.attention_dropout is None else 1 - self.config.attention_dropout
|
||||
self.use_attention_mask = True if self.config.use_attention_mask is None else self.config.use_attention_mask
|
||||
self.is_nope_layer = (
|
||||
config.nope_layer_interval is not None
|
||||
and (layer_number + 1) % config.nope_layer_interval == 0
|
||||
)
|
||||
|
||||
if config.nope_layer_interval is not None and config.nope_layer_interval > 0:
|
||||
logger.info(f"NOPE layer interleaving is enabled. Layer {layer_number} is_nope_layer={self.is_nope_layer}")
|
||||
|
||||
# Define ulysses context parallel related parameters
|
||||
self.cp_ds = self.config.hierarchical_context_parallel_sizes
|
||||
self.cp_co = self.cp // self.cp_ds
|
||||
|
||||
if self.hidden_size % self.num_heads != 0:
|
||||
raise ValueError("For 'MultiHeadAttention', the class variable 'hidden_size' must be a multiple "
|
||||
"of 'num_heads', but got the hidden_size is {} and the num_heads is {}."
|
||||
.format(self.hidden_size, self.num_heads))
|
||||
raise ValueError(f"For 'MultiHeadAttention', the class variable 'hidden_size' must be a multiple "
|
||||
f"of 'num_heads', but got the hidden_size is {self.hidden_size} "
|
||||
f"and the num_heads is {self.num_heads}.")
|
||||
# Check if num_heads and kv_num_heads are multiples of tp * cp_ds
|
||||
if self.num_heads % (self.tp * self.cp_ds) != 0:
|
||||
raise ValueError("For 'ParallelAttention', the class variable 'num_heads' must be a multiple of "
|
||||
"'tensor_parallel * ulysses_cp_num', but got num_heads is {}, tensor_parallel is {}, "
|
||||
"ulysses_cp_num is {}."
|
||||
.format(self.num_heads, self.tp, self.cp_ds))
|
||||
raise ValueError(f"For 'ParallelAttention', the class variable 'num_heads' must be a multiple of "
|
||||
f"'tensor_parallel * ulysses_cp_num', but got num_heads is {self.num_heads}, "
|
||||
f"tensor_parallel is {self.tp}, ulysses_cp_num is {self.cp_ds}.")
|
||||
if self.kv_num_heads % (self.tp * self.cp_ds) != 0 and self.kv_num_heads % self.tp != 0:
|
||||
raise ValueError("For 'ParallelAttention', the class variable 'kv_num_heads' must be a multiple of "
|
||||
"'tensor_parallel * ulysses_cp_num', but got kv_num_heads is {}, tensor_parallel is {}, "
|
||||
"ulysses_cp_num is {}."
|
||||
.format(self.kv_num_heads, self.tp, self.cp_ds))
|
||||
raise ValueError(f"For 'ParallelAttention', the class variable 'kv_num_heads' must be a multiple of "
|
||||
f"'tensor_parallel * ulysses_cp_num', but got kv_num_heads is {self.kv_num_heads}, "
|
||||
f"tensor_parallel is {self.tp}, ulysses_cp_num is {self.cp_ds}.")
|
||||
|
||||
self.core_attention = build_module(
|
||||
submodules.core_attention,
|
||||
@@ -271,7 +277,7 @@ class Attention(nn.Cell):
|
||||
key = self.reshape(key, (seq_len, bs, self.kv_num_heads, self.head_dim))
|
||||
|
||||
# apply rotary position embedding
|
||||
if rotary_pos_emb is not None:
|
||||
if not self.is_nope_layer and rotary_pos_emb is not None:
|
||||
query = self.apply_rotary_pos_emb(query, rotary_pos_emb)
|
||||
key = self.apply_rotary_pos_emb(key, rotary_pos_emb)
|
||||
|
||||
|
||||
@@ -110,6 +110,9 @@ class TransformerConfig(ModelParallelConfig, MFModelConfig):
|
||||
position_embedding_type: str = "rope"
|
||||
"""Position embedding type to use for the attention layer."""
|
||||
|
||||
nope_layer_interval: int = None
|
||||
"""Interval for inserting NoPE (No Position Embedding) layers among RoPE layers."""
|
||||
|
||||
rotary_base: float = 10000.0
|
||||
"""Rotary base for the rotary embeddings, used by rope and yarn. Mindformers required."""
|
||||
|
||||
@@ -459,8 +462,8 @@ class TransformerConfig(ModelParallelConfig, MFModelConfig):
|
||||
self.num_query_groups = self.num_attention_heads
|
||||
|
||||
if self.context_parallel_size > 1 and not self.use_flash_attention:
|
||||
raise ValueError(f"context_parallel is only available for flash attention for now, "
|
||||
f"please set use_flash_attention=True.")
|
||||
raise ValueError("context_parallel is only available for flash attention for now, "
|
||||
"please set use_flash_attention=True.")
|
||||
|
||||
if self.use_flash_attention:
|
||||
if self.use_eod_attn_mask_compression and not self.use_ring_attention:
|
||||
@@ -507,7 +510,7 @@ class TransformerConfig(ModelParallelConfig, MFModelConfig):
|
||||
|
||||
if self.moe_shared_expert_intermediate_size is not None:
|
||||
if self.shared_expert_num == 0:
|
||||
logger.warning(f"The hidden-size of shared experts ('moe_shared_expert_intermediate_size') is set, "
|
||||
logger.warning("The hidden-size of shared experts ('moe_shared_expert_intermediate_size') is set, "
|
||||
"but get shared_expert_num = 0. The shared_expert_num will be ignored.")
|
||||
elif self.moe_shared_expert_intermediate_size != self.moe_ffn_hidden_size * self.shared_expert_num:
|
||||
logger.warning(
|
||||
@@ -650,7 +653,7 @@ class TransformerConfig(ModelParallelConfig, MFModelConfig):
|
||||
raise TypeError("'moe_layer_freq' should be <int> or <list[int]>, "
|
||||
f"but got {type(self.moe_layer_freq)}")
|
||||
|
||||
self.is_dryrun = (os.environ.get('MS_SIMULATION_LEVEL', '0') != '0')
|
||||
self.is_dryrun = os.environ.get('MS_SIMULATION_LEVEL', '0') != '0'
|
||||
if self.is_dryrun:
|
||||
if self.num_moe_experts is not None and self.seq_length % self.num_moe_experts != 0:
|
||||
raise ValueError(
|
||||
|
||||
@@ -265,6 +265,7 @@ COMMON_CONFIG_MAPPING = {
|
||||
"activation_func": "activation_func",
|
||||
"normalization": "normalization",
|
||||
"fused_norm": "fused_norm",
|
||||
"nope_layer_interval": "nope_layer_interval",
|
||||
"calculate_per_token_loss": "calculate_per_token_loss",
|
||||
"multi_latent_attention": "multi_latent_attention",
|
||||
"compute_dtype": "compute_dtype",
|
||||
@@ -481,7 +482,7 @@ def convert_to_transformer_config(
|
||||
if not isinstance(mapping_key, str):
|
||||
(mapping_key, trans_func) = mapping_key
|
||||
value = trans_func(value)
|
||||
if mapping_key in update_dict.keys():
|
||||
if mapping_key in update_dict:
|
||||
raise KeyError(f"Multiple configurations provided for the same setting. "
|
||||
f"Please check these conflicting configs: {list(reversed_mapping[mapping_key])}")
|
||||
update_dict[mapping_key] = value
|
||||
@@ -491,16 +492,16 @@ def convert_to_transformer_config(
|
||||
for parallel_key, parallel_value in model_config['parallel_config'].items():
|
||||
if parallel_key == 'recompute' and isinstance(parallel_value, dict):
|
||||
for recompute_key, recompute_value in parallel_value.items():
|
||||
if recompute_key in convert_map.keys():
|
||||
if recompute_key in convert_map:
|
||||
mapping_config(recompute_key, recompute_value)
|
||||
continue
|
||||
if parallel_key in convert_map.keys():
|
||||
if parallel_key in convert_map:
|
||||
mapping_config(parallel_key, parallel_value)
|
||||
model_config.pop('parallel_config')
|
||||
for model_config_key, model_config_value in model_config.items():
|
||||
if model_config_key in not_convert_whitelist:
|
||||
continue
|
||||
if model_config_key in convert_map.keys():
|
||||
if model_config_key in convert_map:
|
||||
mapping_config(model_config_key, model_config_value)
|
||||
else:
|
||||
not_convert_keys_list.append(model_config_key)
|
||||
|
||||
Reference in New Issue
Block a user