!7632 support ds w4

Merge pull request !7632 from HighCloud/mcore_boom
This commit is contained in:
i-robot
2025-11-21 06:18:30 +00:00
committed by Gitee
16 changed files with 464 additions and 255 deletions

View File

@@ -46,6 +46,7 @@ str_to_ms_type = {
}
format_type = {
"nd": 2,
"nz": 29,
}
@@ -96,7 +97,7 @@ def reverse_dict(d: dict):
new_d = {}
for k, v in d.items():
if v in new_d:
raise ValueError(f"Different keys in dict have same values.")
raise ValueError("Different keys in dict have same values.")
new_d[v] = k
return new_d
@@ -123,16 +124,16 @@ def check_use_3d_tensor_parallel_valid(config):
if not use_3d_tensor_parallel or not is_config_valid:
return False
if not config.use_flash_attention:
raise ValueError(f"When the use_3d_tensor_parallel = True, the use_flash_attention must be True ")
raise ValueError("When the use_3d_tensor_parallel = True, the use_flash_attention must be True ")
if config.parallel_config.get_ulysses_cp_num() > 1:
raise ValueError(f"Currently, when the use_3d_tensor_parallel = True, "
raise ValueError("Currently, when the use_3d_tensor_parallel = True, "
"the cp_ds of the ulysses context parallel must be 1")
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
raise ValueError(f"Currently, when the use_3d_tensor_parallel = True, the auto parallel is not supported")
raise ValueError("Currently, when the use_3d_tensor_parallel = True, the auto parallel is not supported")
if config.moe_config is not None and config.moe_config.expert_num > 1:
raise ValueError(f"Currently, when the use_3d_tensor_parallel = True, the MoE is not supported")
raise ValueError("Currently, when the use_3d_tensor_parallel = True, the MoE is not supported")
if not config.parallel_config.use_seq_parallel:
raise ValueError(f"Currently, when the use_3d_tensor_parallel = True, the use_seq_parallel must be True")
raise ValueError("Currently, when the use_3d_tensor_parallel = True, the use_seq_parallel must be True")
if check_fine_grain_interleave_valid(config.fine_grain_interleave, config.parallel_config):
raise ValueError("Currently, when the use_3d_tensor_parallel = True, "
"the fine_grain_interleave is not supported")
@@ -141,8 +142,8 @@ def check_use_3d_tensor_parallel_valid(config):
tp_z = getattr(config, "tp_z", 1)
model_parallel = config.parallel_config.model_parallel
if model_parallel > 1 and tp_x * tp_y * tp_z != config.parallel_config.model_parallel:
raise ValueError("tp_x * tp_y * tp_z should be equal to model_parallel, but got "
"tp_x={}, tp_y={}, tp_z={}, model_parallel={}.".format(tp_x, tp_y, tp_z, model_parallel))
raise ValueError(f"tp_x * tp_y * tp_z should be equal to model_parallel, but got "
f"tp_x={tp_x}, tp_y={tp_y}, tp_z={tp_z}, model_parallel={model_parallel}.")
if model_parallel > 1:
logger.info(f"use_3d_tensor_parallel is True, (tp_x, tp_y, tp_z): ({tp_x}, {tp_y}, {tp_z})")
return True

View File

@@ -31,11 +31,13 @@ from mindformers.parallel_core.inference.base_models.common.embeddings.rope_util
from mindformers.parallel_core.inference.utils import divide, generate_padding_index
from mindformers.parallel_core.process_group_config import ModelCommProcessGroups
from mindformers.parallel_core.inference.weights_utils import (default_weight_loader, make_expert_params_mapping,
make_expert_params_mapping_with_expert_dim)
make_expert_params_mapping_with_expert_dim,
process_weights_for_310p)
from mindformers.parallel_core.inference.parallel_state import is_pipeline_last_stage
from mindformers.parallel_core.process_group_config import get_model_comm_pgs
from mindformers.tools.logger import logger
from mindformers.tools.utils import is_pynative
from mindformers.version_control import is_310p
class GPTModel(nn.Cell):
@@ -116,7 +118,7 @@ class GPTModel(nn.Cell):
model_comm_pgs: Optional[ModelCommProcessGroups] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super(GPTModel, self).__init__()
super().__init__()
self.check_support(fp16_lm_cross_entropy, rope_scaling)
self.config = config
self.quant_config = quant_config
@@ -317,7 +319,7 @@ class GPTModel(nn.Cell):
return logits
def get_params_dict(self):
params_dict = dict()
params_dict = {}
for _, module in self.modules_dict.items():
module_params = module.parameters_dict()
for param_name, param in module_params.items():
@@ -403,16 +405,19 @@ class GPTModel(nn.Cell):
num_experts = self.config.num_moe_experts
if '.weight1' in name:
weight = loaded_weight[:].reshape(num_experts, self.config.hidden_size, -1)
if '.weight2' in name:
elif '.weight2' in name:
weight = loaded_weight[:].reshape(num_experts, self.config.moe_ffn_hidden_size, -1)
else:
weight = None
for expert_id in range(num_experts):
expert_id = self.map_global_expert_id_to_local_expert_id(expert_id)
loaded_weight = weight[expert_id]
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id=None, expert_id=expert_id)
loaded_params.add(name)
if weight is not None:
for expert_id in range(num_experts):
expert_id = self.map_global_expert_id_to_local_expert_id(expert_id)
loaded_weight = weight[expert_id]
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id=None, expert_id=expert_id)
loaded_params.add(name)
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
@@ -447,6 +452,9 @@ class GPTModel(nn.Cell):
if "weight_scale_inv" in name:
continue
if is_310p():
loaded_weight = process_weights_for_310p(name, loaded_weight, params_dict, loaded_params)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue

View File

