From dc386615dfac01225a84ec6651be92461af3e7b4 Mon Sep 17 00:00:00 2001 From: zxq <342239412@qq.com> Date: Fri, 11 Jul 2025 17:54:31 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90feature=E3=80=91=E3=80=90dev=E3=80=91?= =?UTF-8?q?=E4=BF=AE=E6=94=B9Qwen2=E6=A8=A1=E5=9E=8B=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E5=8F=8Ayaml=EF=BC=8C=E9=80=82=E9=85=8Drun=5Fmindformer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/qwen2/predict_qwen2.yaml | 41 +++ mindformers/models/__init__.py | 5 + mindformers/models/qwen2/__init__.py | 6 + .../models/qwen2/configuration_qwen2.py | 326 ++++++++++-------- mindformers/models/qwen2/modeling_qwen2.py | 2 +- tests/st/test_ut/base_schema.json | 39 ++- 6 files changed, 265 insertions(+), 154 deletions(-) create mode 100644 configs/qwen2/predict_qwen2.yaml diff --git a/configs/qwen2/predict_qwen2.yaml b/configs/qwen2/predict_qwen2.yaml new file mode 100644 index 000000000..9eca2e1f0 --- /dev/null +++ b/configs/qwen2/predict_qwen2.yaml @@ -0,0 +1,41 @@ +seed: 0 +output_dir: './output' # path to save checkpoint/strategy +load_checkpoint: '' +use_parallel: False +run_mode: 'predict' +use_legacy: False +load_ckpt_format: 'safetensors' + +trainer: + type: CausalLanguageModelingTrainer + model_name: 'qwen2' + +# default parallel of device num = 1 for Atlas 800T A2 +parallel_config: + data_parallel: 1 + model_parallel: 1 +# HuggingFace file directory +pretrained_model_dir: '/path/hf_dir' +model: + model_config: + compute_dtype: "bfloat16" + layernorm_compute_dtype: "float32" + softmax_compute_dtype: "float32" + rotary_dtype: "bfloat16" + params_dtype: "bfloat16" + add_qkv_bias: True + +# mindspore context init config +context: + mode: 0 #0--Graph Mode; 1--Pynative Mode + enable_graph_kernel: False + ascend_config: + precision_mode: "must_keep_origin_dtype" + max_device_memory: "59GB" + save_graphs: False + save_graphs_path: "./graph" + +# parallel context config +parallel: + parallel_mode: "MANUAL_PARALLEL" + enable_alltoall: False diff --git a/mindformers/models/__init__.py b/mindformers/models/__init__.py index f0315f27e..be2f27890 100644 --- a/mindformers/models/__init__.py +++ b/mindformers/models/__init__.py @@ -52,6 +52,11 @@ from .llama import ( LlamaTokenizer, LlamaTokenizerFast ) +from .qwen2 import ( + Qwen2Config, + Qwen2PreTrainedModel, + Qwen2ForCausalLM, +) from .qwen3 import ( Qwen3Config, Qwen3PreTrainedModel, diff --git a/mindformers/models/qwen2/__init__.py b/mindformers/models/qwen2/__init__.py index d7ef687ea..41b6b50e6 100644 --- a/mindformers/models/qwen2/__init__.py +++ b/mindformers/models/qwen2/__init__.py @@ -13,3 +13,9 @@ # limitations under the License. # ============================================================================ """qwen2 model""" +from .utils import Qwen2PreTrainedModel +from .configuration_qwen2 import Qwen2Config +from .modeling_qwen2 import Qwen2ForCausalLM +from .modeling_qwen2_infer import InferenceQwen2ForCausalLM + +__all__ = ['Qwen2Config', 'Qwen2ForCausalLM', 'InferenceQwen2ForCausalLM', 'Qwen2PreTrainedModel'] diff --git a/mindformers/models/qwen2/configuration_qwen2.py b/mindformers/models/qwen2/configuration_qwen2.py index 505a4bed4..058d1321f 100644 --- a/mindformers/models/qwen2/configuration_qwen2.py +++ b/mindformers/models/qwen2/configuration_qwen2.py @@ -15,167 +15,205 @@ """Qwen2 Config API.""" __all__ = ['Qwen2Config'] -from typing import Optional, Union - -from mindspore._checkparam import args_type_check - -from mindformers.modules.transformer.transformer import default_transformer_config, \ - TransformerOpParallelConfig -from mindformers.tools.register import MindFormerRegister, MindFormerModuleType from mindformers.models.configuration_utils import PretrainedConfig -from mindformers.models.utils import convert_mstype +from mindformers.models.model_config_utils import ( + register_mf_model_parameter, + ignore_and_delete_parameter, + NotSupportedInfo +) +from mindformers.parallel_core.mf_model_config import MFModelConfig +from mindformers.tools.register import MindFormerRegister, MindFormerModuleType -@MindFormerRegister.register(MindFormerModuleType.CONFIG) +@MindFormerRegister.register(MindFormerModuleType.CONFIG, legacy=False, search_names='qwen2') class Qwen2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a + Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). - """ Qwen2 Model Config """ + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22016): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers using full attention. The first `max_window_layers` + layers will use full attention, while any + additional layer afterwards will use SWA (Sliding Window Attention). + layer_types (`list`, *optional*): + Attention pattern for each layer. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + """ model_type = "Qwen2" + keys_to_ignore_at_inference = ["past_key_values"] - @args_type_check(parallel_config=(dict, TransformerOpParallelConfig)) + # Default tensor parallel plan for base model `Qwen2` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + @register_mf_model_parameter( + mf_model_kwargs=MFModelConfig( + pad_token_id=151643, + block_size=32, + num_blocks=1024, + normalization='RMSNorm', + add_bias_linear=False, + gated_linear_unit=True + )) + @ignore_and_delete_parameter(extra_ignore_param=[ + ('max_window_layers', NotSupportedInfo.useless), + ('sliding_window', NotSupportedInfo.useless), + ('layer_types', NotSupportedInfo.useless), + ('use_sliding_window', NotSupportedInfo.useless), + ]) def __init__(self, - vocab_size: int = 151936, - hidden_size: int = 4096, - intermediate_size: Optional[int] = 22016, - num_hidden_layers: int = 32, - num_attention_heads: int = 32, - num_key_value_heads: Optional[int] = 32, - hidden_act: str = "silu", - max_position_embeddings: Optional[int] = 32768, - rms_norm_eps: float = 1e-6, - tie_word_embeddings: bool = False, - rope_theta: float = 10000.0, - position_embedding_type: str = "rope", - seq_length: int = 2048, - bos_token_id: int = 1, - eos_token_id: int = 2, - pad_token_id: int = 0, - normalization: str = "RMSNorm", - compute_dtype: str = "bfloat16", - layernorm_compute_dtype: str = "float32", - softmax_compute_dtype: str = "float32", - rotary_dtype: str = "float32", - params_dtype: str = "bfloat16", - residual_dtype: str = None, - add_qkv_bias: bool = False, - add_bias_linear: bool = False, - gated_linear_unit: bool = True, - parallel_config: Union[dict, TransformerOpParallelConfig] = default_transformer_config, - use_flash_attention: bool = True, - repetition_penalty: float = 1.0, - max_decode_length: int = 1024, - block_size: int = 16, - num_blocks: int = 512, - top_k: int = 5, - top_p: float = 1.0, - do_sample: bool = True, - parallel_decoding_params: dict = None, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + layer_types=None, + attention_dropout=0.0, **kwargs): - """ - Qwen2 config class which defines the model size. - - Args: - vocab_size (int): Vocabulary size of the qwen2 model. Default: ``151936``. - hidden_size (int): Dimensionality of the encoder layers and the pooler layer. Default: ``4096``. - intermediate_size (int): Customize the number of dimension of the intermediate layer. - Default: ``22016``. - num_hidden_layers (int): Number of hidden layers in the Transformer decoder. Default: ``32``. - num_attention_heads (int): Number of attention heads for each attention layer in the Transformer decoder. - Default: ``32``. - num_key_value_heads (int): Define multi group head attention heads number. Default: ``32``. - hidden_act (str): Specifies the activation function for hidden layers. Default: ``silu``. - max_position_embedding (int): Customize the maximum sequence length that the model can handle. - Default: "32768". - rms_norm_eps (float): The epsilon value of the denominator. Default: ``1e-6``. - tie_word_embeddings (bool): Whether to tie input and output embeddings. Default: ``False``. - rope_theta (float): Frequency factors for sine and cosine functions in RoPE. Default: ``10000.0``. - batch_size (int): Batch size for input data, use in predict. Default: ``1``. - seq_length (int): The sequence length of input_ids. Default: ``2048``. - multiple_of (int): Define SwiGLU hidden layer size multiples. Default: ``256``. - ffn_dim_multiplier (int): Define ffn layer dim multiples. Default: ``None``. - bos_token_id (int): The id of the *beginning-of-sequence* token. Default: ``1``. - eos_token_id (int): The id of the *end-of-sequence* token. Default: ``2``. - pad_token_id (int): The id of the *padding* token. Default: ``0``. - normalization (str): Defines the normalization layer type. Default: ``RMSNorm``. - compute_dtype (str): Linear layer compute dtype. Default: ``bfloat16``. - layernorm_compute_type (str): Layernorm compute dtype. Default: ``float32``. - softmax_compute_type (str): Softmax compute dtype. Default: ``float32``. - rotary_dtype (str): RoPE compute dtype. Default: ``float32``. - params_dtype (str): Parameter initial dtype. Default: ``bfloat16``. - residual_dtype (str): Residual compute dtype. Default: ``None``. - embedding_init_type (str): Embedding weight initial dtype. Default: ``None``. - qkv_has_bias (bool): Whether the Query, Key, and Value projection has bias. Default: ``False``. - attn_proj_has_bias (bool): Whether the attn projection has bias. Default: ``False``. - out_proj_has_bias (bool): Whether the wo projection has bias. Default: ``False``. - add_bias_linear (bool): Whether the attn mlp has bias. Default: ``False``. - parallel_config (Union[dict, TransformerOpParallelConfig]): The parallel configuration. - moe_config (Union[dict, MoEConfig]): The MoE configuration. Default: ``default_moe_config`` , - an instance of `MoEConfig` with default args. - scaling_factor (float): Scaling factor to adjust the weights of the frequency factors in the sine - and cosine functions. Default: ``1.0``. - use_flash_attention (bool): Whether to enable flash attention ops. Default: ``False``. - repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty. - See `this paper `_ for more details. Default: ``1.0``. - max_decode_length (int): The maximum length the generated tokens can have. - block_size (int): The maximum number of tokens in one block can have when using paged attention. - Default: ``16``. - num_blocks (int): The maximum number of blocks when using paged attention. Default: ``512``. - top_k (int): The number of highest probability vocabulary tokens to keep for top-k-filtering. - Default: ``5``. - top_p (float): If set to float < 1, only the smallest set of most probable tokens with probabilities - that add up to `top_p` or higher are kept for generation. Default: ``1.0``. - do_sample (bool): Whether to use sampling; use greedy decoding otherwise. Default: ``True``. - quant_config (dict): Quantitative configuration. Default: ``None``. - parallel_decoding_params (dict): Parallel decoding params. Default: ``None``. - kwargs: Other arguments. - - """ - super(Qwen2Config, self).__init__(**kwargs) - # hf params self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings if max_position_embeddings else seq_length - self.intermediate_size = intermediate_size + self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size + self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if self.use_sliding_window else None + self.max_window_layers = max_window_layers + + # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act + self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache self.rope_theta = rope_theta - self.position_embedding_type = position_embedding_type - self.tie_word_embeddings = tie_word_embeddings - # common params - if isinstance(parallel_config, dict): - parallel_config = TransformerOpParallelConfig(**parallel_config) - self.seq_length = seq_length - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.pad_token_id = pad_token_id - self.normalization = normalization - self.compute_dtype = convert_mstype(compute_dtype) - self.layernorm_compute_dtype = convert_mstype(layernorm_compute_dtype) - self.softmax_compute_dtype = convert_mstype(softmax_compute_dtype) - self.rotary_dtype = convert_mstype(rotary_dtype) - self.params_dtype = convert_mstype(params_dtype) - if residual_dtype is not None: - self.residual_dtype = convert_mstype(residual_dtype) - else: - self.residual_dtype = self.compute_dtype - self.add_qkv_bias = add_qkv_bias - self.add_bias_linear = add_bias_linear - self.gated_linear_unit = gated_linear_unit - self.use_flash_attention = use_flash_attention - # infer params - self.repetition_penalty = repetition_penalty - self.max_decode_length = max_decode_length - self.top_k = top_k - self.top_p = top_p - self.do_sample = do_sample - self.block_size = block_size - self.num_blocks = num_blocks - self.parallel_decoding_params = parallel_decoding_params - self.parallel_config = parallel_config - self.post_process = True + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if self.sliding_window is not None and i >= self.max_window_layers + else "full_attention" + for i in range(self.num_hidden_layers) + ] + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/mindformers/models/qwen2/modeling_qwen2.py b/mindformers/models/qwen2/modeling_qwen2.py index ac931fd20..060ab27f4 100644 --- a/mindformers/models/qwen2/modeling_qwen2.py +++ b/mindformers/models/qwen2/modeling_qwen2.py @@ -24,7 +24,7 @@ from mindformers.models.qwen2.utils import Qwen2PreTrainedModel from mindformers.models.qwen2.modeling_qwen2_infer import InferenceQwen2ForCausalLM -@MindFormerRegister.register(MindFormerModuleType.MODELS) +@MindFormerRegister.register(MindFormerModuleType.MODELS, legacy=False) class Qwen2ForCausalLM(Qwen2PreTrainedModel): r""" Provide Qwen2 Model for training and inference. diff --git a/tests/st/test_ut/base_schema.json b/tests/st/test_ut/base_schema.json index 767160b64..73f963cc8 100644 --- a/tests/st/test_ut/base_schema.json +++ b/tests/st/test_ut/base_schema.json @@ -3002,8 +3002,29 @@ "mindformers.models.multi_modal.ModalContentTransformTemplate.post_process": { "signature": "(self, output_ids, **kwargs)" }, + "mindformers.models.qwen2.InferenceQwen2ForCausalLM": { + "signature": "(config)" + }, + "mindformers.models.qwen2.InferenceQwen2ForCausalLM.construct": { + "signature": "(self, input_ids, positions=None, batch_valid_length=None, context_lens_tensor=None, q_seq_lens=None, block_tables=None, slot_mapping=None, attention_mask=None, attn_metadata=None, key_cache=None, value_cache=None)" + }, + "mindformers.models.qwen2.Qwen2Config": { + "signature": "(vocab_size=151936, hidden_size=4096, intermediate_size=22016, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, hidden_act='silu', max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, use_sliding_window=False, sliding_window=4096, max_window_layers=28, layer_types=None, attention_dropout=0.0, **kwargs)" + }, + "mindformers.models.qwen2.Qwen2ForCausalLM": { + "signature": "(config)" + }, + "mindformers.models.qwen2.Qwen2PreTrainedModel": { + "signature": "(config: mindformers.models.configuration_utils.PretrainedConfig, *inputs, **kwargs)" + }, + "mindformers.models.qwen2.Qwen2PreTrainedModel.config_class": { + "signature": "(vocab_size=151936, hidden_size=4096, intermediate_size=22016, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, hidden_act='silu', max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, use_sliding_window=False, sliding_window=4096, max_window_layers=28, layer_types=None, attention_dropout=0.0, **kwargs)" + }, + "mindformers.models.qwen2.Qwen2PreTrainedModel.convert_name": { + "signature": "(self, weight_name)" + }, "mindformers.models.qwen2.configuration_qwen2.Qwen2Config": { - "signature": "(vocab_size: int = 151936, hidden_size: int = 4096, intermediate_size: Optional[int] = 22016, num_hidden_layers: int = 32, num_attention_heads: int = 32, num_key_value_heads: Optional[int] = 32, hidden_act: str = 'silu', max_position_embeddings: Optional[int] = 32768, rms_norm_eps: float = 1e-06, tie_word_embeddings: bool = False, rope_theta: float = 10000.0, position_embedding_type: str = 'rope', seq_length: int = 2048, bos_token_id: int = 1, eos_token_id: int = 2, pad_token_id: int = 0, normalization: str = 'RMSNorm', compute_dtype: str = 'bfloat16', layernorm_compute_dtype: str = 'float32', softmax_compute_dtype: str = 'float32', rotary_dtype: str = 'float32', params_dtype: str = 'bfloat16', residual_dtype: str = None, add_qkv_bias: bool = False, add_bias_linear: bool = False, gated_linear_unit: bool = True, parallel_config: Union[dict, mindformers.modules.transformer.transformer.TransformerOpParallelConfig] = , use_flash_attention: bool = True, repetition_penalty: float = 1.0, max_decode_length: int = 1024, block_size: int = 16, num_blocks: int = 512, top_k: int = 5, top_p: float = 1.0, do_sample: bool = True, parallel_decoding_params: dict = None, **kwargs)" + "signature": "(vocab_size=151936, hidden_size=4096, intermediate_size=22016, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, hidden_act='silu', max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, use_sliding_window=False, sliding_window=4096, max_window_layers=28, layer_types=None, attention_dropout=0.0, **kwargs)" }, "mindformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM": { "signature": "(config)" @@ -4167,7 +4188,7 @@ "signature": "(input_size: int, output_size: int, config: mindformers.parallel_core.model_parallel_config.ModelParallelConfig, init_method: Callable = None, bias: bool = False, gather_output: bool = False, stride: int = 1, keep_master_weight_for_test: bool = False, skip_bias_add: bool = False, skip_weight_param_allocation: bool = False, embedding_activation_buffer: Optional[List[mindspore.common.tensor.Tensor]] = None, grad_output_buffer: Optional[List[mindspore.common.tensor.Tensor]] = None, is_expert: bool = True, tp_comm_buffer_name: str = None, disable_grad_reduce: bool = False, transpose_b: bool = True, bias_init: Callable = None)" }, "mindformers.parallel_core.training_graph.tensor_parallel.batched_layers.ColumnParallelBatchedLinear.construct": { - "signature": "(self, input_: mindspore.common.tensor.Tensor, weight: mindspore.common.tensor.Tensor = None) -> tuple[mindspore.common.tensor.Tensor, mindspore.common.tensor.Tensor]" + "signature": "(self, input_: mindspore.common.tensor.Tensor, weight: mindspore.common.tensor.Tensor = None) -> tuple" }, "mindformers.parallel_core.training_graph.tensor_parallel.batched_layers.ColumnParallelBatchedLinear.shard": { "signature": "(self, config: mindformers.parallel_core.model_parallel_config.ModelParallelConfig) -> None" @@ -4179,7 +4200,7 @@ "signature": "(input_size: int, output_size: int, config: mindformers.parallel_core.model_parallel_config.ModelParallelConfig, init_method: Callable = None, bias: bool = False, input_is_parallel: bool = False, skip_bias_add: bool = False, stride: int = 1, keep_master_weight_for_test: bool = False, is_expert: bool = True, tp_comm_buffer_name: str = None, transpose_b: bool = True, bias_init: Callable = None)" }, "mindformers.parallel_core.training_graph.tensor_parallel.batched_layers.RowParallelBatchedLinear.construct": { - "signature": "(self, input_: mindspore.common.tensor.Tensor) -> tuple[mindspore.common.tensor.Tensor, mindspore.common.tensor.Tensor]" + "signature": "(self, input_: mindspore.common.tensor.Tensor) -> tuple" }, "mindformers.parallel_core.training_graph.tensor_parallel.batched_layers.RowParallelBatchedLinear.shard": { "signature": "(self, config: mindformers.parallel_core.model_parallel_config.ModelParallelConfig) -> None" @@ -4191,7 +4212,7 @@ "signature": "(input_size: int, output_size: int, config: mindformers.parallel_core.transformer_config.TransformerConfig, init_method: Callable = None, bias: bool = True, gather_output: bool = False, stride: int = 1, keep_master_weight_for_test: bool = False, skip_bias_add: bool = False, skip_weight_param_allocation: bool = False, embedding_activation_buffer: Optional[List[mindspore.common.tensor.Tensor]] = None, grad_output_buffer: Optional[List[mindspore.common.tensor.Tensor]] = None, is_expert: bool = False, tp_comm_buffer_name: str = None, disable_grad_reduce: bool = False, transpose_b: bool = True, bias_init: Callable = None)" }, "mindformers.parallel_core.training_graph.tensor_parallel.layers.ColumnParallelLinear.construct": { - "signature": "(self, input_: mindspore.common.tensor.Tensor, weight: mindspore.common.tensor.Tensor = None) -> tuple[mindspore.common.tensor.Tensor, mindspore.common.tensor.Tensor]" + "signature": "(self, input_: mindspore.common.tensor.Tensor, weight: mindspore.common.tensor.Tensor = None) -> tuple" }, "mindformers.parallel_core.training_graph.tensor_parallel.layers.ColumnParallelLinear.shard": { "signature": "(self, config: mindformers.parallel_core.transformer_config.TransformerConfig) -> None" @@ -4206,7 +4227,7 @@ "signature": "(input_size: int, output_size: int, config: mindformers.parallel_core.transformer_config.TransformerConfig, init_method: Callable = None, bias: bool = True, input_is_parallel: bool = False, skip_bias_add: bool = False, stride: int = 1, keep_master_weight_for_test: bool = False, is_expert: bool = False, tp_comm_buffer_name: str = None, transpose_b: bool = True, bias_init: Callable = None)" }, "mindformers.parallel_core.training_graph.tensor_parallel.layers.RowParallelLinear.construct": { - "signature": "(self, input_: mindspore.common.tensor.Tensor) -> tuple[mindspore.common.tensor.Tensor, mindspore.common.tensor.Tensor]" + "signature": "(self, input_: mindspore.common.tensor.Tensor) -> tuple" }, "mindformers.parallel_core.training_graph.tensor_parallel.layers.RowParallelLinear.shard": { "signature": "(self, config: mindformers.parallel_core.transformer_config.TransformerConfig) -> None" @@ -4230,7 +4251,7 @@ "signature": "(input_size: int, output_size: int, config: mindformers.parallel_core.transformer_config.TransformerConfig, init_method: Callable = None, bias: bool = True, gather_output: bool = False, stride: int = 1, keep_master_weight_for_test: bool = False, skip_bias_add: bool = False, skip_weight_param_allocation: bool = False, embedding_activation_buffer: Optional[List[mindspore.common.tensor.Tensor]] = None, grad_output_buffer: Optional[List[mindspore.common.tensor.Tensor]] = None, is_expert: bool = False, tp_comm_buffer_name: str = None, disable_grad_reduce: bool = False, transpose_b: bool = True, bias_init: Callable = None, lora_rank: int = 8, lora_alpha: int = 32, lora_dropout: float = 0.0, lora_a_init='normal', lora_b_init='zeros')" }, "mindformers.parallel_core.training_graph.tensor_parallel.lora_layers.ColumnParallelLinearWithLoRA.construct": { - "signature": "(self, input_: mindspore.common.tensor.Tensor, weight: mindspore.common.tensor.Tensor = None) -> tuple[mindspore.common.tensor.Tensor, mindspore.common.tensor.Tensor]" + "signature": "(self, input_: mindspore.common.tensor.Tensor, weight: mindspore.common.tensor.Tensor = None) -> tuple" }, "mindformers.parallel_core.training_graph.tensor_parallel.lora_layers.ColumnParallelLinearWithLoRA.shard_lora": { "signature": "(self, config: mindformers.parallel_core.transformer_config.TransformerConfig) -> None" @@ -4239,7 +4260,7 @@ "signature": "(input_size: int, output_size: int, config: mindformers.parallel_core.transformer_config.TransformerConfig, init_method: Callable = None, bias: bool = True, input_is_parallel: bool = False, skip_bias_add: bool = False, stride: int = 1, keep_master_weight_for_test: bool = False, is_expert: bool = False, tp_comm_buffer_name: str = None, transpose_b: bool = True, bias_init: Callable = None, lora_rank: int = 8, lora_alpha: int = 32, lora_dropout: float = 0.0, lora_a_init='normal', lora_b_init='zeros')" }, "mindformers.parallel_core.training_graph.tensor_parallel.lora_layers.RowParallelLinearWithLoRA.construct": { - "signature": "(self, input_: mindspore.common.tensor.Tensor) -> tuple[mindspore.common.tensor.Tensor, mindspore.common.tensor.Tensor]" + "signature": "(self, input_: mindspore.common.tensor.Tensor) -> tuple" }, "mindformers.parallel_core.training_graph.tensor_parallel.lora_layers.RowParallelLinearWithLoRA.shard_lora": { "signature": "(self, config: mindformers.parallel_core.transformer_config.TransformerConfig) -> None" @@ -4287,7 +4308,7 @@ "signature": "(config: mindformers.parallel_core.transformer_config.TransformerConfig, submodules: mindformers.parallel_core.training_graph.transformer.mlp.MLPSubmodules, is_expert: bool = False, input_size: int = None)" }, "mindformers.parallel_core.training_graph.transformer.mlp.MLP.construct": { - "signature": "(self, hidden_states: mindspore.common.tensor.Tensor, per_token_scale=None, extra_loss=0.0) -> tuple[mindspore.common.tensor.Tensor, mindspore.common.tensor.Tensor, float]" + "signature": "(self, hidden_states: mindspore.common.tensor.Tensor, per_token_scale=None, extra_loss=0.0) -> tuple" }, "mindformers.parallel_core.training_graph.transformer.mlp.MLP.shard": { "signature": "(self, config: mindformers.parallel_core.transformer_config.TransformerConfig)" @@ -4302,7 +4323,7 @@ "signature": "(config: mindformers.parallel_core.transformer_config.TransformerConfig, submodules: mindformers.parallel_core.training_graph.transformer.mlp.MLPSubmodules)" }, "mindformers.parallel_core.training_graph.transformer.moe.shared_experts.SharedExpertMLP.construct": { - "signature": "(self, hidden_states: mindspore.common.tensor.Tensor) -> tuple[mindspore.common.tensor.Tensor, mindspore.common.tensor.Tensor]" + "signature": "(self, hidden_states: mindspore.common.tensor.Tensor) -> tuple" }, "mindformers.parallel_core.training_graph.transformer.moe.shared_experts.SharedExpertMLP.expert_gate_shard": { "signature": "(self, config: mindformers.parallel_core.transformer_config.TransformerConfig)"