From ef9eff052cb022dabff4b8cbd8acb022a5011687 Mon Sep 17 00:00:00 2001 From: HighCloud Date: Wed, 17 Sep 2025 14:18:08 +0800 Subject: [PATCH 1/2] deepseek support 310P --- mindformers/models/utils.py | 1 + .../inference/base_models/gpt/gpt_model.py | 7 +- .../inference/quantization/base_config.py | 8 + .../quantization/golden_stick/a8dynw4.py | 80 +++-- .../quantization/golden_stick/a8dynw8.py | 277 ++++++++++-------- .../quantization/golden_stick/a8w8.py | 8 +- .../inference/quantization/utils.py | 22 ++ .../tensor_parallel/grouped_layers.py | 84 ++++-- .../inference/tensor_parallel/layers.py | 9 +- .../inference/transformer/moe/experts.py | 6 + .../inference/transformer/moe/router.py | 2 +- .../transformer/moe/token_dispatcher.py | 10 +- .../transformer/multi_latent_attention.py | 9 +- mindformers/parallel_core/inference/utils.py | 12 + .../parallel_core/inference/weights_utils.py | 19 ++ 15 files changed, 376 insertions(+), 178 deletions(-) diff --git a/mindformers/models/utils.py b/mindformers/models/utils.py index c405c995e..d75e6a6eb 100644 --- a/mindformers/models/utils.py +++ b/mindformers/models/utils.py @@ -46,6 +46,7 @@ str_to_ms_type = { } format_type = { + "nd": 2, "nz": 29, } diff --git a/mindformers/parallel_core/inference/base_models/gpt/gpt_model.py b/mindformers/parallel_core/inference/base_models/gpt/gpt_model.py index 647b73f25..f4567237d 100644 --- a/mindformers/parallel_core/inference/base_models/gpt/gpt_model.py +++ b/mindformers/parallel_core/inference/base_models/gpt/gpt_model.py @@ -31,11 +31,13 @@ from mindformers.parallel_core.inference.base_models.common.embeddings.rope_util from mindformers.parallel_core.inference.utils import divide, generate_padding_index from mindformers.parallel_core.process_group_config import ModelCommProcessGroups from mindformers.parallel_core.inference.weights_utils import (default_weight_loader, make_expert_params_mapping, - make_expert_params_mapping_with_expert_dim) + make_expert_params_mapping_with_expert_dim, + process_weights_for_310p) from mindformers.parallel_core.inference.parallel_state import is_pipeline_last_stage from mindformers.parallel_core.process_group_config import get_model_comm_pgs from mindformers.tools.logger import logger from mindformers.tools.utils import is_pynative +from mindformers.version_control import is_310p class GPTModel(nn.Cell): @@ -447,6 +449,9 @@ class GPTModel(nn.Cell): if "weight_scale_inv" in name: continue + if is_310p(): + loaded_weight = process_weights_for_310p(name, loaded_weight, params_dict, loaded_params) + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/mindformers/parallel_core/inference/quantization/base_config.py b/mindformers/parallel_core/inference/quantization/base_config.py index 9bd3dce05..e1ab75fc0 100644 --- a/mindformers/parallel_core/inference/quantization/base_config.py +++ b/mindformers/parallel_core/inference/quantization/base_config.py @@ -58,6 +58,14 @@ class QuantizeMethodBase(ABC): raise RuntimeError + def process_weight_before_loading(self, param_name, loaded_weight): + """ + Process the weight before loading. + This can be used for example, to transpose weights for computation. + """ + + return loaded_weight + def process_weights_after_loading(self, layer: nn.Cell) -> None: """ Process the weight after loading. diff --git a/mindformers/parallel_core/inference/quantization/golden_stick/a8dynw4.py b/mindformers/parallel_core/inference/quantization/golden_stick/a8dynw4.py index da75d83c7..4c0ad51cf 100644 --- a/mindformers/parallel_core/inference/quantization/golden_stick/a8dynw4.py +++ b/mindformers/parallel_core/inference/quantization/golden_stick/a8dynw4.py @@ -21,12 +21,17 @@ import mindspore from mindspore import nn, Parameter, ops from mindspore.common.initializer import initializer from mindspore.ops.auto_generate import WeightQuantBatchMatmul, DynamicQuantExt, GroupedMatmulV4 - +try: + from ms_custom_ops import grouped_matmul_w4 + GMM_310P = True +except ImportError: + GMM_310P = False +from mindformers.version_control import is_310p from mindformers.parallel_core.inference.weights_utils import set_weight_attrs from mindformers.parallel_core.inference.transformer.moe.experts import GroupedMLP from mindformers.parallel_core.inference.tensor_parallel.layers import LinearMethodBase from mindformers.parallel_core.inference.quantization import QuantizationConfig - +from mindformers.parallel_core.inference.quantization.utils import np_pack_int4_to_int8, np_unpack_int8_to_int4 class A8W4DynamicLinearMethod(LinearMethodBase): """Linear method with A8W4 quantization.""" @@ -35,12 +40,14 @@ class A8W4DynamicLinearMethod(LinearMethodBase): self.quant_config = quant_config self.quant = DynamicQuantExt() self.bias_add = ops.Add() + self.is_310p = is_310p() and GMM_310P def create_weights(self, layer: nn.Cell, input_size_per_partition: int, output_partition_sizes: list[int], params_dtype, + transpose_b=False, *weight_args, num_local_experts=None, **extra_weight_attrs) -> Union[Parameter, None]: output_size_per_partition = sum(output_partition_sizes) @@ -52,17 +59,24 @@ class A8W4DynamicLinearMethod(LinearMethodBase): raise ValueError(f"group_size should >=0 but group_size is : {group_size}") if self.is_group_mm: weight = None - self.matmul = GroupedMatmulV4() + if self.is_310p: + self.matmul = grouped_matmul_w4 + else: + self.matmul = GroupedMatmulV4() if not extra_weight_attrs.get('skip_weight_param_allocation', False): - weight_shape = (num_local_experts, self.input_size_per_partition, self.output_size_per_partition // 2) + weight_shape = (num_local_experts, self.output_size_per_partition, + self.input_size_per_partition // 2) \ + if transpose_b else (num_local_experts, self.input_size_per_partition, + self.output_size_per_partition // 2) weight = Parameter(initializer('ones', weight_shape, mindspore.qint4x2), requires_grad=False) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 2}) + input_dim, output_dim = (2, 1) if transpose_b else (1, 2) + set_weight_attrs(weight, {"input_dim": input_dim, "output_dim": output_dim}) set_weight_attrs(weight, extra_weight_attrs) return weight w_scale_shape = (num_local_experts, self.input_size_per_partition // group_size, self.output_size_per_partition) - w_scale_dtype = mindspore.uint64 + w_scale_dtype = mindspore.float32 if self.is_310p else mindspore.uint64 w_scale = Parameter( initializer('ones', w_scale_shape, w_scale_dtype), name="w_scale", requires_grad=False) set_weight_attrs(w_scale, {"input_dim": 1, "output_dim": 2}) @@ -98,12 +112,27 @@ class A8W4DynamicLinearMethod(LinearMethodBase): return weight + def process_weight_before_loading(self, param_name, loaded_weight): + """preprocess before loading weight""" + if self.is_310p and (param_name.endswith("weight1") or param_name.endswith('weight2')) \ + and loaded_weight.ndim == 2: + # In 310P, transpose_b is True, so the param shape is (out_dim, in_dim//2). + # But weights is (out_dim//2, in_dim) after packing int4 to int8. + loaded_weight = loaded_weight.astype(np.uint8) + loaded_weight = loaded_weight.transpose(1, 0) + loaded_weight = np_unpack_int8_to_int4(loaded_weight) + loaded_weight = loaded_weight.transpose(1, 0) + loaded_weight = np_pack_int4_to_int8(loaded_weight) + if self.is_310p and param_name.endswith("w_scale"): + loaded_weight = loaded_weight.view(np.float32) + loaded_weight = loaded_weight[..., 0::2] + return loaded_weight + + def process_weights_after_loading(self, layer: nn.Cell) -> None: """calculate gmm_bias""" if isinstance(layer, GroupedMLP): return - - # int4 pack to int8 np_data = layer.weight.asnumpy().astype(np.uint8) np_data_low = ((np_data & 0x0F) << 4).astype(np.int8) >> 4 np_data_high = ((np_data >> 4) << 4).astype(np.int8) >> 4 @@ -113,8 +142,12 @@ class A8W4DynamicLinearMethod(LinearMethodBase): np_int4_data[:, :, ::2] = np_data_low np_int4_data[:, :, 1::2] = np_data_high w_scale = layer.w_scale.asnumpy() - w_scale_repeat = np.repeat(w_scale, layer.weight.shape[1] // w_scale.shape[1], - axis=1).astype(np.uint32).view(np.float32) + if self.is_310p: + np_int4_data = np_int4_data.transpose(0, 2, 1) + w_scale_repeat = np.repeat(w_scale, np_int4_data.shape[1] // w_scale.shape[1], axis=1) + else: + w_scale_repeat = np.repeat(w_scale, np_int4_data.shape[1] // w_scale.shape[1], + axis=1).astype(np.uint32).view(np.float32) gmm_bias = 8 * np.sum( np_int4_data.astype(np.float32) * w_scale_repeat, axis=1) @@ -136,15 +169,24 @@ class A8W4DynamicLinearMethod(LinearMethodBase): output_shape = qx.shape[:-1] + (self.output_size_per_partition,) qx = qx.reshape(-1, self.input_size_per_partition) if self.is_group_mm: - out = self.matmul([qx], [weight], - [gmm_bias], [w_scale], - None, - None, - None, [qx_scale], - group_list, - split_item=3, - group_type=0, - group_list_type=1)[0] + if self.is_310p: + group_list = ops.cast(group_list, dtype=mindspore.int32) + out = self.matmul(qx, + weight, + group_list, + gmm_bias, + qx_scale, + w_scale) + else: + out = self.matmul([qx], [weight], + [gmm_bias], [w_scale], + None, + None, + None, [qx_scale], + group_list, + split_item=3, + group_type=0, + group_list_type=1)[0] else: w_scale = ops.cast(w_scale, mindspore.float16) qx = ops.cast(qx, mindspore.float16) diff --git a/mindformers/parallel_core/inference/quantization/golden_stick/a8dynw8.py b/mindformers/parallel_core/inference/quantization/golden_stick/a8dynw8.py index 33a70c721..63303d6ab 100644 --- a/mindformers/parallel_core/inference/quantization/golden_stick/a8dynw8.py +++ b/mindformers/parallel_core/inference/quantization/golden_stick/a8dynw8.py @@ -1,126 +1,151 @@ -# Copyright 2025 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""A8 dynamic W8 quantization method.""" - -from typing import Union - -import mindspore -from mindspore import Parameter, nn, ops -from mindspore.common.initializer import initializer -from mindspore.ops.auto_generate import QuantBatchMatmul, DynamicQuantExt, GroupedMatmulV4 - -from mindformers.parallel_core.inference.tensor_parallel.layers import LinearMethodBase -from mindformers.parallel_core.inference.tensor_parallel.mappings import reduce_from_model_parallel_region -from mindformers.parallel_core.inference.quantization import QuantizationConfig -from mindformers.parallel_core.inference.weights_utils import set_weight_attrs -from mindformers.models.utils import format_type - - -class A8W8DynamicLinearMethod(LinearMethodBase): - """Linear method with A8W8 dynamic quantization.""" - - def __init__(self, quant_config: QuantizationConfig): - self.quant_config = quant_config - self.quant = DynamicQuantExt() - self.bias_add = ops.Add() - - def create_weights(self, - layer: nn.Cell, - input_size_per_partition: int, - output_partition_sizes: list[int], - params_dtype, - *weight_args, - num_local_experts=None, **extra_weight_attrs) -> Union[Parameter, None]: - output_size_per_partition = sum(output_partition_sizes) - self.output_size_per_partition = output_size_per_partition - self.input_size_per_partition = input_size_per_partition - self.is_group_mm = num_local_experts is not None - - if self.is_group_mm: - weight = None - self.matmul = GroupedMatmulV4() - if not extra_weight_attrs.get('skip_weight_param_allocation', False): - shape = (num_local_experts, input_size_per_partition, output_size_per_partition) - weight = Parameter(initializer('ones', shape, mindspore.int8), requires_grad=False) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 2}) - set_weight_attrs(weight, extra_weight_attrs) - return weight - - w_scale_shape = (num_local_experts, output_size_per_partition) - w_scale_dtype = mindspore.bfloat16 if params_dtype == mindspore.bfloat16 else mindspore.float32 - w_scale = Parameter( - initializer('ones', w_scale_shape, w_scale_dtype), name="w_scale", requires_grad=False) - set_weight_attrs(w_scale, {"output_dim": 1}) - set_weight_attrs(w_scale, extra_weight_attrs) - - else: - self.matmul = QuantBatchMatmul(transpose_x1=False, transpose_x2=True, dtype=params_dtype) - weight_shape = (self.output_size_per_partition, self.input_size_per_partition) - weight = Parameter(initializer('ones', weight_shape, mindspore.int8), requires_grad=False) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - set_weight_attrs(weight, extra_weight_attrs) - - w_scale_shape = (output_size_per_partition,) - w_scale_dtype = mindspore.bfloat16 if params_dtype == mindspore.bfloat16 else mindspore.float32 - w_scale = Parameter( - initializer('ones', w_scale_shape, w_scale_dtype), name="w_scale", requires_grad=False) - set_weight_attrs(w_scale, {"output_dim": 0}) - set_weight_attrs(w_scale, extra_weight_attrs) - - if layer is not None: - layer.insert_param_to_cell("weight", weight) - layer.insert_param_to_cell("w_scale", w_scale) - return weight - - def process_weights_after_loading(self, layer: nn.Cell) -> None: - """ - Process the weight after loading. - This can be used for example, to transpose weights for computation. - """ - if self.is_group_mm: - layer.weight = ops.auto_generate.format_cast(layer.weight, format_type['nz']) - - def apply(self, - layer: mindspore.nn.Cell, - x: mindspore.Tensor, - weight: mindspore.Tensor = None, - bias: mindspore.Parameter = None, - group_list=None, **kwargs) -> mindspore.Tensor: - if weight is None: - weight = layer.weight - w_scale = layer.w_scale - qx, qx_scale = self.quant(x, None) - qx_scale = qx_scale.reshape(-1) - output_shape = qx.shape[:-1] + (self.output_size_per_partition,) - qx = qx.reshape(-1, self.input_size_per_partition) - if self.is_group_mm: - out = self.matmul([qx], [weight], - None, [w_scale], - None, - None, - None, [qx_scale], - group_list, - split_item=3, - group_type=0, - group_list_type=1)[0] - if hasattr(layer, 'delay_allreduce'): - if not layer.delay_allreduce and not layer.skip_bias_add: - out = reduce_from_model_parallel_region(out, layer.tp_group) - else: - out = self.matmul(qx, weight, w_scale, None, None, qx_scale) - if bias is not None: - out = self.bias_add(out, bias) - out = out.reshape(output_shape) - return out +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""A8 dynamic W8 quantization method.""" + +from typing import Union + +import mindspore +from mindspore import Parameter, nn, ops +from mindspore.common.initializer import initializer +from mindspore.ops.auto_generate import QuantBatchMatmul, DynamicQuantExt, GroupedMatmulV4 + +from mindformers.parallel_core.inference.tensor_parallel.layers import LinearMethodBase +from mindformers.parallel_core.inference.tensor_parallel.mappings import reduce_from_model_parallel_region +from mindformers.parallel_core.inference.quantization import QuantizationConfig +try: + from ms_custom_ops import grouped_matmul + GMM_310P = True +except ImportError: + GMM_310P = False +from mindformers.version_control import is_310p +from mindformers.parallel_core.inference.weights_utils import set_weight_attrs +from mindformers.models.utils import format_type + + +class A8W8DynamicLinearMethod(LinearMethodBase): + """Linear method with A8W8 dynamic quantization.""" + + def __init__(self, quant_config: QuantizationConfig): + self.quant_config = quant_config + self.quant = DynamicQuantExt() + self.bias_add = ops.Add() + self.is_310p = is_310p() and GMM_310P + + def create_weights(self, + layer: nn.Cell, + input_size_per_partition: int, + output_partition_sizes: list[int], + params_dtype, + *weight_args, + transpose_b=False, + num_local_experts=None, **extra_weight_attrs) -> Union[Parameter, None]: + output_size_per_partition = sum(output_partition_sizes) + self.output_size_per_partition = output_size_per_partition + self.input_size_per_partition = input_size_per_partition + self.is_group_mm = num_local_experts is not None + + if self.is_group_mm: + weight = None + if self.is_310p: + self.matmul = grouped_matmul + else: + self.matmul = GroupedMatmulV4() + if not extra_weight_attrs.get('skip_weight_param_allocation', False): + shape = (num_local_experts, output_size_per_partition, input_size_per_partition) \ + if transpose_b else (num_local_experts, input_size_per_partition, output_size_per_partition) + weight = Parameter(initializer('ones', shape, mindspore.int8), requires_grad=False) + input_dim, output_dim = (2, 1) if transpose_b else (1, 2) + set_weight_attrs(weight, {"input_dim": input_dim, "output_dim": output_dim}) + set_weight_attrs(weight, extra_weight_attrs) + return weight + + w_scale_shape = (num_local_experts, output_size_per_partition) + w_scale_dtype = mindspore.bfloat16 if params_dtype == mindspore.bfloat16 else mindspore.float32 + w_scale = Parameter( + initializer('ones', w_scale_shape, w_scale_dtype), name="w_scale", requires_grad=False) + set_weight_attrs(w_scale, {"output_dim": 1}) + set_weight_attrs(w_scale, extra_weight_attrs) + + else: + self.matmul = QuantBatchMatmul(transpose_x1=False, transpose_x2=True, dtype=params_dtype) + weight_shape = (self.output_size_per_partition, self.input_size_per_partition) + weight = Parameter(initializer('ones', weight_shape, mindspore.int8), requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + set_weight_attrs(weight, extra_weight_attrs) + + w_scale_shape = (output_size_per_partition,) + w_scale_dtype = mindspore.bfloat16 if params_dtype == mindspore.bfloat16 else mindspore.float32 + w_scale = Parameter( + initializer('ones', w_scale_shape, w_scale_dtype), name="w_scale", requires_grad=False) + set_weight_attrs(w_scale, {"output_dim": 0}) + set_weight_attrs(w_scale, extra_weight_attrs) + + if layer is not None: + layer.insert_param_to_cell("weight", weight) + layer.insert_param_to_cell("w_scale", w_scale) + return weight + + def process_weights_after_loading(self, layer: nn.Cell) -> None: + """ + Process the weight after loading. + This can be used for example, to transpose weights for computation. + """ + if self.is_group_mm: + layer.weight = ops.auto_generate.format_cast(layer.weight, format_type['nz']) + + def apply(self, + layer: mindspore.nn.Cell, + x: mindspore.Tensor, + weight: mindspore.Tensor = None, + bias: mindspore.Parameter = None, + group_list=None, **kwargs) -> mindspore.Tensor: + if weight is None: + weight = layer.weight + w_scale = layer.w_scale + qx, qx_scale = self.quant(x, None) + qx_scale = qx_scale.reshape(-1) + output_shape = qx.shape[:-1] + (self.output_size_per_partition,) + qx = qx.reshape(-1, self.input_size_per_partition) + if self.is_group_mm: + if self.is_310p: + group_list = ops.cast(group_list, dtype=mindspore.int32) + out = self.matmul(qx, + weight, + group_list, + None, + w_scale, + qx_scale, + None, + False, + True) + else: + out = self.matmul([qx], [weight], + None, [w_scale], + None, + None, + None, [qx_scale], + group_list, + split_item=3, + group_type=0, + group_list_type=1)[0] + if hasattr(layer, 'delay_allreduce'): + if not layer.delay_allreduce and not layer.skip_bias_add: + out = reduce_from_model_parallel_region(out, layer.tp_group) + else: + out = self.matmul(qx, weight, w_scale, None, None, qx_scale) + if bias is not None: + out = self.bias_add(out, bias) + out = out.reshape(output_shape) + return out diff --git a/mindformers/parallel_core/inference/quantization/golden_stick/a8w8.py b/mindformers/parallel_core/inference/quantization/golden_stick/a8w8.py index 6cae4e9bd..63eca707e 100644 --- a/mindformers/parallel_core/inference/quantization/golden_stick/a8w8.py +++ b/mindformers/parallel_core/inference/quantization/golden_stick/a8w8.py @@ -18,6 +18,7 @@ import mindspore from mindspore import Tensor, Parameter, ops, nn from mindspore.common.initializer import initializer from mindspore.ops.auto_generate import QuantBatchMatmul, QuantV2 +from mindformers.version_control import is_310p from mindformers.parallel_core.inference.weights_utils import set_weight_attrs from mindformers.parallel_core.inference.quantization import QuantizationConfig from mindformers.parallel_core.inference.tensor_parallel.layers import LinearMethodBase @@ -31,10 +32,11 @@ class A8W8LinearMethod(LinearMethodBase): self.quant = QuantV2() self.bias_add = ops.Add() self.is_modelslim = self.quant_config.is_modelslim + self.is_310p = is_310p() self.is_ms_custom_ops = False try: import ms_custom_ops - self.is_ms_custom_ops = True + self.is_ms_custom_ops = True and not self.is_310p self.ms_custom_ops = ms_custom_ops except ModuleNotFoundError: pass @@ -56,7 +58,7 @@ class A8W8LinearMethod(LinearMethodBase): weight_shape = (self.output_size_per_partition, self.input_size_per_partition) weight = Parameter(initializer('ones', weight_shape, mindspore.int8), requires_grad=False) deq_scale_shape = self.output_size_per_partition - scale_dtype = mindspore.float32 + scale_dtype = mindspore.int64 if self.is_310p else mindspore.float32 deq_scale = Parameter( initializer('ones', deq_scale_shape, scale_dtype), name="deq_scale", requires_grad=False) shape = (self.output_size_per_partition,) @@ -113,7 +115,7 @@ class A8W8LinearMethod(LinearMethodBase): return input_scale = 1 / layer.input_scale.asnumpy() layer.input_scale = Parameter( - Tensor(input_scale, dtype=mindspore.bfloat16), name=layer.input_scale.name, requires_grad=False) + Tensor(input_scale, dtype=self.params_dtype), name=layer.input_scale.name, requires_grad=False) def apply(self, layer: mindspore.nn.Cell, diff --git a/mindformers/parallel_core/inference/quantization/utils.py b/mindformers/parallel_core/inference/quantization/utils.py index 800fee1ed..cfee0e97c 100644 --- a/mindformers/parallel_core/inference/quantization/utils.py +++ b/mindformers/parallel_core/inference/quantization/utils.py @@ -16,6 +16,7 @@ import os import json import glob +import numpy as np from mindformers.parallel_core.inference.quantization import (get_quantization_config, QuantizationConfig) from mindformers.models.configuration_utils import PretrainedConfig @@ -66,3 +67,24 @@ def get_quant_config(model_config: PretrainedConfig, weight_mapping: list) -> Qu config["weight_mapping"] = weight_mapping config["quantization"] = quantization return quant_cls.from_config(config) + +def np_unpack_int8_to_int4(packed_data): + """unpack int8 to int4 numpy array.""" + low_nibbles = (packed_data & 0x0F).astype(np.uint8) + high_nibbles = ((packed_data >> 4) & 0x0F).astype(np.uint8) + + unpacked = np.empty((*packed_data.shape[:-1], packed_data.shape[-1] * 2), + dtype=np.uint8) + unpacked[..., 0::2] = low_nibbles + unpacked[..., 1::2] = high_nibbles + + return unpacked + +def np_pack_int4_to_int8(np_data): + """pack int4 numpy array to int8.""" + np_data = np_data.astype(np.int8) + np_data &= 0x000F + np_data[..., 0::2] <<= 0 + np_data[..., 1::2] <<= 4 + np_int4_data = np_data[..., 0::2] | np_data[..., 1::2] + return np_int4_data diff --git a/mindformers/parallel_core/inference/tensor_parallel/grouped_layers.py b/mindformers/parallel_core/inference/tensor_parallel/grouped_layers.py index 39e51ebd0..afd43f37b 100644 --- a/mindformers/parallel_core/inference/tensor_parallel/grouped_layers.py +++ b/mindformers/parallel_core/inference/tensor_parallel/grouped_layers.py @@ -32,7 +32,7 @@ from mindspore.common.initializer import initializer from mindformers.parallel_core.transformer_config import TransformerConfig from mindformers.parallel_core.inference.quantization import QuantizationConfig from mindformers.parallel_core.inference.quantization.base_config import QuantizeMethodBase -from mindformers.parallel_core.inference.utils import divide +from mindformers.parallel_core.inference.utils import divide, cast_weight_for_310p from mindformers.parallel_core.inference.weights_utils import set_weight_attrs from mindformers.parallel_core.inference.tensor_parallel.mappings import ( gather_from_model_parallel_region, @@ -41,7 +41,8 @@ from mindformers.parallel_core.inference.tensor_parallel.mappings import ( ) from mindformers.parallel_core.inference.parallel_state import ProcessGroup, default_pgs from mindformers.parallel_core.inference.weights_utils import split_loaded_weight, deal_training_moe_weight - +from mindformers.version_control import is_310p +from mindformers.models.utils import format_type class GroupedLinearMethodBase(QuantizeMethodBase): """Base class for different (maybe quantized) grouped linear methods.""" @@ -151,10 +152,29 @@ class GroupedLinearBase(nn.Cell): QuantizeMethodBase] = UnquantizedGroupedLinearMethod() if quant_config is not None: self.quant_method = quant_config.get_quant_method(self, prefix=prefix) + self.param_load_counts: Dict[str, int] = {} def construct(self, x: Tensor, weight: Tensor, group_list: Tensor) -> Tensor: raise NotImplementedError + def format_to_nz(self, param, merge_count=1): + current_count = self.param_load_counts.get(param.name, 0) + 1 + self.param_load_counts[param.name] = current_count + + # Only format when all shards are loaded. + if current_count == merge_count: + cast_weight = param + if param.dtype == ms.qint4x2: + # format_cast don't support qint4x2 + import ms_custom_ops + cast_weight = ms_custom_ops.type_cast(param, ms.int8) + cast_weight = ops.auto_generate.format_cast(cast_weight, format_type['nz']) + if param.dtype == ms.qint4x2: + cast_weight = ms_custom_ops.type_cast(cast_weight, ms.qint4x2) + param.set_data(cast_weight) + del self.param_load_counts[param.name] + ms.runtime.empty_cache() + class ColumnParallelGroupedLinear(GroupedLinearBase): """ @@ -212,6 +232,7 @@ class ColumnParallelGroupedLinear(GroupedLinearBase): weight: Tensor = None, is_expert: bool = True, tp_comm_buffer_name: str = None, + transpose_b: bool = False, tp_group: ProcessGroup = default_pgs, quant_config: Optional[QuantizationConfig] = None, prefix: str = "" @@ -249,6 +270,7 @@ class ColumnParallelGroupedLinear(GroupedLinearBase): self.tp_group = tp_group self.tensor_parallel_group_size = self.tp_group.size self.output_size_per_partition = divide(output_size, self.tensor_parallel_group_size) + self.transpose_b = transpose_b if self.quant_method is None: raise ValueError("`quant_method` is not initialized in ColumnParallelGroupedLinear.") @@ -259,6 +281,7 @@ class ColumnParallelGroupedLinear(GroupedLinearBase): input_size_per_partition=self.input_size, output_partition_sizes=[self.output_size_per_partition], params_dtype=self.config.params_dtype, + transpose_b=self.transpose_b, skip_weight_param_allocation=self.skip_weight_param_allocation, weight_loader=self.weight_loader ) @@ -297,7 +320,8 @@ class ColumnParallelGroupedLinear(GroupedLinearBase): def sharded_state_dict(self): """Provide the sharded state dict.""" expert_parallel_group_size = self.config.num_moe_experts // self.num_local_experts - w_shard = (expert_parallel_group_size, 1, self.tensor_parallel_group_size) + w_shard = (expert_parallel_group_size, self.tensor_parallel_group_size, 1) \ + if self.transpose_b else (expert_parallel_group_size, 1, self.tensor_parallel_group_size) state_dict = {} if not self.skip_weight_param_allocation: @@ -323,7 +347,6 @@ class ColumnParallelGroupedLinear(GroupedLinearBase): """ if expert_id == -1: return - if shard_id is not None: param_output_dim = getattr(param, "output_dim", None) shard_size = param.shape[param_output_dim] // 2 @@ -332,17 +355,21 @@ class ColumnParallelGroupedLinear(GroupedLinearBase): # but the dimension splitting in the network is defined based on three dimensions, # so the splitting dimension needs to be subtracted by 1. weight_shape = len(loaded_weight.get_shape()) - if weight_shape == 1 or (weight_shape != 1 and loaded_weight.get_shape()[1] == 1): + weight_need_transpose = not (self.transpose_b and param.name.endswith("weight1")) + if weight_shape == 1 or (weight_shape != 1 and loaded_weight.get_shape()[1] == 1) \ + or not weight_need_transpose: shard_dim = getattr(param, "output_dim", None) - 1 else: shard_dim = getattr(param, "input_dim", None) - 1 tp_rank = self.tp_group.rank start_idx = tp_rank * shard_size + # 310P need unpack qint4x2, and do transpose, then pack back to qint4x2 + loaded_weight = self.quant_method.process_weight_before_loading(param.name, loaded_weight[:]) loaded_weight = split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size) if weight_shape == 1: loaded_weight = loaded_weight.reshape(-1, 1) - if loaded_weight.shape[1] != 1: + if loaded_weight.shape[1] != 1 and weight_need_transpose: # The Hugging Face weight shape is [hidden_size, moe_ffn_hidden_size] # The shape of param is [moe_ffn_hidden_size, hidden_size] # So must be transposed. @@ -353,6 +380,8 @@ class ColumnParallelGroupedLinear(GroupedLinearBase): if weight_dtype == ms.bfloat16: loaded_weight = ms.from_numpy(loaded_weight).astype(ms.float32).asnumpy() + expected_shape = list(param.shape) + expected_shape[param_output_dim] = shard_size if loaded_weight.shape[1] == 1: if loaded_weight.shape[0] != shard_size: raise ValueError( @@ -365,15 +394,18 @@ class ColumnParallelGroupedLinear(GroupedLinearBase): elif shard_id == "w3": param.asnumpy()[expert_id][shard_size:shard_size + shard_size] = loaded_weight.squeeze(1) else: - if loaded_weight.shape != (param.shape[1], shard_size): + update_indices = [slice(None)] * len(param.shape) + update_indices[0] = expert_id + if shard_id == "w1": + update_indices[param_output_dim] = slice(None, shard_size) + elif shard_id == "w3": + update_indices[param_output_dim] = slice(shard_size, 2 * shard_size) + if loaded_weight.shape != tuple(expected_shape[1:]): raise ValueError( f"'{param.name}.shape' should be equal to 'loaded_weight.shape'," - f" but got the shape of param is {(param.shape[1], shard_size)} and " + f" but got the shape of param is {expected_shape[1:]} and " f"the shape of weight is{loaded_weight.shape}") - if shard_id == "w1": - param.asnumpy()[expert_id][:, :shard_size] = loaded_weight - elif shard_id == "w3": - param.asnumpy()[expert_id][:, shard_size:shard_size + shard_size] = loaded_weight + param.asnumpy()[tuple(update_indices)] = loaded_weight else: loaded_weight = deal_training_moe_weight(loaded_weight) param.init_data() @@ -383,6 +415,8 @@ class ColumnParallelGroupedLinear(GroupedLinearBase): f" but got the shape of param is {param.data[expert_id].shape} and " f"the shape of weight is{loaded_weight.shape}") param[expert_id] = ms.from_numpy(loaded_weight) + if is_310p() and param.name.endswith("weight1"): + self.format_to_nz(param, self.num_local_experts * 2) class RowParallelGroupedLinear(GroupedLinearBase): @@ -441,6 +475,7 @@ class RowParallelGroupedLinear(GroupedLinearBase): delay_allreduce: bool = True, is_expert: bool = True, tp_comm_buffer_name: str = None, + transpose_b: bool = False, tp_group: ProcessGroup = default_pgs, quant_config: Optional[QuantizationConfig] = None, prefix: str = "" @@ -478,6 +513,7 @@ class RowParallelGroupedLinear(GroupedLinearBase): self.tp_group = tp_group self.tensor_parallel_group_size = self.tp_group.size self.input_size_per_partition = divide(input_size, self.tensor_parallel_group_size) + self.transpose_b = transpose_b if self.quant_method is None: raise ValueError("`quant_method` is not initialized in RowParallelGroupedLinear.") @@ -487,6 +523,7 @@ class RowParallelGroupedLinear(GroupedLinearBase): input_size_per_partition=self.input_size_per_partition, output_partition_sizes=[self.output_size], params_dtype=self.config.params_dtype, + transpose_b=self.transpose_b, skip_weight_param_allocation=self.skip_weight_param_allocation, weight_loader=self.weight_loader ) @@ -508,7 +545,8 @@ class RowParallelGroupedLinear(GroupedLinearBase): weight = self.weight else: # Check the weight passed in is the correct shape. - expected_shape = (self.num_local_experts, self.input_size_per_partition, self.output_size) + expected_shape = (self.num_local_experts,) + (self.output_size, self.input_size_per_partition) \ + if self.transpose_b else (self.input_size_per_partition, self.output_size) if weight.shape != expected_shape: raise ValueError( f"supplied weight's shape is {tuple(weight.shape)}, " @@ -531,7 +569,8 @@ class RowParallelGroupedLinear(GroupedLinearBase): def sharded_state_dict(self): """Provide the sharded state dict.""" expert_parallel_group_size = self.config.num_moe_experts // self.num_local_experts - w_shard = (expert_parallel_group_size, self.tensor_parallel_group_size, 1) + w_shard = (expert_parallel_group_size, 1, self.tensor_parallel_group_size) \ + if self.transpose_b else (expert_parallel_group_size, self.tensor_parallel_group_size, 1) state_dict = {} if not self.skip_weight_param_allocation: @@ -567,12 +606,12 @@ class RowParallelGroupedLinear(GroupedLinearBase): tp_rank = self.tp_group.rank if shard_id is not None: param_output_dim = getattr(param, "input_dim", None) - if not param.name.endswith("weight") and param_output_dim is None: if param.name.endswith("w_scale") and len(loaded_weight.get_shape()) == 2 \ and loaded_weight.get_shape()[1] == 1: loaded_weight = loaded_weight[:].squeeze(-1) param.init_data() + loaded_weight = cast_weight_for_310p(loaded_weight[:]) weight_dtype = ms.from_numpy(loaded_weight[:]).dtype if weight_dtype == ms.bfloat16: loaded_weight = ms.from_numpy(loaded_weight[:]).astype(ms.float32).asnumpy() @@ -583,17 +622,22 @@ class RowParallelGroupedLinear(GroupedLinearBase): # Because this weight shape is two-dimensional, # but the dimension splitting in the network is defined based on three dimensions, # so the splitting dimension needs to be subtracted by 1. - shard_dim = getattr(param, "output_dim", None) - 1 + shard_dim = getattr(param, "input_dim", None) - 1 + weight_need_transpose = not self.transpose_b or param.name.endswith("w_scale") + if weight_need_transpose: + shard_dim = 1 - shard_dim start_idx = tp_rank * shard_size + loaded_weight = self.quant_method.process_weight_before_loading(param.name, loaded_weight[:]) loaded_weight = split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size) # The Hugging Face weight shape is [hidden_size, moe_ffn_hidden_size] # The shape of param is [moe_ffn_hidden_size, hidden_size] # So must be transposed. - loaded_weight = loaded_weight.T - if loaded_weight.shape != (shard_size, param.shape[2]): + if weight_need_transpose: + loaded_weight = loaded_weight.T + if loaded_weight.shape != param.shape[1:]: raise ValueError( f"'{param.name}.shape' should be equal to 'loaded_weight.shape'," - f" but got the shape of param is {(shard_size, param.shape[2])} and " + f" but got the shape of param is {(param.shape[1:])} and " f"the shape of weight is{loaded_weight.shape}") param.init_data() weight_dtype = ms.from_numpy(loaded_weight).dtype @@ -612,3 +656,5 @@ class RowParallelGroupedLinear(GroupedLinearBase): f" but got the shape of param is {param.data[expert_id].shape} and " f"the shape of weight is{loaded_weight.shape}") param[expert_id] = ms.from_numpy(loaded_weight) + if is_310p() and param.name.endswith("weight2"): + self.format_to_nz(param, self.num_local_experts) diff --git a/mindformers/parallel_core/inference/tensor_parallel/layers.py b/mindformers/parallel_core/inference/tensor_parallel/layers.py index 906fe2673..2074f390c 100644 --- a/mindformers/parallel_core/inference/tensor_parallel/layers.py +++ b/mindformers/parallel_core/inference/tensor_parallel/layers.py @@ -176,6 +176,7 @@ class LinearBase(ms.nn.Cell): param.set_data(cast_weight) del self.param_load_counts[param.name] + class ColumnParallelLinear(LinearBase): """ The dense layer with weight sliced on second dimension by tensor parallel size. @@ -395,7 +396,7 @@ class ColumnParallelLinear(LinearBase): f"'{param.name}.shape' should be equal to 'loaded_weight.shape'," f" but got the shape of param is {param.shape} and the shape of weight is{loaded_weight.shape}") param.init_data() - param.set_data(ms.from_numpy(loaded_weight)) + param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) if is_310p() and param.name.endswith("weight"): self.format_to_nz(param) @@ -878,7 +879,7 @@ class RowParallelLinear(LinearBase): raise ValueError( f"'{param.name}.shape' should be equal to 'loaded_weight.shape'," f" but got the shape of param is {param.shape} and the shape of weight is{loaded_weight.shape}") - param.set_data(ms.from_numpy(loaded_weight)) + param.set_data(ms.from_numpy(loaded_weight)) # TODO param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) if is_310p() and param.name.endswith("weight"): self.format_to_nz(param) @@ -1061,7 +1062,9 @@ class ReplicatedLinear(LinearBase): f" but got the shape of param is {param.shape} " f"and the shape of weight is{loaded_weight.shape}") param.init_data() - param.set_data(ms.from_numpy(loaded_weight)) + param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) + if is_310p() and param.name.endswith("weight"): + self.format_to_nz(param, 2) class VocabParallelEmbedding(nn.Cell): diff --git a/mindformers/parallel_core/inference/transformer/moe/experts.py b/mindformers/parallel_core/inference/transformer/moe/experts.py index 6e77b3d47..1608cfe4d 100644 --- a/mindformers/parallel_core/inference/transformer/moe/experts.py +++ b/mindformers/parallel_core/inference/transformer/moe/experts.py @@ -30,6 +30,7 @@ from mindformers.parallel_core.inference.transformer.activation import get_act_f from mindformers.parallel_core.inference.utils import divide from mindformers.parallel_core.process_group_config import ModelCommProcessGroups, default_model_comm_pgs from mindformers.parallel_core.inference.weights_utils import set_weight_attrs +from mindformers.version_control import is_310p class GroupedMLP(nn.Cell): @@ -54,6 +55,7 @@ class GroupedMLP(nn.Cell): # use model_comm_pgs.moe_tp_group as tensor parallel group in this module. self.tp_group = model_comm_pgs.moe_tp self.tp_group_size = self.tp_group.size + self.transpose_b = is_310p() ffn_hidden_size = self.config.moe_ffn_hidden_size self.ffn_hidden_size_per_partition = divide(ffn_hidden_size, self.tp_group_size) @@ -70,6 +72,7 @@ class GroupedMLP(nn.Cell): num_local_experts=self.num_local_experts, input_size_per_partition=self.input_size, output_partition_sizes=[divide(ffn_hidden_size, self.tp_group_size)], + transpose_b=self.transpose_b, params_dtype=self.config.params_dtype, ) self.weight2 = self.quant_method.create_weights( @@ -77,6 +80,7 @@ class GroupedMLP(nn.Cell): num_local_experts=self.num_local_experts, input_size_per_partition=self.ffn_hidden_size_per_partition, output_partition_sizes=[self.input_size], + transpose_b=self.transpose_b, params_dtype=self.config.params_dtype, ) @@ -92,6 +96,7 @@ class GroupedMLP(nn.Cell): weight=self.weight1, # Skip creating weights and use weight1 for gemm linear calculation is_expert=True, tp_group=self.tp_group, + transpose_b=self.transpose_b, quant_config=quant_config, prefix=f"{prefix}.linear_fc1", ) @@ -112,6 +117,7 @@ class GroupedMLP(nn.Cell): weight=self.weight2, # Skip creating weights and use weight2 for gemm linear calculation is_expert=True, tp_group=self.tp_group, + transpose_b=self.transpose_b, quant_config=quant_config, prefix=f"{prefix}.linear_fc2", ) diff --git a/mindformers/parallel_core/inference/transformer/moe/router.py b/mindformers/parallel_core/inference/transformer/moe/router.py index 832ab01fa..9000b545f 100644 --- a/mindformers/parallel_core/inference/transformer/moe/router.py +++ b/mindformers/parallel_core/inference/transformer/moe/router.py @@ -108,7 +108,7 @@ class Router(nn.Cell): f"'param.data.shape' should be equal to 'loaded_weight.shape'," f" but got the shape of param is {param.shape} " f"and the shape of weight is{loaded_weight.shape}") - param.set_data(ms.from_numpy(loaded_weight).astype(param.dtype)) + param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) def construct(self, input_tensor: Tensor): """ diff --git a/mindformers/parallel_core/inference/transformer/moe/token_dispatcher.py b/mindformers/parallel_core/inference/transformer/moe/token_dispatcher.py index 6f5e6e5ad..f1fa5f158 100644 --- a/mindformers/parallel_core/inference/transformer/moe/token_dispatcher.py +++ b/mindformers/parallel_core/inference/transformer/moe/token_dispatcher.py @@ -27,7 +27,7 @@ from mindspore.ops.auto_generate import (MoeInitRoutingV2, from mindformers.parallel_core.transformer_config import TransformerConfig from mindformers.parallel_core.process_group_config import ModelCommProcessGroups, default_model_comm_pgs -from mindformers.version_control import is_910b +from mindformers.version_control import is_910b, is_310p class MoETokenDispatcher: @@ -109,6 +109,7 @@ class MoEAllGatherTokenDispatcher(MoETokenDispatcher): self.cast = ops.Cast() self.moe_init_routing_v2 = MoeInitRoutingV2() self.moe_token_unpermute = MoeTokenUnpermute() + self.is_310p = is_310p() def dispatch_preprocess(self, expert_weight, routing_map): """Preprocess expert weight by masking out invalid experts.""" @@ -129,8 +130,8 @@ class MoEAllGatherTokenDispatcher(MoETokenDispatcher): expert_capacity=0, expert_num=self.num_experts, drop_pad_mode=0, - expert_tokens_count_or_cumsum_flag=2, - expert_tokens_before_capacity_flag=True + expert_tokens_count_or_cumsum_flag=1 if self.is_310p else 2, + expert_tokens_before_capacity_flag=not self.is_310p ) # Avoid the problem of poor performance of the split(int32) operator @@ -143,7 +144,8 @@ class MoEAllGatherTokenDispatcher(MoETokenDispatcher): def token_combine(self, hidden_states, expert_weight, *args): """Combines expert outputs.""" (tokens_per_expert,) = args - hidden_states = mint.nan_to_num(hidden_states, 0, 0, 0) + if not self.is_310p: + hidden_states = mint.nan_to_num(hidden_states, 0, 0, 0) expert_weight = expert_weight.astype(hidden_states.dtype) hidden_states = self.moe_token_unpermute( permuted_tokens=hidden_states, diff --git a/mindformers/parallel_core/inference/transformer/multi_latent_attention.py b/mindformers/parallel_core/inference/transformer/multi_latent_attention.py index df402ff6e..d6474353d 100644 --- a/mindformers/parallel_core/inference/transformer/multi_latent_attention.py +++ b/mindformers/parallel_core/inference/transformer/multi_latent_attention.py @@ -45,7 +45,8 @@ from mindformers.parallel_core.inference.base_models.common.embeddings.yarn_rota from mindformers.parallel_core.inference.base_models.common.embeddings.rope_utils import get_rope from mindformers.parallel_core.process_group_config import ModelCommProcessGroups, default_model_comm_pgs from mindformers.parallel_core.inference.weights_utils import set_weight_attrs, split_loaded_weight - +from mindformers.version_control import is_310p +from mindformers.models.utils import format_type @dataclass class MLASelfAttentionSubmodules: @@ -172,6 +173,7 @@ class MultiLatentAttention(Attention): self.dim_slice_3d = P.Slice() self.transpose = P.Transpose() self.out_absorb_matmul = P.BatchMatMul(transpose_b=True) + self.is_310p = is_310p() def construct( self, @@ -363,7 +365,10 @@ class MLASelfAttention(MultiLatentAttention): Process the weight after loading. This can be used for example, to transpose weights for computation. """ - q_absorb, out_absorb = mint.split(self.linear_kv_up_proj.weight, + weight = self.linear_kv_up_proj.weight + if self.is_310p: + weight = ops.auto_generate.format_cast(weight, format_type['nd']) + q_absorb, out_absorb = mint.split(weight, [self.num_attention_heads_per_partition * self.config.qk_head_dim, self.num_attention_heads_per_partition * self.config.v_head_dim], -2) self.q_absorb = q_absorb.reshape(self.num_attention_heads_per_partition, diff --git a/mindformers/parallel_core/inference/utils.py b/mindformers/parallel_core/inference/utils.py index 852b26e12..3485bb01d 100644 --- a/mindformers/parallel_core/inference/utils.py +++ b/mindformers/parallel_core/inference/utils.py @@ -373,3 +373,15 @@ def use_ms_custom_ops(): return False return not is_310p() + +def cast_weight_for_310p(loaded_weight): + """ + Casts weights to float16 for 310p. + + In non-quantized scenarios, the 310P hardware only supports float16 weights. + This function converts float32 or bfloat16 weights to float16. + """ + cast_weight = (loaded_weight.astype(np.float16) if + (str(loaded_weight.dtype) == "float32" or str( + loaded_weight.dtype) == "bfloat16") else loaded_weight) + return cast_weight diff --git a/mindformers/parallel_core/inference/weights_utils.py b/mindformers/parallel_core/inference/weights_utils.py index 86f3d6968..c0c175851 100644 --- a/mindformers/parallel_core/inference/weights_utils.py +++ b/mindformers/parallel_core/inference/weights_utils.py @@ -393,3 +393,22 @@ def split_fusion_loaded_weight(loaded_weight, start_idxs, shard_sizes): loaded_weight_parts.append(loaded_weight[start_idx:start_idx + shard_size]) perrank_ffn_weight = np.concatenate(loaded_weight_parts, axis=0) return perrank_ffn_weight + + +def process_weights_for_310p(name, loaded_weight, params_dict, loaded_params): + """ + Process the loadweight for 310P. + + Args: + name: The name of the weight to be loaded. + loaded_weight: The weights to be loaded. + params_dict: The dictionary of model parameters. + loaded_params: The set of already loaded parameter names. + Returns: + The processed weights. + """ + + if "deq_scale" in name: + loaded_weight = loaded_weight[:].astype(np.float32).view(np.int32).astype(np.int64) + + return loaded_weight From f3e09f69fe3bb1d782e35485169425b5df5642e1 Mon Sep 17 00:00:00 2001 From: HighCloud Date: Thu, 20 Nov 2025 11:33:26 +0800 Subject: [PATCH 2/2] code clean --- mindformers/models/utils.py | 16 ++--- .../inference/base_models/gpt/gpt_model.py | 23 ++++--- .../inference/quantization/base_config.py | 2 +- .../quantization/golden_stick/a8dynw4.py | 6 +- .../quantization/golden_stick/a8dynw8.py | 6 +- .../quantization/golden_stick/a8w8.py | 2 +- .../tensor_parallel/grouped_layers.py | 62 ++++++++++--------- .../inference/tensor_parallel/layers.py | 32 +++++----- .../inference/transformer/moe/router.py | 4 +- .../transformer/multi_latent_attention.py | 10 +-- mindformers/parallel_core/inference/utils.py | 14 ++--- .../parallel_core/inference/weights_utils.py | 2 +- tests/st/test_ut/base_schema.json | 4 +- 13 files changed, 97 insertions(+), 86 deletions(-) diff --git a/mindformers/models/utils.py b/mindformers/models/utils.py index d75e6a6eb..cdd4c9e9f 100644 --- a/mindformers/models/utils.py +++ b/mindformers/models/utils.py @@ -97,7 +97,7 @@ def reverse_dict(d: dict): new_d = {} for k, v in d.items(): if v in new_d: - raise ValueError(f"Different keys in dict have same values.") + raise ValueError("Different keys in dict have same values.") new_d[v] = k return new_d @@ -124,16 +124,16 @@ def check_use_3d_tensor_parallel_valid(config): if not use_3d_tensor_parallel or not is_config_valid: return False if not config.use_flash_attention: - raise ValueError(f"When the use_3d_tensor_parallel = True, the use_flash_attention must be True ") + raise ValueError("When the use_3d_tensor_parallel = True, the use_flash_attention must be True ") if config.parallel_config.get_ulysses_cp_num() > 1: - raise ValueError(f"Currently, when the use_3d_tensor_parallel = True, " + raise ValueError("Currently, when the use_3d_tensor_parallel = True, " "the cp_ds of the ulysses context parallel must be 1") if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation(): - raise ValueError(f"Currently, when the use_3d_tensor_parallel = True, the auto parallel is not supported") + raise ValueError("Currently, when the use_3d_tensor_parallel = True, the auto parallel is not supported") if config.moe_config is not None and config.moe_config.expert_num > 1: - raise ValueError(f"Currently, when the use_3d_tensor_parallel = True, the MoE is not supported") + raise ValueError("Currently, when the use_3d_tensor_parallel = True, the MoE is not supported") if not config.parallel_config.use_seq_parallel: - raise ValueError(f"Currently, when the use_3d_tensor_parallel = True, the use_seq_parallel must be True") + raise ValueError("Currently, when the use_3d_tensor_parallel = True, the use_seq_parallel must be True") if check_fine_grain_interleave_valid(config.fine_grain_interleave, config.parallel_config): raise ValueError("Currently, when the use_3d_tensor_parallel = True, " "the fine_grain_interleave is not supported") @@ -142,8 +142,8 @@ def check_use_3d_tensor_parallel_valid(config): tp_z = getattr(config, "tp_z", 1) model_parallel = config.parallel_config.model_parallel if model_parallel > 1 and tp_x * tp_y * tp_z != config.parallel_config.model_parallel: - raise ValueError("tp_x * tp_y * tp_z should be equal to model_parallel, but got " - "tp_x={}, tp_y={}, tp_z={}, model_parallel={}.".format(tp_x, tp_y, tp_z, model_parallel)) + raise ValueError(f"tp_x * tp_y * tp_z should be equal to model_parallel, but got " + f"tp_x={tp_x}, tp_y={tp_y}, tp_z={tp_z}, model_parallel={model_parallel}.") if model_parallel > 1: logger.info(f"use_3d_tensor_parallel is True, (tp_x, tp_y, tp_z): ({tp_x}, {tp_y}, {tp_z})") return True diff --git a/mindformers/parallel_core/inference/base_models/gpt/gpt_model.py b/mindformers/parallel_core/inference/base_models/gpt/gpt_model.py index f4567237d..493015eba 100644 --- a/mindformers/parallel_core/inference/base_models/gpt/gpt_model.py +++ b/mindformers/parallel_core/inference/base_models/gpt/gpt_model.py @@ -118,7 +118,7 @@ class GPTModel(nn.Cell): model_comm_pgs: Optional[ModelCommProcessGroups] = None, quant_config: Optional[QuantizationConfig] = None, ): - super(GPTModel, self).__init__() + super().__init__() self.check_support(fp16_lm_cross_entropy, rope_scaling) self.config = config self.quant_config = quant_config @@ -319,7 +319,7 @@ class GPTModel(nn.Cell): return logits def get_params_dict(self): - params_dict = dict() + params_dict = {} for _, module in self.modules_dict.items(): module_params = module.parameters_dict() for param_name, param in module_params.items(): @@ -405,16 +405,19 @@ class GPTModel(nn.Cell): num_experts = self.config.num_moe_experts if '.weight1' in name: weight = loaded_weight[:].reshape(num_experts, self.config.hidden_size, -1) - if '.weight2' in name: + elif '.weight2' in name: weight = loaded_weight[:].reshape(num_experts, self.config.moe_ffn_hidden_size, -1) + else: + weight = None - for expert_id in range(num_experts): - expert_id = self.map_global_expert_id_to_local_expert_id(expert_id) - loaded_weight = weight[expert_id] - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id=None, expert_id=expert_id) - loaded_params.add(name) + if weight is not None: + for expert_id in range(num_experts): + expert_id = self.map_global_expert_id_to_local_expert_id(expert_id) + loaded_weight = weight[expert_id] + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id=None, expert_id=expert_id) + loaded_params.add(name) else: param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/mindformers/parallel_core/inference/quantization/base_config.py b/mindformers/parallel_core/inference/quantization/base_config.py index e1ab75fc0..307a93020 100644 --- a/mindformers/parallel_core/inference/quantization/base_config.py +++ b/mindformers/parallel_core/inference/quantization/base_config.py @@ -81,7 +81,7 @@ class QuantizationConfig(ABC): def __init__(self): super().__init__() # mapping is updated by models as they initialize - self.packed_modules_mapping: dict[str, list[str]] = dict() + self.packed_modules_mapping: dict[str, list[str]] = {} @abstractmethod def get_name(self) -> QuantizationBackends: diff --git a/mindformers/parallel_core/inference/quantization/golden_stick/a8dynw4.py b/mindformers/parallel_core/inference/quantization/golden_stick/a8dynw4.py index 4c0ad51cf..79d767aa5 100644 --- a/mindformers/parallel_core/inference/quantization/golden_stick/a8dynw4.py +++ b/mindformers/parallel_core/inference/quantization/golden_stick/a8dynw4.py @@ -47,8 +47,8 @@ class A8W4DynamicLinearMethod(LinearMethodBase): input_size_per_partition: int, output_partition_sizes: list[int], params_dtype, - transpose_b=False, *weight_args, + transpose_b=False, num_local_experts=None, **extra_weight_attrs) -> Union[Parameter, None]: output_size_per_partition = sum(output_partition_sizes) self.output_size_per_partition = output_size_per_partition @@ -64,9 +64,9 @@ class A8W4DynamicLinearMethod(LinearMethodBase): else: self.matmul = GroupedMatmulV4() if not extra_weight_attrs.get('skip_weight_param_allocation', False): - weight_shape = (num_local_experts, self.output_size_per_partition, + weight_shape = (num_local_experts, self.output_size_per_partition, self.input_size_per_partition // 2) \ - if transpose_b else (num_local_experts, self.input_size_per_partition, + if transpose_b else (num_local_experts, self.input_size_per_partition, self.output_size_per_partition // 2) weight = Parameter(initializer('ones', weight_shape, mindspore.qint4x2), requires_grad=False) input_dim, output_dim = (2, 1) if transpose_b else (1, 2) diff --git a/mindformers/parallel_core/inference/quantization/golden_stick/a8dynw8.py b/mindformers/parallel_core/inference/quantization/golden_stick/a8dynw8.py index 63303d6ab..61ea805b6 100644 --- a/mindformers/parallel_core/inference/quantization/golden_stick/a8dynw8.py +++ b/mindformers/parallel_core/inference/quantization/golden_stick/a8dynw8.py @@ -24,14 +24,14 @@ from mindspore.ops.auto_generate import QuantBatchMatmul, DynamicQuantExt, Group from mindformers.parallel_core.inference.tensor_parallel.layers import LinearMethodBase from mindformers.parallel_core.inference.tensor_parallel.mappings import reduce_from_model_parallel_region from mindformers.parallel_core.inference.quantization import QuantizationConfig +from mindformers.version_control import is_310p +from mindformers.parallel_core.inference.weights_utils import set_weight_attrs +from mindformers.models.utils import format_type try: from ms_custom_ops import grouped_matmul GMM_310P = True except ImportError: GMM_310P = False -from mindformers.version_control import is_310p -from mindformers.parallel_core.inference.weights_utils import set_weight_attrs -from mindformers.models.utils import format_type class A8W8DynamicLinearMethod(LinearMethodBase): diff --git a/mindformers/parallel_core/inference/quantization/golden_stick/a8w8.py b/mindformers/parallel_core/inference/quantization/golden_stick/a8w8.py index 63eca707e..7a3886e2d 100644 --- a/mindformers/parallel_core/inference/quantization/golden_stick/a8w8.py +++ b/mindformers/parallel_core/inference/quantization/golden_stick/a8w8.py @@ -35,7 +35,7 @@ class A8W8LinearMethod(LinearMethodBase): self.is_310p = is_310p() self.is_ms_custom_ops = False try: - import ms_custom_ops + import ms_custom_ops # pylint: disable=import-outside-toplevel self.is_ms_custom_ops = True and not self.is_310p self.ms_custom_ops = ms_custom_ops except ModuleNotFoundError: diff --git a/mindformers/parallel_core/inference/tensor_parallel/grouped_layers.py b/mindformers/parallel_core/inference/tensor_parallel/grouped_layers.py index afd43f37b..6c3275204 100644 --- a/mindformers/parallel_core/inference/tensor_parallel/grouped_layers.py +++ b/mindformers/parallel_core/inference/tensor_parallel/grouped_layers.py @@ -49,7 +49,7 @@ class GroupedLinearMethodBase(QuantizeMethodBase): @abstractmethod def create_weights(self, layer: nn.Cell, num_local_experts: int, - input_size_per_partition: int, output_size_per_partition: int, + input_size_per_partition: int, output_partition_sizes: list[int], params_dtype, **extra_weight_attrs): """Create weights for a grouped linear layer. The weights will be set as attributes of the layer. @@ -58,7 +58,7 @@ class GroupedLinearMethodBase(QuantizeMethodBase): layer: The layer that is using the GroupedLinearMethodBase factory. num_local_experts: The number of local experts. input_size_per_partition: Sizes of the input dim on rank X. - output_size_per_partition: Sizes of the output dim on rank X. + output_partition_sizes: Sizes of the output dim on rank X. params_dtype: Datatype of the parameters. """ raise NotImplementedError @@ -87,7 +87,7 @@ class UnquantizedGroupedLinearMethod(GroupedLinearMethodBase): def create_weights(self, layer: nn.Cell, num_local_experts: int, input_size_per_partition: int, output_partition_sizes: list[int], - params_dtype, **extra_weight_attrs): + params_dtype, **extra_weight_attrs): # pylint: disable=arguments-renamed weight_shape = (num_local_experts, input_size_per_partition, sum(output_partition_sizes)) weight = Parameter(initializer("zeros", weight_shape, params_dtype), requires_grad=False) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 2}) @@ -158,6 +158,7 @@ class GroupedLinearBase(nn.Cell): raise NotImplementedError def format_to_nz(self, param, merge_count=1): + """Format parameter to nz format.""" current_count = self.param_load_counts.get(param.name, 0) + 1 self.param_load_counts[param.name] = current_count @@ -166,7 +167,7 @@ class GroupedLinearBase(nn.Cell): cast_weight = param if param.dtype == ms.qint4x2: # format_cast don't support qint4x2 - import ms_custom_ops + import ms_custom_ops # pylint: disable=import-outside-toplevel cast_weight = ms_custom_ops.type_cast(param, ms.int8) cast_weight = ops.auto_generate.format_cast(cast_weight, format_type['nz']) if param.dtype == ms.qint4x2: @@ -237,17 +238,17 @@ class ColumnParallelGroupedLinear(GroupedLinearBase): quant_config: Optional[QuantizationConfig] = None, prefix: str = "" ): - super(ColumnParallelGroupedLinear, self).__init__(num_local_experts, - input_size, - output_size, - skip_bias_add, - config.params_dtype, - quant_config=quant_config, - prefix=prefix) + super().__init__(num_local_experts, + input_size, + output_size, + skip_bias_add, + config.params_dtype, + quant_config=quant_config, + prefix=prefix) if stride > 1: raise NotImplementedError( - "For ColumnParallelGroupedLinear, `stride > 1` is not supported for now, " - "but got `stride={}`".format(stride)) + f"For ColumnParallelGroupedLinear, `stride > 1` is not supported for now, " + f"but got `stride={stride}`") if skip_bias_add: raise NotImplementedError( "For ColumnParallelGroupedLinear, `skip_bias_add=True` is not supported for now." @@ -298,7 +299,7 @@ class ColumnParallelGroupedLinear(GroupedLinearBase): else: self.bias = None - def construct(self, input_parallel, weight=None, group_list=None): + def construct(self, input_parallel, weight=None, group_list=None): # pylint: disable=arguments-renamed """Forward of ColumnParallelGroupedLinear.""" if weight is None: weight = self.weight @@ -356,11 +357,10 @@ class ColumnParallelGroupedLinear(GroupedLinearBase): # so the splitting dimension needs to be subtracted by 1. weight_shape = len(loaded_weight.get_shape()) weight_need_transpose = not (self.transpose_b and param.name.endswith("weight1")) - if weight_shape == 1 or (weight_shape != 1 and loaded_weight.get_shape()[1] == 1) \ - or not weight_need_transpose: - shard_dim = getattr(param, "output_dim", None) - 1 - else: - shard_dim = getattr(param, "input_dim", None) - 1 + cond = weight_shape == 1 or (weight_shape != 1 and loaded_weight.get_shape()[1] == 1) \ + or not weight_need_transpose + shard_dim_map = {True: "output_dim", False: "input_dim"} + shard_dim = getattr(param, shard_dim_map[cond], None) - 1 tp_rank = self.tp_group.rank start_idx = tp_rank * shard_size @@ -415,6 +415,10 @@ class ColumnParallelGroupedLinear(GroupedLinearBase): f" but got the shape of param is {param.data[expert_id].shape} and " f"the shape of weight is{loaded_weight.shape}") param[expert_id] = ms.from_numpy(loaded_weight) + self.post_process(param) + + def post_process(self, param): + """post process in weight loader""" if is_310p() and param.name.endswith("weight1"): self.format_to_nz(param, self.num_local_experts * 2) @@ -480,17 +484,17 @@ class RowParallelGroupedLinear(GroupedLinearBase): quant_config: Optional[QuantizationConfig] = None, prefix: str = "" ): - super(RowParallelGroupedLinear, self).__init__(num_local_experts, - input_size, - output_size, - skip_bias_add, - config.params_dtype, - quant_config=quant_config, - prefix=prefix) + super().__init__(num_local_experts, + input_size, + output_size, + skip_bias_add, + config.params_dtype, + quant_config=quant_config, + prefix=prefix) if stride > 1: raise NotImplementedError( - "For RowParallelGroupedLinear, `stride > 1` is not supported for now, " - "but got `stride={}`".format(stride)) + f"For RowParallelGroupedLinear, `stride > 1` is not supported for now, " + f"but got `stride={stride}`") if not is_expert: raise NotImplementedError( "For RowParallelGroupedLinear, `is_expert=False` is not supported for now.") @@ -539,7 +543,7 @@ class RowParallelGroupedLinear(GroupedLinearBase): else: self.bias = None - def construct(self, input_, weight=None, group_list=None): + def construct(self, input_, weight=None, group_list=None): # pylint: disable=arguments-renamed """Forward of RowParallelGroupedLinear.""" if weight is None: weight = self.weight diff --git a/mindformers/parallel_core/inference/tensor_parallel/layers.py b/mindformers/parallel_core/inference/tensor_parallel/layers.py index 2074f390c..e86e918d4 100644 --- a/mindformers/parallel_core/inference/tensor_parallel/layers.py +++ b/mindformers/parallel_core/inference/tensor_parallel/layers.py @@ -169,7 +169,7 @@ class LinearBase(ms.nn.Cell): if is_310p(): cast_weight = ops.auto_generate.format_cast(param, format_type['nz']) else: - import ms_custom_ops + import ms_custom_ops # pylint: disable=import-outside-toplevel cast_weight = ms_custom_ops.trans_data(param, transdata_type=1) if move_to_cpu: cast_weight = cast_weight.move_to("CPU") @@ -243,7 +243,7 @@ class ColumnParallelLinear(LinearBase): quant_config: Optional[QuantizationConfig] = None, prefix: str = "" ): - super(ColumnParallelLinear, self).__init__( + super().__init__( input_size, output_size, skip_bias_add, @@ -252,8 +252,8 @@ class ColumnParallelLinear(LinearBase): prefix=prefix ) if stride > 1: - raise NotImplementedError("For ColumnParallelLinear, `stride > 1` is not supported for now, " - "but got `stride={}`".format(stride)) + raise NotImplementedError(f"For ColumnParallelLinear, `stride > 1` is not supported for now, " + f"but got `stride={stride}`") if keep_master_weight_for_test: raise NotImplementedError( "For ColumnParallelLinear, `keep_master_weight_for_test` is not supported for now") @@ -481,6 +481,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): array_id = 0 elif loaded_shard_id == 'hidden': array_id = 1 + else: + raise ValueError(f"Unknown loaded_shard_id: {loaded_shard_id}") shard_offset = sum(self.output_sizes[:array_id]) // tp_size shard_size = self.output_sizes[array_id] // tp_size @@ -635,6 +637,8 @@ class QKVParallelLinear(ColumnParallelLinear): shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size shard_size = self.num_kv_heads * self.head_size + else: + raise ValueError(f"Unknown loaded_shard_id: {loaded_shard_id}") if loaded_shard_id == "q": shard_id = tp_rank @@ -740,15 +744,15 @@ class RowParallelLinear(LinearBase): quant_config: Optional[QuantizationConfig] = None, prefix: str = "" ): - super(RowParallelLinear, self).__init__(input_size, - output_size, - skip_bias_add, - config.params_dtype, - quant_config=quant_config, - prefix=prefix) + super().__init__(input_size, + output_size, + skip_bias_add, + config.params_dtype, + quant_config=quant_config, + prefix=prefix) if stride > 1: - raise NotImplementedError("For RowParallelLinear, `stride > 1` is not supported for now, " - "but got `stride={}`".format(stride)) + raise NotImplementedError(f"For RowParallelLinear, `stride > 1` is not supported for now, " + f"but got `stride={stride}`") if skip_bias_add: raise NotImplementedError("For RowParallelLinear, `skip_bias_add=True` is not supported for now") if keep_master_weight_for_test: @@ -933,8 +937,8 @@ class ReplicatedLinear(LinearBase): quant_config=quant_config, prefix=prefix) if stride > 1: - raise NotImplementedError("For ReplicatedLinear, `stride > 1` is not supported for now, " - "but got `stride={}`".format(stride)) + raise NotImplementedError(f"For ReplicatedLinear, `stride > 1` is not supported for now, " + f"but got `stride={stride}`") if skip_bias_add: raise NotImplementedError("For ReplicatedLinear, `skip_bias_add=True` is not supported for now") if keep_master_weight_for_test: diff --git a/mindformers/parallel_core/inference/transformer/moe/router.py b/mindformers/parallel_core/inference/transformer/moe/router.py index 9000b545f..47891748b 100644 --- a/mindformers/parallel_core/inference/transformer/moe/router.py +++ b/mindformers/parallel_core/inference/transformer/moe/router.py @@ -44,7 +44,7 @@ class Router(nn.Cell): Args: config (TransformerConfig): Configuration object for the Transformer model. """ - super(Router, self).__init__() + super().__init__() self.config = config self.num_experts = self.config.num_moe_experts self.router_dense_type = self.config.moe_router_dtype @@ -98,7 +98,7 @@ class Router(nn.Cell): """ loaded_weight = loaded_weight[:] if self.ep_group_size > 1 and not self.config.use_alltoall: - expert_idx_list = [idx for idx in range(self.num_experts)] + expert_idx_list = list(range(self.num_experts)) start_idx = self.num_experts // self.ep_group_size * self.ep_rank expert_idx_list = expert_idx_list[start_idx:] + expert_idx_list[:start_idx] loaded_weight = loaded_weight[expert_idx_list] diff --git a/mindformers/parallel_core/inference/transformer/multi_latent_attention.py b/mindformers/parallel_core/inference/transformer/multi_latent_attention.py index d6474353d..d458c2ac3 100644 --- a/mindformers/parallel_core/inference/transformer/multi_latent_attention.py +++ b/mindformers/parallel_core/inference/transformer/multi_latent_attention.py @@ -536,7 +536,7 @@ class FusedMLASelfAttention(MLASelfAttention): self.is_modelslim = quant_config.is_modelslim self.fa3_quant = quant_config.fa3_quant self.fa3_quant_layer = quant_config.fa3_quant_layer - self.is_fa3_quant_layer = (layer_number - 1) in self.fa3_quant_layer # layer_number start from 1 + self.is_fa3_quant_layer = layer_number - 1 in self.fa3_quant_layer # layer_number start from 1 self.input_layernorm_weight = None self.qkv_down_proj_input_scale = None self.q_layernorm_weight = None @@ -547,10 +547,10 @@ class FusedMLASelfAttention(MLASelfAttention): self.q_up_proj_input_offset = None self.input_format = 1 if self.fa3_quant else 0 self.use_ringmla = use_ms_custom_ops() and get_tensor_model_parallel_world_size() < 16 - import ms_custom_ops + import ms_custom_ops # pylint: disable=import-outside-toplevel self.ms_custom_ops = ms_custom_ops - self.scale_value = 1 / math.sqrt(self.config.kv_lora_rank + self.config.qk_head_dim) \ - if self.softmax_scale is None else self.softmax_scale + self.scale_value = (1 / math.sqrt(self.config.kv_lora_rank + self.config.qk_head_dim) + if self.softmax_scale is None else self.softmax_scale) self.ring_mla_mask = Tensor(np.triu(np.ones((512, 512), dtype=np.float16), 1), dtype.bfloat16) self.depend = P.Depend() self.quant = QuantV2() @@ -798,7 +798,7 @@ class FusedMLASelfAttention(MLASelfAttention): k_cache = self.transpose(key_cache.reshape(-1, self.config.kv_lora_rank // 32, \ self.config.block_size, 32), (0, 2, 1, 3)).reshape( \ -1, self.config.block_size, self.config.kv_lora_rank) - k_cache = (self.cast(k_cache, dtype.bfloat16) / self.quant_ctkv_scale) + k_cache = self.cast(k_cache, dtype.bfloat16) / self.quant_ctkv_scale else: k_cache = self.ms_custom_ops.trans_data(key_cache, transdata_type=0) v_cache = self.ms_custom_ops.trans_data(value_cache, transdata_type=0) diff --git a/mindformers/parallel_core/inference/utils.py b/mindformers/parallel_core/inference/utils.py index 3485bb01d..624629e97 100644 --- a/mindformers/parallel_core/inference/utils.py +++ b/mindformers/parallel_core/inference/utils.py @@ -76,8 +76,8 @@ def get_attn_mask_func(mask_func_type): """ if mask_func_type not in ATTNMASK_FUNC_MAP: raise KeyError("Invalid attention mask function. Supported attention " - "mask function are ['attn_mask_fill', 'attn_mask_add'] " - ", but got {}.".format(mask_func_type)) + f"mask function are ['attn_mask_fill', 'attn_mask_add'] " + f", but got {mask_func_type}.") return ATTNMASK_FUNC_MAP[mask_func_type] @@ -158,7 +158,7 @@ def create_empty_parameter(shape, *, dtype=None, device=None, **kwargs): def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" if numerator % denominator != 0: - raise ValueError("{} is not divisible by {}".format(numerator, denominator)) + raise ValueError(f"{numerator} is not divisible by {denominator}") def divide(numerator, denominator): @@ -178,9 +178,9 @@ def save_strategy_file(state_dict, strategy_file_name): Supported Platforms: ``Ascend`` """ - import os - import stat - from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy + import os # pylint: disable=import-outside-toplevel + import stat # pylint: disable=import-outside-toplevel + from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy # pylint: disable=import-outside-toplevel stra = ckpt_strategy() @@ -367,7 +367,7 @@ def use_ms_custom_ops(): """ try: # pylint: disable=W0611 - import ms_custom_ops + import ms_custom_ops # pylint: disable=import-outside-toplevel except ModuleNotFoundError: # environment need install ms_custom_ops package return False diff --git a/mindformers/parallel_core/inference/weights_utils.py b/mindformers/parallel_core/inference/weights_utils.py index c0c175851..6995794ce 100644 --- a/mindformers/parallel_core/inference/weights_utils.py +++ b/mindformers/parallel_core/inference/weights_utils.py @@ -64,7 +64,7 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): elif shard_dim == 2: loaded_weight = loaded_weight[:, :, start_idx:end_idx] else: - raise ValueError("shard_dim:{} is not supported.".format(shard_dim)) + raise ValueError(f"shard_dim:{shard_dim} is not supported.") loaded_weight = loaded_weight.astype(np.float16) \ if (str(loaded_weight.dtype) == 'bfloat16' and is_310p()) else loaded_weight return loaded_weight diff --git a/tests/st/test_ut/base_schema.json b/tests/st/test_ut/base_schema.json index 08a34b009..681315657 100644 --- a/tests/st/test_ut/base_schema.json +++ b/tests/st/test_ut/base_schema.json @@ -3975,7 +3975,7 @@ "signature": "(self, layer: mindspore.nn.cell.Cell) -> None" }, "mindformers.parallel_core.inference.tensor_parallel.grouped_layers.ColumnParallelGroupedLinear": { - "signature": "(num_local_experts: int, input_size: int, output_size: int, *, config: mindformers.parallel_core.transformer_config.TransformerConfig, bias: bool = False, gather_output: bool = False, stride: int = 1, skip_bias_add: bool = False, weight: mindspore.common.tensor.Tensor = None, is_expert: bool = True, tp_comm_buffer_name: str = None, tp_group: mindformers.parallel_core.inference.parallel_state.ProcessGroup = , quant_config: Optional[mindformers.parallel_core.inference.quantization.base_config.QuantizationConfig] = None, prefix: str = '')" + "signature": "(num_local_experts: int, input_size: int, output_size: int, *, config: mindformers.parallel_core.transformer_config.TransformerConfig, bias: bool = False, gather_output: bool = False, stride: int = 1, skip_bias_add: bool = False, weight: mindspore.common.tensor.Tensor = None, is_expert: bool = True, tp_comm_buffer_name: str = None, transpose_b: bool = False, tp_group: mindformers.parallel_core.inference.parallel_state.ProcessGroup = , quant_config: Optional[mindformers.parallel_core.inference.quantization.base_config.QuantizationConfig] = None, prefix: str = '')" }, "mindformers.parallel_core.inference.tensor_parallel.grouped_layers.ColumnParallelGroupedLinear.construct": { "signature": "(self, input_parallel, weight=None, group_list=None)" @@ -3987,7 +3987,7 @@ "signature": "(self, param, loaded_weight, shard_id=None, expert_id=None) -> None" }, "mindformers.parallel_core.inference.tensor_parallel.grouped_layers.RowParallelGroupedLinear": { - "signature": "(num_local_experts: int, input_size: int, output_size: int, *, config: mindformers.parallel_core.transformer_config.TransformerConfig, bias: bool = False, input_is_parallel: bool = True, skip_bias_add: bool = False, weight: mindspore.common.tensor.Tensor = None, stride: int = 1, delay_allreduce: bool = True, is_expert: bool = True, tp_comm_buffer_name: str = None, tp_group: mindformers.parallel_core.inference.parallel_state.ProcessGroup = , quant_config: Optional[mindformers.parallel_core.inference.quantization.base_config.QuantizationConfig] = None, prefix: str = '')" + "signature": "(num_local_experts: int, input_size: int, output_size: int, *, config: mindformers.parallel_core.transformer_config.TransformerConfig, bias: bool = False, input_is_parallel: bool = True, skip_bias_add: bool = False, weight: mindspore.common.tensor.Tensor = None, stride: int = 1, delay_allreduce: bool = True, is_expert: bool = True, tp_comm_buffer_name: str = None, transpose_b: bool = False, tp_group: mindformers.parallel_core.inference.parallel_state.ProcessGroup = , quant_config: Optional[mindformers.parallel_core.inference.quantization.base_config.QuantizationConfig] = None, prefix: str = '')" }, "mindformers.parallel_core.inference.tensor_parallel.grouped_layers.RowParallelGroupedLinear.construct": { "signature": "(self, input_, weight=None, group_list=None)"