@@ -58,6 +58,14 @@ class QuantizeMethodBase(ABC):
raise RuntimeError
def process_weight_before_loading(self, param_name, loaded_weight):
"""
Process the weight before loading.
This can be used for example, to transpose weights for computation.
"""
return loaded_weight
def process_weights_after_loading(self, layer: nn.Cell) -> None:
"""
Process the weight after loading.
@@ -73,7 +81,7 @@ class QuantizationConfig(ABC):
def __init__(self):
super().__init__()
# mapping is updated by models as they initialize
self.packed_modules_mapping: dict[str, list[str]] = dict()
self.packed_modules_mapping: dict[str, list[str]] = {}
@abstractmethod
def get_name(self) -> QuantizationBackends:

View File

@@ -21,12 +21,17 @@ import mindspore
from mindspore import nn, Parameter, ops
from mindspore.common.initializer import initializer
from mindspore.ops.auto_generate import WeightQuantBatchMatmul, DynamicQuantExt, GroupedMatmulV4
try:
from ms_custom_ops import grouped_matmul_w4
GMM_310P = True
except ImportError:
GMM_310P = False
from mindformers.version_control import is_310p
from mindformers.parallel_core.inference.weights_utils import set_weight_attrs
from mindformers.parallel_core.inference.transformer.moe.experts import GroupedMLP
from mindformers.parallel_core.inference.tensor_parallel.layers import LinearMethodBase
from mindformers.parallel_core.inference.quantization import QuantizationConfig
from mindformers.parallel_core.inference.quantization.utils import np_pack_int4_to_int8, np_unpack_int8_to_int4
class A8W4DynamicLinearMethod(LinearMethodBase):
"""Linear method with A8W4 quantization."""
@@ -35,6 +40,7 @@ class A8W4DynamicLinearMethod(LinearMethodBase):
self.quant_config = quant_config
self.quant = DynamicQuantExt()
self.bias_add = ops.Add()
self.is_310p = is_310p() and GMM_310P
def create_weights(self,
layer: nn.Cell,
@@ -42,6 +48,7 @@ class A8W4DynamicLinearMethod(LinearMethodBase):
output_partition_sizes: list[int],
params_dtype,
*weight_args,
transpose_b=False,
num_local_experts=None, **extra_weight_attrs) -> Union[Parameter, None]:
output_size_per_partition = sum(output_partition_sizes)
self.output_size_per_partition = output_size_per_partition
@@ -52,17 +59,24 @@ class A8W4DynamicLinearMethod(LinearMethodBase):
raise ValueError(f"group_size should >=0 but group_size is : {group_size}")
if self.is_group_mm:
weight = None
self.matmul = GroupedMatmulV4()
if self.is_310p:
self.matmul = grouped_matmul_w4
else:
self.matmul = GroupedMatmulV4()
if not extra_weight_attrs.get('skip_weight_param_allocation', False):
weight_shape = (num_local_experts, self.input_size_per_partition, self.output_size_per_partition // 2)
weight_shape = (num_local_experts, self.output_size_per_partition,
self.input_size_per_partition // 2) \
if transpose_b else (num_local_experts, self.input_size_per_partition,
self.output_size_per_partition // 2)
weight = Parameter(initializer('ones', weight_shape, mindspore.qint4x2), requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 2})
input_dim, output_dim = (2, 1) if transpose_b else (1, 2)
set_weight_attrs(weight, {"input_dim": input_dim, "output_dim": output_dim})
set_weight_attrs(weight, extra_weight_attrs)
return weight
w_scale_shape = (num_local_experts, self.input_size_per_partition // group_size,
self.output_size_per_partition)
w_scale_dtype = mindspore.uint64
w_scale_dtype = mindspore.float32 if self.is_310p else mindspore.uint64
w_scale = Parameter(
initializer('ones', w_scale_shape, w_scale_dtype), name="w_scale", requires_grad=False)
set_weight_attrs(w_scale, {"input_dim": 1, "output_dim": 2})
@@ -98,12 +112,27 @@ class A8W4DynamicLinearMethod(LinearMethodBase):
return weight
def process_weight_before_loading(self, param_name, loaded_weight):
"""preprocess before loading weight"""
if self.is_310p and (param_name.endswith("weight1") or param_name.endswith('weight2')) \
and loaded_weight.ndim == 2:
# In 310P, transpose_b is True, so the param shape is (out_dim, in_dim//2).
# But weights is (out_dim//2, in_dim) after packing int4 to int8.
loaded_weight = loaded_weight.astype(np.uint8)
loaded_weight = loaded_weight.transpose(1, 0)
loaded_weight = np_unpack_int8_to_int4(loaded_weight)
loaded_weight = loaded_weight.transpose(1, 0)
loaded_weight = np_pack_int4_to_int8(loaded_weight)
if self.is_310p and param_name.endswith("w_scale"):
loaded_weight = loaded_weight.view(np.float32)
loaded_weight = loaded_weight[..., 0::2]
return loaded_weight
def process_weights_after_loading(self, layer: nn.Cell) -> None:
"""calculate gmm_bias"""
if isinstance(layer, GroupedMLP):
return
# int4 pack to int8
np_data = layer.weight.asnumpy().astype(np.uint8)
np_data_low = ((np_data & 0x0F) << 4).astype(np.int8) >> 4
np_data_high = ((np_data >> 4) << 4).astype(np.int8) >> 4
@@ -113,8 +142,12 @@ class A8W4DynamicLinearMethod(LinearMethodBase):
np_int4_data[:, :, ::2] = np_data_low
np_int4_data[:, :, 1::2] = np_data_high
w_scale = layer.w_scale.asnumpy()
w_scale_repeat = np.repeat(w_scale, layer.weight.shape[1] // w_scale.shape[1],
axis=1).astype(np.uint32).view(np.float32)
if self.is_310p:
np_int4_data = np_int4_data.transpose(0, 2, 1)
w_scale_repeat = np.repeat(w_scale, np_int4_data.shape[1] // w_scale.shape[1], axis=1)
else:
w_scale_repeat = np.repeat(w_scale, np_int4_data.shape[1] // w_scale.shape[1],
axis=1).astype(np.uint32).view(np.float32)
gmm_bias = 8 * np.sum(
np_int4_data.astype(np.float32) * w_scale_repeat, axis=1)
@@ -136,15 +169,24 @@ class A8W4DynamicLinearMethod(LinearMethodBase):
output_shape = qx.shape[:-1] + (self.output_size_per_partition,)
qx = qx.reshape(-1, self.input_size_per_partition)
if self.is_group_mm:
out = self.matmul([qx], [weight],
[gmm_bias], [w_scale],
None,
None,
None, [qx_scale],
group_list,
split_item=3,
group_type=0,
group_list_type=1)[0]
if self.is_310p:
group_list = ops.cast(group_list, dtype=mindspore.int32)
out = self.matmul(qx,
weight,
group_list,
gmm_bias,
qx_scale,
w_scale)
else:
out = self.matmul([qx], [weight],
[gmm_bias], [w_scale],
None,
None,
None, [qx_scale],
group_list,
split_item=3,
group_type=0,
group_list_type=1)[0]
else:
w_scale = ops.cast(w_scale, mindspore.float16)
qx = ops.cast(qx, mindspore.float16)

View File

@@ -1,126 +1,151 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""A8 dynamic W8 quantization method."""
from typing import Union
import mindspore
from mindspore import Parameter, nn, ops
from mindspore.common.initializer import initializer
from mindspore.ops.auto_generate import QuantBatchMatmul, DynamicQuantExt, GroupedMatmulV4
from mindformers.parallel_core.inference.tensor_parallel.layers import LinearMethodBase
from mindformers.parallel_core.inference.tensor_parallel.mappings import reduce_from_model_parallel_region
from mindformers.parallel_core.inference.quantization import QuantizationConfig
from mindformers.parallel_core.inference.weights_utils import set_weight_attrs
from mindformers.models.utils import format_type
class A8W8DynamicLinearMethod(LinearMethodBase):
"""Linear method with A8W8 dynamic quantization."""
def __init__(self, quant_config: QuantizationConfig):
self.quant_config = quant_config
self.quant = DynamicQuantExt()
self.bias_add = ops.Add()
def create_weights(self,
layer: nn.Cell,
input_size_per_partition: int,
output_partition_sizes: list[int],
params_dtype,
*weight_args,
num_local_experts=None, **extra_weight_attrs) -> Union[Parameter, None]:
output_size_per_partition = sum(output_partition_sizes)
self.output_size_per_partition = output_size_per_partition
self.input_size_per_partition = input_size_per_partition
self.is_group_mm = num_local_experts is not None
if self.is_group_mm:
weight = None
self.matmul = GroupedMatmulV4()
if not extra_weight_attrs.get('skip_weight_param_allocation', False):
shape = (num_local_experts, input_size_per_partition, output_size_per_partition)
weight = Parameter(initializer('ones', shape, mindspore.int8), requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 2})
set_weight_attrs(weight, extra_weight_attrs)
return weight
w_scale_shape = (num_local_experts, output_size_per_partition)
w_scale_dtype = mindspore.bfloat16 if params_dtype == mindspore.bfloat16 else mindspore.float32
w_scale = Parameter(
initializer('ones', w_scale_shape, w_scale_dtype), name="w_scale", requires_grad=False)
set_weight_attrs(w_scale, {"output_dim": 1})
set_weight_attrs(w_scale, extra_weight_attrs)
else:
self.matmul = QuantBatchMatmul(transpose_x1=False, transpose_x2=True, dtype=params_dtype)
weight_shape = (self.output_size_per_partition, self.input_size_per_partition)
weight = Parameter(initializer('ones', weight_shape, mindspore.int8), requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
set_weight_attrs(weight, extra_weight_attrs)
w_scale_shape = (output_size_per_partition,)
w_scale_dtype = mindspore.bfloat16 if params_dtype == mindspore.bfloat16 else mindspore.float32
w_scale = Parameter(
initializer('ones', w_scale_shape, w_scale_dtype), name="w_scale", requires_grad=False)
set_weight_attrs(w_scale, {"output_dim": 0})
set_weight_attrs(w_scale, extra_weight_attrs)
if layer is not None:
layer.insert_param_to_cell("weight", weight)
layer.insert_param_to_cell("w_scale", w_scale)
return weight
def process_weights_after_loading(self, layer: nn.Cell) -> None:
"""
Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
if self.is_group_mm:
layer.weight = ops.auto_generate.format_cast(layer.weight, format_type['nz'])
def apply(self,
layer: mindspore.nn.Cell,
x: mindspore.Tensor,
weight: mindspore.Tensor = None,
bias: mindspore.Parameter = None,
group_list=None, **kwargs) -> mindspore.Tensor:
if weight is None:
weight = layer.weight
w_scale = layer.w_scale
qx, qx_scale = self.quant(x, None)
qx_scale = qx_scale.reshape(-1)
output_shape = qx.shape[:-1] + (self.output_size_per_partition,)
qx = qx.reshape(-1, self.input_size_per_partition)
if self.is_group_mm:
out = self.matmul([qx], [weight],
None, [w_scale],
None,
None,
None, [qx_scale],
group_list,
split_item=3,
group_type=0,
group_list_type=1)[0]
if hasattr(layer, 'delay_allreduce'):
if not layer.delay_allreduce and not layer.skip_bias_add:
out = reduce_from_model_parallel_region(out, layer.tp_group)
else:
out = self.matmul(qx, weight, w_scale, None, None, qx_scale)
if bias is not None:
out = self.bias_add(out, bias)
out = out.reshape(output_shape)
return out
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""A8 dynamic W8 quantization method."""
from typing import Union
import mindspore
from mindspore import Parameter, nn, ops
from mindspore.common.initializer import initializer
from mindspore.ops.auto_generate import QuantBatchMatmul, DynamicQuantExt, GroupedMatmulV4
from mindformers.parallel_core.inference.tensor_parallel.layers import LinearMethodBase
from mindformers.parallel_core.inference.tensor_parallel.mappings import reduce_from_model_parallel_region
from mindformers.parallel_core.inference.quantization import QuantizationConfig
from mindformers.version_control import is_310p
from mindformers.parallel_core.inference.weights_utils import set_weight_attrs
from mindformers.models.utils import format_type
try:
from ms_custom_ops import grouped_matmul
GMM_310P = True
except ImportError:
GMM_310P = False
class A8W8DynamicLinearMethod(LinearMethodBase):
"""Linear method with A8W8 dynamic quantization."""
def __init__(self, quant_config: QuantizationConfig):
self.quant_config = quant_config
self.quant = DynamicQuantExt()
self.bias_add = ops.Add()
self.is_310p = is_310p() and GMM_310P
def create_weights(self,
layer: nn.Cell,
input_size_per_partition: int,
output_partition_sizes: list[int],
params_dtype,
*weight_args,
transpose_b=False,
num_local_experts=None, **extra_weight_attrs) -> Union[Parameter, None]:
output_size_per_partition = sum(output_partition_sizes)
self.output_size_per_partition = output_size_per_partition
self.input_size_per_partition = input_size_per_partition
self.is_group_mm = num_local_experts is not None
if self.is_group_mm:
weight = None
if self.is_310p:
self.matmul = grouped_matmul
else:
self.matmul = GroupedMatmulV4()
if not extra_weight_attrs.get('skip_weight_param_allocation', False):
shape = (num_local_experts, output_size_per_partition, input_size_per_partition) \
if transpose_b else (num_local_experts, input_size_per_partition, output_size_per_partition)
weight = Parameter(initializer('ones', shape, mindspore.int8), requires_grad=False)
input_dim, output_dim = (2, 1) if transpose_b else (1, 2)
set_weight_attrs(weight, {"input_dim": input_dim, "output_dim": output_dim})
set_weight_attrs(weight, extra_weight_attrs)
return weight
w_scale_shape = (num_local_experts, output_size_per_partition)
w_scale_dtype = mindspore.bfloat16 if params_dtype == mindspore.bfloat16 else mindspore.float32
w_scale = Parameter(
initializer('ones', w_scale_shape, w_scale_dtype), name="w_scale", requires_grad=False)
set_weight_attrs(w_scale, {"output_dim": 1})
set_weight_attrs(w_scale, extra_weight_attrs)
else:
self.matmul = QuantBatchMatmul(transpose_x1=False, transpose_x2=True, dtype=params_dtype)
weight_shape = (self.output_size_per_partition, self.input_size_per_partition)
weight = Parameter(initializer('ones', weight_shape, mindspore.int8), requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
set_weight_attrs(weight, extra_weight_attrs)
w_scale_shape = (output_size_per_partition,)
w_scale_dtype = mindspore.bfloat16 if params_dtype == mindspore.bfloat16 else mindspore.float32
w_scale = Parameter(
initializer('ones', w_scale_shape, w_scale_dtype), name="w_scale", requires_grad=False)
set_weight_attrs(w_scale, {"output_dim": 0})
set_weight_attrs(w_scale, extra_weight_attrs)
if layer is not None:
layer.insert_param_to_cell("weight", weight)
layer.insert_param_to_cell("w_scale", w_scale)
return weight
def process_weights_after_loading(self, layer: nn.Cell) -> None:
"""
Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
if self.is_group_mm:
layer.weight = ops.auto_generate.format_cast(layer.weight, format_type['nz'])
def apply(self,
layer: mindspore.nn.Cell,
x: mindspore.Tensor,
weight: mindspore.Tensor = None,
bias: mindspore.Parameter = None,
group_list=None, **kwargs) -> mindspore.Tensor:
if weight is None:
weight = layer.weight
w_scale = layer.w_scale
qx, qx_scale = self.quant(x, None)
qx_scale = qx_scale.reshape(-1)
output_shape = qx.shape[:-1] + (self.output_size_per_partition,)
qx = qx.reshape(-1, self.input_size_per_partition)
if self.is_group_mm:
if self.is_310p:
group_list = ops.cast(group_list, dtype=mindspore.int32)
out = self.matmul(qx,
weight,
group_list,
None,
w_scale,
qx_scale,
None,
False,
True)
else:
out = self.matmul([qx], [weight],
None, [w_scale],
None,
None,
None, [qx_scale],
group_list,
split_item=3,
group_type=0,
group_list_type=1)[0]
if hasattr(layer, 'delay_allreduce'):
if not layer.delay_allreduce and not layer.skip_bias_add:
out = reduce_from_model_parallel_region(out, layer.tp_group)
else:
out = self.matmul(qx, weight, w_scale, None, None, qx_scale)
if bias is not None:
out = self.bias_add(out, bias)
out = out.reshape(output_shape)
return out

View File

@@ -18,6 +18,7 @@ import mindspore
from mindspore import Tensor, Parameter, ops, nn
from mindspore.common.initializer import initializer
from mindspore.ops.auto_generate import QuantBatchMatmul, QuantV2
from mindformers.version_control import is_310p
from mindformers.parallel_core.inference.weights_utils import set_weight_attrs
from mindformers.parallel_core.inference.quantization import QuantizationConfig
from mindformers.parallel_core.inference.tensor_parallel.layers import LinearMethodBase
@@ -31,10 +32,11 @@ class A8W8LinearMethod(LinearMethodBase):
self.quant = QuantV2()
self.bias_add = ops.Add()
self.is_modelslim = self.quant_config.is_modelslim
self.is_310p = is_310p()
self.is_ms_custom_ops = False
try:
import ms_custom_ops
self.is_ms_custom_ops = True
import ms_custom_ops # pylint: disable=import-outside-toplevel
self.is_ms_custom_ops = True and not self.is_310p
self.ms_custom_ops = ms_custom_ops
except ModuleNotFoundError:
pass
@@ -56,7 +58,7 @@ class A8W8LinearMethod(LinearMethodBase):
weight_shape = (self.output_size_per_partition, self.input_size_per_partition)
weight = Parameter(initializer('ones', weight_shape, mindspore.int8), requires_grad=False)
deq_scale_shape = self.output_size_per_partition
scale_dtype = mindspore.float32
scale_dtype = mindspore.int64 if self.is_310p else mindspore.float32
deq_scale = Parameter(
initializer('ones', deq_scale_shape, scale_dtype), name="deq_scale", requires_grad=False)
shape = (self.output_size_per_partition,)
@@ -113,7 +115,7 @@ class A8W8LinearMethod(LinearMethodBase):
return
input_scale = 1 / layer.input_scale.asnumpy()
layer.input_scale = Parameter(
Tensor(input_scale, dtype=mindspore.bfloat16), name=layer.input_scale.name, requires_grad=False)
Tensor(input_scale, dtype=self.params_dtype), name=layer.input_scale.name, requires_grad=False)
def apply(self,
layer: mindspore.nn.Cell,

View File

@@ -16,6 +16,7 @@
import os
import json
import glob
import numpy as np
from mindformers.parallel_core.inference.quantization import (get_quantization_config,
QuantizationConfig)
from mindformers.models.configuration_utils import PretrainedConfig
@@ -66,3 +67,24 @@ def get_quant_config(model_config: PretrainedConfig, weight_mapping: list) -> Qu
config["weight_mapping"] = weight_mapping
config["quantization"] = quantization
return quant_cls.from_config(config)
def np_unpack_int8_to_int4(packed_data):
"""unpack int8 to int4 numpy array."""
low_nibbles = (packed_data & 0x0F).astype(np.uint8)
high_nibbles = ((packed_data >> 4) & 0x0F).astype(np.uint8)
unpacked = np.empty((*packed_data.shape[:-1], packed_data.shape[-1] * 2),
dtype=np.uint8)
unpacked[..., 0::2] = low_nibbles
unpacked[..., 1::2] = high_nibbles
return unpacked
def np_pack_int4_to_int8(np_data):
"""pack int4 numpy array to int8."""
np_data = np_data.astype(np.int8)
np_data &= 0x000F
np_data[..., 0::2] <<= 0
np_data[..., 1::2] <<= 4
np_int4_data = np_data[..., 0::2] | np_data[..., 1::2]
return np_int4_data

View File

@@ -32,7 +32,7 @@ from mindspore.common.initializer import initializer
from mindformers.parallel_core.transformer_config import TransformerConfig
from mindformers.parallel_core.inference.quantization import QuantizationConfig
from mindformers.parallel_core.inference.quantization.base_config import QuantizeMethodBase
from mindformers.parallel_core.inference.utils import divide
from mindformers.parallel_core.inference.utils import divide, cast_weight_for_310p
from mindformers.parallel_core.inference.weights_utils import set_weight_attrs
from mindformers.parallel_core.inference.tensor_parallel.mappings import (
gather_from_model_parallel_region,
@@ -41,14 +41,15 @@ from mindformers.parallel_core.inference.tensor_parallel.mappings import (
)
from mindformers.parallel_core.inference.parallel_state import ProcessGroup, default_pgs
from mindformers.parallel_core.inference.weights_utils import split_loaded_weight, deal_training_moe_weight
from mindformers.version_control import is_310p
from mindformers.models.utils import format_type
class GroupedLinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) grouped linear methods."""
@abstractmethod
def create_weights(self, layer: nn.Cell, num_local_experts: int,
input_size_per_partition: int, output_size_per_partition: int,
input_size_per_partition: int, output_partition_sizes: list[int],
params_dtype, **extra_weight_attrs):
"""Create weights for a grouped linear layer.
The weights will be set as attributes of the layer.
@@ -57,7 +58,7 @@ class GroupedLinearMethodBase(QuantizeMethodBase):
layer: The layer that is using the GroupedLinearMethodBase factory.
num_local_experts: The number of local experts.
input_size_per_partition: Sizes of the input dim on rank X.
output_size_per_partition: Sizes of the output dim on rank X.
output_partition_sizes: Sizes of the output dim on rank X.
params_dtype: Datatype of the parameters.
"""
raise NotImplementedError
@@ -86,7 +87,7 @@ class UnquantizedGroupedLinearMethod(GroupedLinearMethodBase):
def create_weights(self, layer: nn.Cell, num_local_experts: int,
input_size_per_partition: int, output_partition_sizes: list[int],
params_dtype, **extra_weight_attrs):
params_dtype, **extra_weight_attrs): # pylint: disable=arguments-renamed
weight_shape = (num_local_experts, input_size_per_partition, sum(output_partition_sizes))
weight = Parameter(initializer("zeros", weight_shape, params_dtype), requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 2})
@@ -151,10 +152,30 @@ class GroupedLinearBase(nn.Cell):
QuantizeMethodBase] = UnquantizedGroupedLinearMethod()
if quant_config is not None:
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
self.param_load_counts: Dict[str, int] = {}
def construct(self, x: Tensor, weight: Tensor, group_list: Tensor) -> Tensor:
raise NotImplementedError
def format_to_nz(self, param, merge_count=1):
"""Format parameter to nz format."""
current_count = self.param_load_counts.get(param.name, 0) + 1
self.param_load_counts[param.name] = current_count
# Only format when all shards are loaded.
if current_count == merge_count:
cast_weight = param
if param.dtype == ms.qint4x2:
# format_cast don't support qint4x2
import ms_custom_ops # pylint: disable=import-outside-toplevel
cast_weight = ms_custom_ops.type_cast(param, ms.int8)
cast_weight = ops.auto_generate.format_cast(cast_weight, format_type['nz'])
if param.dtype == ms.qint4x2:
cast_weight = ms_custom_ops.type_cast(cast_weight, ms.qint4x2)
param.set_data(cast_weight)
del self.param_load_counts[param.name]
ms.runtime.empty_cache()
class ColumnParallelGroupedLinear(GroupedLinearBase):
"""
@@ -212,21 +233,22 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
weight: Tensor = None,
is_expert: bool = True,
tp_comm_buffer_name: str = None,
transpose_b: bool = False,
tp_group: ProcessGroup = default_pgs,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""
):
super(ColumnParallelGroupedLinear, self).__init__(num_local_experts,
input_size,
output_size,
skip_bias_add,
config.params_dtype,
quant_config=quant_config,
prefix=prefix)
super().__init__(num_local_experts,
input_size,
output_size,
skip_bias_add,
config.params_dtype,
quant_config=quant_config,
prefix=prefix)
if stride > 1:
raise NotImplementedError(
"For ColumnParallelGroupedLinear, `stride > 1` is not supported for now, "
"but got `stride={}`".format(stride))
f"For ColumnParallelGroupedLinear, `stride > 1` is not supported for now, "
f"but got `stride={stride}`")
if skip_bias_add:
raise NotImplementedError(
"For ColumnParallelGroupedLinear, `skip_bias_add=True` is not supported for now."
@@ -249,6 +271,7 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
self.tp_group = tp_group
self.tensor_parallel_group_size = self.tp_group.size
self.output_size_per_partition = divide(output_size, self.tensor_parallel_group_size)
self.transpose_b = transpose_b
if self.quant_method is None:
raise ValueError("`quant_method` is not initialized in ColumnParallelGroupedLinear.")
@@ -259,6 +282,7 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
input_size_per_partition=self.input_size,
output_partition_sizes=[self.output_size_per_partition],
params_dtype=self.config.params_dtype,
transpose_b=self.transpose_b,
skip_weight_param_allocation=self.skip_weight_param_allocation,
weight_loader=self.weight_loader
)
@@ -275,7 +299,7 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
else:
self.bias = None
def construct(self, input_parallel, weight=None, group_list=None):
def construct(self, input_parallel, weight=None, group_list=None): # pylint: disable=arguments-renamed
"""Forward of ColumnParallelGroupedLinear."""
if weight is None:
weight = self.weight
@@ -297,7 +321,8 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
def sharded_state_dict(self):
"""Provide the sharded state dict."""
expert_parallel_group_size = self.config.num_moe_experts // self.num_local_experts
w_shard = (expert_parallel_group_size, 1, self.tensor_parallel_group_size)
w_shard = (expert_parallel_group_size, self.tensor_parallel_group_size, 1) \
if self.transpose_b else (expert_parallel_group_size, 1, self.tensor_parallel_group_size)
state_dict = {}
if not self.skip_weight_param_allocation:
@@ -323,7 +348,6 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
"""
if expert_id == -1:
return
if shard_id is not None:
param_output_dim = getattr(param, "output_dim", None)
shard_size = param.shape[param_output_dim] // 2
@@ -332,17 +356,20 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
# but the dimension splitting in the network is defined based on three dimensions,
# so the splitting dimension needs to be subtracted by 1.
weight_shape = len(loaded_weight.get_shape())
if weight_shape == 1 or (weight_shape != 1 and loaded_weight.get_shape()[1] == 1):
shard_dim = getattr(param, "output_dim", None) - 1
else:
shard_dim = getattr(param, "input_dim", None) - 1
weight_need_transpose = not (self.transpose_b and param.name.endswith("weight1"))
cond = weight_shape == 1 or (weight_shape != 1 and loaded_weight.get_shape()[1] == 1) \
or not weight_need_transpose
shard_dim_map = {True: "output_dim", False: "input_dim"}
shard_dim = getattr(param, shard_dim_map[cond], None) - 1
tp_rank = self.tp_group.rank
start_idx = tp_rank * shard_size
# 310P need unpack qint4x2, and do transpose, then pack back to qint4x2
loaded_weight = self.quant_method.process_weight_before_loading(param.name, loaded_weight[:])
loaded_weight = split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size)
if weight_shape == 1:
loaded_weight = loaded_weight.reshape(-1, 1)
if loaded_weight.shape[1] != 1:
if loaded_weight.shape[1] != 1 and weight_need_transpose:
# The Hugging Face weight shape is [hidden_size, moe_ffn_hidden_size]
# The shape of param is [moe_ffn_hidden_size, hidden_size]
# So must be transposed.
@@ -353,6 +380,8 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
if weight_dtype == ms.bfloat16:
loaded_weight = ms.from_numpy(loaded_weight).astype(ms.float32).asnumpy()
expected_shape = list(param.shape)
expected_shape[param_output_dim] = shard_size
if loaded_weight.shape[1] == 1:
if loaded_weight.shape[0] != shard_size:
raise ValueError(
@@ -365,15 +394,18 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
elif shard_id == "w3":
param.asnumpy()[expert_id][shard_size:shard_size + shard_size] = loaded_weight.squeeze(1)
else:
if loaded_weight.shape != (param.shape[1], shard_size):
update_indices = [slice(None)] * len(param.shape)
update_indices[0] = expert_id
if shard_id == "w1":
update_indices[param_output_dim] = slice(None, shard_size)
elif shard_id == "w3":
update_indices[param_output_dim] = slice(shard_size, 2 * shard_size)
if loaded_weight.shape != tuple(expected_shape[1:]):
raise ValueError(
f"'{param.name}.shape' should be equal to 'loaded_weight.shape',"
f" but got the shape of param is {(param.shape[1], shard_size)} and "
f" but got the shape of param is {expected_shape[1:]} and "
f"the shape of weight is{loaded_weight.shape}")
if shard_id == "w1":
param.asnumpy()[expert_id][:, :shard_size] = loaded_weight
elif shard_id == "w3":
param.asnumpy()[expert_id][:, shard_size:shard_size + shard_size] = loaded_weight
param.asnumpy()[tuple(update_indices)] = loaded_weight
else:
loaded_weight = deal_training_moe_weight(loaded_weight)
param.init_data()
@@ -383,6 +415,12 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
f" but got the shape of param is {param.data[expert_id].shape} and "
f"the shape of weight is{loaded_weight.shape}")
param[expert_id] = ms.from_numpy(loaded_weight)
self.post_process(param)
def post_process(self, param):
"""post process in weight loader"""
if is_310p() and param.name.endswith("weight1"):
self.format_to_nz(param, self.num_local_experts * 2)
class RowParallelGroupedLinear(GroupedLinearBase):
@@ -441,21 +479,22 @@ class RowParallelGroupedLinear(GroupedLinearBase):
delay_allreduce: bool = True,
is_expert: bool = True,
tp_comm_buffer_name: str = None,
transpose_b: bool = False,
tp_group: ProcessGroup = default_pgs,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""
):
super(RowParallelGroupedLinear, self).__init__(num_local_experts,
input_size,
output_size,
skip_bias_add,
config.params_dtype,
quant_config=quant_config,
prefix=prefix)
super().__init__(num_local_experts,
input_size,
output_size,
skip_bias_add,
config.params_dtype,
quant_config=quant_config,
prefix=prefix)
if stride > 1:
raise NotImplementedError(
"For RowParallelGroupedLinear, `stride > 1` is not supported for now, "
"but got `stride={}`".format(stride))
f"For RowParallelGroupedLinear, `stride > 1` is not supported for now, "
f"but got `stride={stride}`")
if not is_expert:
raise NotImplementedError(
"For RowParallelGroupedLinear, `is_expert=False` is not supported for now.")
@@ -478,6 +517,7 @@ class RowParallelGroupedLinear(GroupedLinearBase):
self.tp_group = tp_group
self.tensor_parallel_group_size = self.tp_group.size
self.input_size_per_partition = divide(input_size, self.tensor_parallel_group_size)
self.transpose_b = transpose_b
if self.quant_method is None:
raise ValueError("`quant_method` is not initialized in RowParallelGroupedLinear.")
@@ -487,6 +527,7 @@ class RowParallelGroupedLinear(GroupedLinearBase):
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=[self.output_size],
params_dtype=self.config.params_dtype,
transpose_b=self.transpose_b,
skip_weight_param_allocation=self.skip_weight_param_allocation,
weight_loader=self.weight_loader
)
@@ -502,13 +543,14 @@ class RowParallelGroupedLinear(GroupedLinearBase):
else:
self.bias = None
def construct(self, input_, weight=None, group_list=None):
def construct(self, input_, weight=None, group_list=None): # pylint: disable=arguments-renamed
"""Forward of RowParallelGroupedLinear."""
if weight is None:
weight = self.weight
else:
# Check the weight passed in is the correct shape.
expected_shape = (self.num_local_experts, self.input_size_per_partition, self.output_size)
expected_shape = (self.num_local_experts,) + (self.output_size, self.input_size_per_partition) \
if self.transpose_b else (self.input_size_per_partition, self.output_size)
if weight.shape != expected_shape:
raise ValueError(
f"supplied weight's shape is {tuple(weight.shape)}, "
@@ -531,7 +573,8 @@ class RowParallelGroupedLinear(GroupedLinearBase):
def sharded_state_dict(self):
"""Provide the sharded state dict."""
expert_parallel_group_size = self.config.num_moe_experts // self.num_local_experts
w_shard = (expert_parallel_group_size, self.tensor_parallel_group_size, 1)
w_shard = (expert_parallel_group_size, 1, self.tensor_parallel_group_size) \
if self.transpose_b else (expert_parallel_group_size, self.tensor_parallel_group_size, 1)
state_dict = {}
if not self.skip_weight_param_allocation:
@@ -567,12 +610,12 @@ class RowParallelGroupedLinear(GroupedLinearBase):
tp_rank = self.tp_group.rank
if shard_id is not None:
param_output_dim = getattr(param, "input_dim", None)
if not param.name.endswith("weight") and param_output_dim is None:
if param.name.endswith("w_scale") and len(loaded_weight.get_shape()) == 2 \
and loaded_weight.get_shape()[1] == 1:
loaded_weight = loaded_weight[:].squeeze(-1)
param.init_data()
loaded_weight = cast_weight_for_310p(loaded_weight[:])
weight_dtype = ms.from_numpy(loaded_weight[:]).dtype
if weight_dtype == ms.bfloat16:
loaded_weight = ms.from_numpy(loaded_weight[:]).astype(ms.float32).asnumpy()
@@ -583,17 +626,22 @@ class RowParallelGroupedLinear(GroupedLinearBase):
# Because this weight shape is two-dimensional,
# but the dimension splitting in the network is defined based on three dimensions,
# so the splitting dimension needs to be subtracted by 1.
shard_dim = getattr(param, "output_dim", None) - 1
shard_dim = getattr(param, "input_dim", None) - 1
weight_need_transpose = not self.transpose_b or param.name.endswith("w_scale")
if weight_need_transpose:
shard_dim = 1 - shard_dim
start_idx = tp_rank * shard_size
loaded_weight = self.quant_method.process_weight_before_loading(param.name, loaded_weight[:])
loaded_weight = split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size)
# The Hugging Face weight shape is [hidden_size, moe_ffn_hidden_size]
# The shape of param is [moe_ffn_hidden_size, hidden_size]
# So must be transposed.
loaded_weight = loaded_weight.T
if loaded_weight.shape != (shard_size, param.shape[2]):
if weight_need_transpose:
loaded_weight = loaded_weight.T
if loaded_weight.shape != param.shape[1:]:
raise ValueError(
f"'{param.name}.shape' should be equal to 'loaded_weight.shape',"
f" but got the shape of param is {(shard_size, param.shape[2])} and "
f" but got the shape of param is {(param.shape[1:])} and "
f"the shape of weight is{loaded_weight.shape}")
param.init_data()
weight_dtype = ms.from_numpy(loaded_weight).dtype
@@ -612,3 +660,5 @@ class RowParallelGroupedLinear(GroupedLinearBase):
f" but got the shape of param is {param.data[expert_id].shape} and "
f"the shape of weight is{loaded_weight.shape}")
param[expert_id] = ms.from_numpy(loaded_weight)
if is_310p() and param.name.endswith("weight2"):
self.format_to_nz(param, self.num_local_experts)

View File

@@ -169,13 +169,14 @@ class LinearBase(ms.nn.Cell):
if is_310p():
cast_weight = ops.auto_generate.format_cast(param, format_type['nz'])
else:
import ms_custom_ops
import ms_custom_ops # pylint: disable=import-outside-toplevel
cast_weight = ms_custom_ops.trans_data(param, transdata_type=1)
if move_to_cpu:
cast_weight = cast_weight.move_to("CPU")
param.set_data(cast_weight)
del self.param_load_counts[param.name]
class ColumnParallelLinear(LinearBase):
"""
The dense layer with weight sliced on second dimension by tensor parallel size.
@@ -242,7 +243,7 @@ class ColumnParallelLinear(LinearBase):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""
):
super(ColumnParallelLinear, self).__init__(
super().__init__(
input_size,
output_size,
skip_bias_add,
@@ -251,8 +252,8 @@ class ColumnParallelLinear(LinearBase):
prefix=prefix
)
if stride > 1:
raise NotImplementedError("For ColumnParallelLinear, `stride > 1` is not supported for now, "
"but got `stride={}`".format(stride))
raise NotImplementedError(f"For ColumnParallelLinear, `stride > 1` is not supported for now, "
f"but got `stride={stride}`")
if keep_master_weight_for_test:
raise NotImplementedError(
"For ColumnParallelLinear, `keep_master_weight_for_test` is not supported for now")
@@ -395,7 +396,7 @@ class ColumnParallelLinear(LinearBase):
f"'{param.name}.shape' should be equal to 'loaded_weight.shape',"
f" but got the shape of param is {param.shape} and the shape of weight is{loaded_weight.shape}")
param.init_data()
param.set_data(ms.from_numpy(loaded_weight))
param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype))
if is_310p() and param.name.endswith("weight"):
self.format_to_nz(param)
@@ -480,6 +481,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
array_id = 0
elif loaded_shard_id == 'hidden':
array_id = 1
else:
raise ValueError(f"Unknown loaded_shard_id: {loaded_shard_id}")
shard_offset = sum(self.output_sizes[:array_id]) // tp_size
shard_size = self.output_sizes[array_id] // tp_size
@@ -634,6 +637,8 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset = (self.num_heads +
self.num_kv_heads) * self.head_size
shard_size = self.num_kv_heads * self.head_size
else:
raise ValueError(f"Unknown loaded_shard_id: {loaded_shard_id}")
if loaded_shard_id == "q":
shard_id = tp_rank
@@ -739,15 +744,15 @@ class RowParallelLinear(LinearBase):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""
):
super(RowParallelLinear, self).__init__(input_size,
output_size,
skip_bias_add,
config.params_dtype,
quant_config=quant_config,
prefix=prefix)
super().__init__(input_size,
output_size,
skip_bias_add,
config.params_dtype,
quant_config=quant_config,
prefix=prefix)
if stride > 1:
raise NotImplementedError("For RowParallelLinear, `stride > 1` is not supported for now, "
"but got `stride={}`".format(stride))
raise NotImplementedError(f"For RowParallelLinear, `stride > 1` is not supported for now, "
f"but got `stride={stride}`")
if skip_bias_add:
raise NotImplementedError("For RowParallelLinear, `skip_bias_add=True` is not supported for now")
if keep_master_weight_for_test:
@@ -878,7 +883,7 @@ class RowParallelLinear(LinearBase):
raise ValueError(
f"'{param.name}.shape' should be equal to 'loaded_weight.shape',"
f" but got the shape of param is {param.shape} and the shape of weight is{loaded_weight.shape}")
param.set_data(ms.from_numpy(loaded_weight))
param.set_data(ms.from_numpy(loaded_weight)) # TODO param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype))
if is_310p() and param.name.endswith("weight"):
self.format_to_nz(param)
@@ -932,8 +937,8 @@ class ReplicatedLinear(LinearBase):
quant_config=quant_config,
prefix=prefix)
if stride > 1:
raise NotImplementedError("For ReplicatedLinear, `stride > 1` is not supported for now, "
"but got `stride={}`".format(stride))
raise NotImplementedError(f"For ReplicatedLinear, `stride > 1` is not supported for now, "
f"but got `stride={stride}`")
if skip_bias_add:
raise NotImplementedError("For ReplicatedLinear, `skip_bias_add=True` is not supported for now")
if keep_master_weight_for_test:
@@ -1061,7 +1066,9 @@ class ReplicatedLinear(LinearBase):
f" but got the shape of param is {param.shape} "
f"and the shape of weight is{loaded_weight.shape}")
param.init_data()
param.set_data(ms.from_numpy(loaded_weight))
param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype))
if is_310p() and param.name.endswith("weight"):
self.format_to_nz(param, 2)
class VocabParallelEmbedding(nn.Cell):

View File

@@ -30,6 +30,7 @@ from mindformers.parallel_core.inference.transformer.activation import get_act_f
from mindformers.parallel_core.inference.utils import divide
from mindformers.parallel_core.process_group_config import ModelCommProcessGroups, default_model_comm_pgs
from mindformers.parallel_core.inference.weights_utils import set_weight_attrs
from mindformers.version_control import is_310p
class GroupedMLP(nn.Cell):
@@ -54,6 +55,7 @@ class GroupedMLP(nn.Cell):
# use model_comm_pgs.moe_tp_group as tensor parallel group in this module.
self.tp_group = model_comm_pgs.moe_tp
self.tp_group_size = self.tp_group.size
self.transpose_b = is_310p()
ffn_hidden_size = self.config.moe_ffn_hidden_size
self.ffn_hidden_size_per_partition = divide(ffn_hidden_size, self.tp_group_size)
@@ -70,6 +72,7 @@ class GroupedMLP(nn.Cell):
num_local_experts=self.num_local_experts,
input_size_per_partition=self.input_size,
output_partition_sizes=[divide(ffn_hidden_size, self.tp_group_size)],
transpose_b=self.transpose_b,
params_dtype=self.config.params_dtype,
)
self.weight2 = self.quant_method.create_weights(
@@ -77,6 +80,7 @@ class GroupedMLP(nn.Cell):
num_local_experts=self.num_local_experts,
input_size_per_partition=self.ffn_hidden_size_per_partition,
output_partition_sizes=[self.input_size],
transpose_b=self.transpose_b,
params_dtype=self.config.params_dtype,
)
@@ -92,6 +96,7 @@ class GroupedMLP(nn.Cell):
weight=self.weight1, # Skip creating weights and use weight1 for gemm linear calculation
is_expert=True,
tp_group=self.tp_group,
transpose_b=self.transpose_b,
quant_config=quant_config,
prefix=f"{prefix}.linear_fc1",
)
@@ -112,6 +117,7 @@ class GroupedMLP(nn.Cell):
weight=self.weight2, # Skip creating weights and use weight2 for gemm linear calculation
is_expert=True,
tp_group=self.tp_group,
transpose_b=self.transpose_b,
quant_config=quant_config,
prefix=f"{prefix}.linear_fc2",
)

View File

@@ -44,7 +44,7 @@ class Router(nn.Cell):
Args:
config (TransformerConfig): Configuration object for the Transformer model.
"""
super(Router, self).__init__()
super().__init__()
self.config = config
self.num_experts = self.config.num_moe_experts
self.router_dense_type = self.config.moe_router_dtype
@@ -98,7 +98,7 @@ class Router(nn.Cell):
"""
loaded_weight = loaded_weight[:]
if self.ep_group_size > 1 and not self.config.use_alltoall:
expert_idx_list = [idx for idx in range(self.num_experts)]
expert_idx_list = list(range(self.num_experts))
start_idx = self.num_experts // self.ep_group_size * self.ep_rank
expert_idx_list = expert_idx_list[start_idx:] + expert_idx_list[:start_idx]
loaded_weight = loaded_weight[expert_idx_list]
@@ -108,7 +108,7 @@ class Router(nn.Cell):
f"'param.data.shape' should be equal to 'loaded_weight.shape',"
f" but got the shape of param is {param.shape} "
f"and the shape of weight is{loaded_weight.shape}")
param.set_data(ms.from_numpy(loaded_weight).astype(param.dtype))
param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype))
def construct(self, input_tensor: Tensor):
"""

View File

@@ -27,7 +27,7 @@ from mindspore.ops.auto_generate import (MoeInitRoutingV2,
from mindformers.parallel_core.transformer_config import TransformerConfig
from mindformers.parallel_core.process_group_config import ModelCommProcessGroups, default_model_comm_pgs
from mindformers.version_control import is_910b
from mindformers.version_control import is_910b, is_310p
class MoETokenDispatcher:
@@ -109,6 +109,7 @@ class MoEAllGatherTokenDispatcher(MoETokenDispatcher):
self.cast = ops.Cast()
self.moe_init_routing_v2 = MoeInitRoutingV2()
self.moe_token_unpermute = MoeTokenUnpermute()
self.is_310p = is_310p()
def dispatch_preprocess(self, expert_weight, routing_map):
"""Preprocess expert weight by masking out invalid experts."""
@@ -129,8 +130,8 @@ class MoEAllGatherTokenDispatcher(MoETokenDispatcher):
expert_capacity=0,
expert_num=self.num_experts,
drop_pad_mode=0,
expert_tokens_count_or_cumsum_flag=2,
expert_tokens_before_capacity_flag=True
expert_tokens_count_or_cumsum_flag=1 if self.is_310p else 2,
expert_tokens_before_capacity_flag=not self.is_310p
)
# Avoid the problem of poor performance of the split(int32) operator
@@ -143,7 +144,8 @@ class MoEAllGatherTokenDispatcher(MoETokenDispatcher):
def token_combine(self, hidden_states, expert_weight, *args):
"""Combines expert outputs."""
(tokens_per_expert,) = args
hidden_states = mint.nan_to_num(hidden_states, 0, 0, 0)
if not self.is_310p:
hidden_states = mint.nan_to_num(hidden_states, 0, 0, 0)
expert_weight = expert_weight.astype(hidden_states.dtype)
hidden_states = self.moe_token_unpermute(
permuted_tokens=hidden_states,

View File

@@ -45,7 +45,8 @@ from mindformers.parallel_core.inference.base_models.common.embeddings.yarn_rota
from mindformers.parallel_core.inference.base_models.common.embeddings.rope_utils import get_rope
from mindformers.parallel_core.process_group_config import ModelCommProcessGroups, default_model_comm_pgs
from mindformers.parallel_core.inference.weights_utils import set_weight_attrs, split_loaded_weight
from mindformers.version_control import is_310p
from mindformers.models.utils import format_type
@dataclass
class MLASelfAttentionSubmodules:
@@ -172,6 +173,7 @@ class MultiLatentAttention(Attention):
self.dim_slice_3d = P.Slice()
self.transpose = P.Transpose()
self.out_absorb_matmul = P.BatchMatMul(transpose_b=True)
self.is_310p = is_310p()
def construct(
self,
@@ -363,7 +365,10 @@ class MLASelfAttention(MultiLatentAttention):
Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
q_absorb, out_absorb = mint.split(self.linear_kv_up_proj.weight,
weight = self.linear_kv_up_proj.weight
if self.is_310p:
weight = ops.auto_generate.format_cast(weight, format_type['nd'])
q_absorb, out_absorb = mint.split(weight,
[self.num_attention_heads_per_partition * self.config.qk_head_dim,
self.num_attention_heads_per_partition * self.config.v_head_dim], -2)
self.q_absorb = q_absorb.reshape(self.num_attention_heads_per_partition,
@@ -531,7 +536,7 @@ class FusedMLASelfAttention(MLASelfAttention):
self.is_modelslim = quant_config.is_modelslim
self.fa3_quant = quant_config.fa3_quant
self.fa3_quant_layer = quant_config.fa3_quant_layer
self.is_fa3_quant_layer = (layer_number - 1) in self.fa3_quant_layer # layer_number start from 1
self.is_fa3_quant_layer = layer_number - 1 in self.fa3_quant_layer # layer_number start from 1
self.input_layernorm_weight = None
self.qkv_down_proj_input_scale = None
self.q_layernorm_weight = None
@@ -542,10 +547,10 @@ class FusedMLASelfAttention(MLASelfAttention):
self.q_up_proj_input_offset = None
self.input_format = 1 if self.fa3_quant else 0
self.use_ringmla = use_ms_custom_ops() and get_tensor_model_parallel_world_size() < 16
import ms_custom_ops
import ms_custom_ops # pylint: disable=import-outside-toplevel
self.ms_custom_ops = ms_custom_ops
self.scale_value = 1 / math.sqrt(self.config.kv_lora_rank + self.config.qk_head_dim) \
if self.softmax_scale is None else self.softmax_scale
self.scale_value = (1 / math.sqrt(self.config.kv_lora_rank + self.config.qk_head_dim)
if self.softmax_scale is None else self.softmax_scale)
self.ring_mla_mask = Tensor(np.triu(np.ones((512, 512), dtype=np.float16), 1), dtype.bfloat16)
self.depend = P.Depend()
self.quant = QuantV2()
@@ -793,7 +798,7 @@ class FusedMLASelfAttention(MLASelfAttention):
k_cache = self.transpose(key_cache.reshape(-1, self.config.kv_lora_rank // 32, \
self.config.block_size, 32), (0, 2, 1, 3)).reshape( \
-1, self.config.block_size, self.config.kv_lora_rank)
k_cache = (self.cast(k_cache, dtype.bfloat16) / self.quant_ctkv_scale)
k_cache = self.cast(k_cache, dtype.bfloat16) / self.quant_ctkv_scale
else:
k_cache = self.ms_custom_ops.trans_data(key_cache, transdata_type=0)
v_cache = self.ms_custom_ops.trans_data(value_cache, transdata_type=0)

View File

@@ -76,8 +76,8 @@ def get_attn_mask_func(mask_func_type):
"""
if mask_func_type not in ATTNMASK_FUNC_MAP:
raise KeyError("Invalid attention mask function. Supported attention "
"mask function are ['attn_mask_fill', 'attn_mask_add'] "
", but got {}.".format(mask_func_type))
f"mask function are ['attn_mask_fill', 'attn_mask_add'] "
f", but got {mask_func_type}.")
return ATTNMASK_FUNC_MAP[mask_func_type]
@@ -158,7 +158,7 @@ def create_empty_parameter(shape, *, dtype=None, device=None, **kwargs):
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
if numerator % denominator != 0:
raise ValueError("{} is not divisible by {}".format(numerator, denominator))
raise ValueError(f"{numerator} is not divisible by {denominator}")
def divide(numerator, denominator):
@@ -178,9 +178,9 @@ def save_strategy_file(state_dict, strategy_file_name):
Supported Platforms:
``Ascend``
"""
import os
import stat
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy
import os # pylint: disable=import-outside-toplevel
import stat # pylint: disable=import-outside-toplevel
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy # pylint: disable=import-outside-toplevel
stra = ckpt_strategy()
@@ -367,9 +367,21 @@ def use_ms_custom_ops():
"""
try:
# pylint: disable=W0611
import ms_custom_ops
import ms_custom_ops # pylint: disable=import-outside-toplevel
except ModuleNotFoundError:
# environment need install ms_custom_ops package
return False
return not is_310p()
def cast_weight_for_310p(loaded_weight):
"""
Casts weights to float16 for 310p.
In non-quantized scenarios, the 310P hardware only supports float16 weights.
This function converts float32 or bfloat16 weights to float16.
"""
cast_weight = (loaded_weight.astype(np.float16) if
(str(loaded_weight.dtype) == "float32" or str(
loaded_weight.dtype) == "bfloat16") else loaded_weight)
return cast_weight

View File

@@ -64,7 +64,7 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size):
elif shard_dim == 2:
loaded_weight = loaded_weight[:, :, start_idx:end_idx]
else:
raise ValueError("shard_dim:{} is not supported.".format(shard_dim))
raise ValueError(f"shard_dim:{shard_dim} is not supported.")
loaded_weight = loaded_weight.astype(np.float16) \
if (str(loaded_weight.dtype) == 'bfloat16' and is_310p()) else loaded_weight
return loaded_weight
@@ -393,3 +393,22 @@ def split_fusion_loaded_weight(loaded_weight, start_idxs, shard_sizes):
loaded_weight_parts.append(loaded_weight[start_idx:start_idx + shard_size])
perrank_ffn_weight = np.concatenate(loaded_weight_parts, axis=0)
return perrank_ffn_weight
def process_weights_for_310p(name, loaded_weight, params_dict, loaded_params):
"""
Process the loadweight for 310P.
Args:
name: The name of the weight to be loaded.
loaded_weight: The weights to be loaded.
params_dict: The dictionary of model parameters.
loaded_params: The set of already loaded parameter names.
Returns:
The processed weights.
"""
if "deq_scale" in name:
loaded_weight = loaded_weight[:].astype(np.float32).view(np.int32).astype(np.int64)
return loaded_weight

View File

@@ -3975,7 +3975,7 @@
"signature": "(self, layer: mindspore.nn.cell.Cell) -> None"
},
"mindformers.parallel_core.inference.tensor_parallel.grouped_layers.ColumnParallelGroupedLinear": {
"signature": "(num_local_experts: int, input_size: int, output_size: int, *, config: mindformers.parallel_core.transformer_config.TransformerConfig, bias: bool = False, gather_output: bool = False, stride: int = 1, skip_bias_add: bool = False, weight: mindspore.common.tensor.Tensor = None, is_expert: bool = True, tp_comm_buffer_name: str = None, tp_group: mindformers.parallel_core.inference.parallel_state.ProcessGroup = <mindformers.parallel_core.inference.parallel_state.ProcessGroup object>, quant_config: Optional[mindformers.parallel_core.inference.quantization.base_config.QuantizationConfig] = None, prefix: str = '')"
"signature": "(num_local_experts: int, input_size: int, output_size: int, *, config: mindformers.parallel_core.transformer_config.TransformerConfig, bias: bool = False, gather_output: bool = False, stride: int = 1, skip_bias_add: bool = False, weight: mindspore.common.tensor.Tensor = None, is_expert: bool = True, tp_comm_buffer_name: str = None, transpose_b: bool = False, tp_group: mindformers.parallel_core.inference.parallel_state.ProcessGroup = <mindformers.parallel_core.inference.parallel_state.ProcessGroup object>, quant_config: Optional[mindformers.parallel_core.inference.quantization.base_config.QuantizationConfig] = None, prefix: str = '')"
},
"mindformers.parallel_core.inference.tensor_parallel.grouped_layers.ColumnParallelGroupedLinear.construct": {
"signature": "(self, input_parallel, weight=None, group_list=None)"
@@ -3987,7 +3987,7 @@
"signature": "(self, param, loaded_weight, shard_id=None, expert_id=None) -> None"
},
"mindformers.parallel_core.inference.tensor_parallel.grouped_layers.RowParallelGroupedLinear": {
"signature": "(num_local_experts: int, input_size: int, output_size: int, *, config: mindformers.parallel_core.transformer_config.TransformerConfig, bias: bool = False, input_is_parallel: bool = True, skip_bias_add: bool = False, weight: mindspore.common.tensor.Tensor = None, stride: int = 1, delay_allreduce: bool = True, is_expert: bool = True, tp_comm_buffer_name: str = None, tp_group: mindformers.parallel_core.inference.parallel_state.ProcessGroup = <mindformers.parallel_core.inference.parallel_state.ProcessGroup object>, quant_config: Optional[mindformers.parallel_core.inference.quantization.base_config.QuantizationConfig] = None, prefix: str = '')"
"signature": "(num_local_experts: int, input_size: int, output_size: int, *, config: mindformers.parallel_core.transformer_config.TransformerConfig, bias: bool = False, input_is_parallel: bool = True, skip_bias_add: bool = False, weight: mindspore.common.tensor.Tensor = None, stride: int = 1, delay_allreduce: bool = True, is_expert: bool = True, tp_comm_buffer_name: str = None, transpose_b: bool = False, tp_group: mindformers.parallel_core.inference.parallel_state.ProcessGroup = <mindformers.parallel_core.inference.parallel_state.ProcessGroup object>, quant_config: Optional[mindformers.parallel_core.inference.quantization.base_config.QuantizationConfig] = None, prefix: str = '')"
},
"mindformers.parallel_core.inference.tensor_parallel.grouped_layers.RowParallelGroupedLinear.construct": {
"signature": "(self, input_, weight=None, group_list=None)"