mirror of
https://gitee.com/mindspore/mindformers.git
synced 2025-12-06 11:29:59 +08:00
@@ -46,6 +46,7 @@ str_to_ms_type = {
|
||||
}
|
||||
|
||||
format_type = {
|
||||
"nd": 2,
|
||||
"nz": 29,
|
||||
}
|
||||
|
||||
@@ -96,7 +97,7 @@ def reverse_dict(d: dict):
|
||||
new_d = {}
|
||||
for k, v in d.items():
|
||||
if v in new_d:
|
||||
raise ValueError(f"Different keys in dict have same values.")
|
||||
raise ValueError("Different keys in dict have same values.")
|
||||
new_d[v] = k
|
||||
return new_d
|
||||
|
||||
@@ -123,16 +124,16 @@ def check_use_3d_tensor_parallel_valid(config):
|
||||
if not use_3d_tensor_parallel or not is_config_valid:
|
||||
return False
|
||||
if not config.use_flash_attention:
|
||||
raise ValueError(f"When the use_3d_tensor_parallel = True, the use_flash_attention must be True ")
|
||||
raise ValueError("When the use_3d_tensor_parallel = True, the use_flash_attention must be True ")
|
||||
if config.parallel_config.get_ulysses_cp_num() > 1:
|
||||
raise ValueError(f"Currently, when the use_3d_tensor_parallel = True, "
|
||||
raise ValueError("Currently, when the use_3d_tensor_parallel = True, "
|
||||
"the cp_ds of the ulysses context parallel must be 1")
|
||||
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
|
||||
raise ValueError(f"Currently, when the use_3d_tensor_parallel = True, the auto parallel is not supported")
|
||||
raise ValueError("Currently, when the use_3d_tensor_parallel = True, the auto parallel is not supported")
|
||||
if config.moe_config is not None and config.moe_config.expert_num > 1:
|
||||
raise ValueError(f"Currently, when the use_3d_tensor_parallel = True, the MoE is not supported")
|
||||
raise ValueError("Currently, when the use_3d_tensor_parallel = True, the MoE is not supported")
|
||||
if not config.parallel_config.use_seq_parallel:
|
||||
raise ValueError(f"Currently, when the use_3d_tensor_parallel = True, the use_seq_parallel must be True")
|
||||
raise ValueError("Currently, when the use_3d_tensor_parallel = True, the use_seq_parallel must be True")
|
||||
if check_fine_grain_interleave_valid(config.fine_grain_interleave, config.parallel_config):
|
||||
raise ValueError("Currently, when the use_3d_tensor_parallel = True, "
|
||||
"the fine_grain_interleave is not supported")
|
||||
@@ -141,8 +142,8 @@ def check_use_3d_tensor_parallel_valid(config):
|
||||
tp_z = getattr(config, "tp_z", 1)
|
||||
model_parallel = config.parallel_config.model_parallel
|
||||
if model_parallel > 1 and tp_x * tp_y * tp_z != config.parallel_config.model_parallel:
|
||||
raise ValueError("tp_x * tp_y * tp_z should be equal to model_parallel, but got "
|
||||
"tp_x={}, tp_y={}, tp_z={}, model_parallel={}.".format(tp_x, tp_y, tp_z, model_parallel))
|
||||
raise ValueError(f"tp_x * tp_y * tp_z should be equal to model_parallel, but got "
|
||||
f"tp_x={tp_x}, tp_y={tp_y}, tp_z={tp_z}, model_parallel={model_parallel}.")
|
||||
if model_parallel > 1:
|
||||
logger.info(f"use_3d_tensor_parallel is True, (tp_x, tp_y, tp_z): ({tp_x}, {tp_y}, {tp_z})")
|
||||
return True
|
||||
|
||||
@@ -31,11 +31,13 @@ from mindformers.parallel_core.inference.base_models.common.embeddings.rope_util
|
||||
from mindformers.parallel_core.inference.utils import divide, generate_padding_index
|
||||
from mindformers.parallel_core.process_group_config import ModelCommProcessGroups
|
||||
from mindformers.parallel_core.inference.weights_utils import (default_weight_loader, make_expert_params_mapping,
|
||||
make_expert_params_mapping_with_expert_dim)
|
||||
make_expert_params_mapping_with_expert_dim,
|
||||
process_weights_for_310p)
|
||||
from mindformers.parallel_core.inference.parallel_state import is_pipeline_last_stage
|
||||
from mindformers.parallel_core.process_group_config import get_model_comm_pgs
|
||||
from mindformers.tools.logger import logger
|
||||
from mindformers.tools.utils import is_pynative
|
||||
from mindformers.version_control import is_310p
|
||||
|
||||
|
||||
class GPTModel(nn.Cell):
|
||||
@@ -116,7 +118,7 @@ class GPTModel(nn.Cell):
|
||||
model_comm_pgs: Optional[ModelCommProcessGroups] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super(GPTModel, self).__init__()
|
||||
super().__init__()
|
||||
self.check_support(fp16_lm_cross_entropy, rope_scaling)
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
@@ -317,7 +319,7 @@ class GPTModel(nn.Cell):
|
||||
return logits
|
||||
|
||||
def get_params_dict(self):
|
||||
params_dict = dict()
|
||||
params_dict = {}
|
||||
for _, module in self.modules_dict.items():
|
||||
module_params = module.parameters_dict()
|
||||
for param_name, param in module_params.items():
|
||||
@@ -403,16 +405,19 @@ class GPTModel(nn.Cell):
|
||||
num_experts = self.config.num_moe_experts
|
||||
if '.weight1' in name:
|
||||
weight = loaded_weight[:].reshape(num_experts, self.config.hidden_size, -1)
|
||||
if '.weight2' in name:
|
||||
elif '.weight2' in name:
|
||||
weight = loaded_weight[:].reshape(num_experts, self.config.moe_ffn_hidden_size, -1)
|
||||
else:
|
||||
weight = None
|
||||
|
||||
for expert_id in range(num_experts):
|
||||
expert_id = self.map_global_expert_id_to_local_expert_id(expert_id)
|
||||
loaded_weight = weight[expert_id]
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id=None, expert_id=expert_id)
|
||||
loaded_params.add(name)
|
||||
if weight is not None:
|
||||
for expert_id in range(num_experts):
|
||||
expert_id = self.map_global_expert_id_to_local_expert_id(expert_id)
|
||||
loaded_weight = weight[expert_id]
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id=None, expert_id=expert_id)
|
||||
loaded_params.add(name)
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
@@ -447,6 +452,9 @@ class GPTModel(nn.Cell):
|
||||
if "weight_scale_inv" in name:
|
||||
continue
|
||||
|
||||
if is_310p():
|
||||
loaded_weight = process_weights_for_310p(name, loaded_weight, params_dict, loaded_params)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@@ -58,6 +58,14 @@ class QuantizeMethodBase(ABC):
|
||||
|
||||
raise RuntimeError
|
||||
|
||||
def process_weight_before_loading(self, param_name, loaded_weight):
|
||||
"""
|
||||
Process the weight before loading.
|
||||
This can be used for example, to transpose weights for computation.
|
||||
"""
|
||||
|
||||
return loaded_weight
|
||||
|
||||
def process_weights_after_loading(self, layer: nn.Cell) -> None:
|
||||
"""
|
||||
Process the weight after loading.
|
||||
@@ -73,7 +81,7 @@ class QuantizationConfig(ABC):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# mapping is updated by models as they initialize
|
||||
self.packed_modules_mapping: dict[str, list[str]] = dict()
|
||||
self.packed_modules_mapping: dict[str, list[str]] = {}
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> QuantizationBackends:
|
||||
|
||||
@@ -21,12 +21,17 @@ import mindspore
|
||||
from mindspore import nn, Parameter, ops
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops.auto_generate import WeightQuantBatchMatmul, DynamicQuantExt, GroupedMatmulV4
|
||||
|
||||
try:
|
||||
from ms_custom_ops import grouped_matmul_w4
|
||||
GMM_310P = True
|
||||
except ImportError:
|
||||
GMM_310P = False
|
||||
from mindformers.version_control import is_310p
|
||||
from mindformers.parallel_core.inference.weights_utils import set_weight_attrs
|
||||
from mindformers.parallel_core.inference.transformer.moe.experts import GroupedMLP
|
||||
from mindformers.parallel_core.inference.tensor_parallel.layers import LinearMethodBase
|
||||
from mindformers.parallel_core.inference.quantization import QuantizationConfig
|
||||
|
||||
from mindformers.parallel_core.inference.quantization.utils import np_pack_int4_to_int8, np_unpack_int8_to_int4
|
||||
|
||||
class A8W4DynamicLinearMethod(LinearMethodBase):
|
||||
"""Linear method with A8W4 quantization."""
|
||||
@@ -35,6 +40,7 @@ class A8W4DynamicLinearMethod(LinearMethodBase):
|
||||
self.quant_config = quant_config
|
||||
self.quant = DynamicQuantExt()
|
||||
self.bias_add = ops.Add()
|
||||
self.is_310p = is_310p() and GMM_310P
|
||||
|
||||
def create_weights(self,
|
||||
layer: nn.Cell,
|
||||
@@ -42,6 +48,7 @@ class A8W4DynamicLinearMethod(LinearMethodBase):
|
||||
output_partition_sizes: list[int],
|
||||
params_dtype,
|
||||
*weight_args,
|
||||
transpose_b=False,
|
||||
num_local_experts=None, **extra_weight_attrs) -> Union[Parameter, None]:
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
self.output_size_per_partition = output_size_per_partition
|
||||
@@ -52,17 +59,24 @@ class A8W4DynamicLinearMethod(LinearMethodBase):
|
||||
raise ValueError(f"group_size should >=0 but group_size is : {group_size}")
|
||||
if self.is_group_mm:
|
||||
weight = None
|
||||
self.matmul = GroupedMatmulV4()
|
||||
if self.is_310p:
|
||||
self.matmul = grouped_matmul_w4
|
||||
else:
|
||||
self.matmul = GroupedMatmulV4()
|
||||
if not extra_weight_attrs.get('skip_weight_param_allocation', False):
|
||||
weight_shape = (num_local_experts, self.input_size_per_partition, self.output_size_per_partition // 2)
|
||||
weight_shape = (num_local_experts, self.output_size_per_partition,
|
||||
self.input_size_per_partition // 2) \
|
||||
if transpose_b else (num_local_experts, self.input_size_per_partition,
|
||||
self.output_size_per_partition // 2)
|
||||
weight = Parameter(initializer('ones', weight_shape, mindspore.qint4x2), requires_grad=False)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 2})
|
||||
input_dim, output_dim = (2, 1) if transpose_b else (1, 2)
|
||||
set_weight_attrs(weight, {"input_dim": input_dim, "output_dim": output_dim})
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
return weight
|
||||
|
||||
w_scale_shape = (num_local_experts, self.input_size_per_partition // group_size,
|
||||
self.output_size_per_partition)
|
||||
w_scale_dtype = mindspore.uint64
|
||||
w_scale_dtype = mindspore.float32 if self.is_310p else mindspore.uint64
|
||||
w_scale = Parameter(
|
||||
initializer('ones', w_scale_shape, w_scale_dtype), name="w_scale", requires_grad=False)
|
||||
set_weight_attrs(w_scale, {"input_dim": 1, "output_dim": 2})
|
||||
@@ -98,12 +112,27 @@ class A8W4DynamicLinearMethod(LinearMethodBase):
|
||||
|
||||
return weight
|
||||
|
||||
def process_weight_before_loading(self, param_name, loaded_weight):
|
||||
"""preprocess before loading weight"""
|
||||
if self.is_310p and (param_name.endswith("weight1") or param_name.endswith('weight2')) \
|
||||
and loaded_weight.ndim == 2:
|
||||
# In 310P, transpose_b is True, so the param shape is (out_dim, in_dim//2).
|
||||
# But weights is (out_dim//2, in_dim) after packing int4 to int8.
|
||||
loaded_weight = loaded_weight.astype(np.uint8)
|
||||
loaded_weight = loaded_weight.transpose(1, 0)
|
||||
loaded_weight = np_unpack_int8_to_int4(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose(1, 0)
|
||||
loaded_weight = np_pack_int4_to_int8(loaded_weight)
|
||||
if self.is_310p and param_name.endswith("w_scale"):
|
||||
loaded_weight = loaded_weight.view(np.float32)
|
||||
loaded_weight = loaded_weight[..., 0::2]
|
||||
return loaded_weight
|
||||
|
||||
|
||||
def process_weights_after_loading(self, layer: nn.Cell) -> None:
|
||||
"""calculate gmm_bias"""
|
||||
if isinstance(layer, GroupedMLP):
|
||||
return
|
||||
|
||||
# int4 pack to int8
|
||||
np_data = layer.weight.asnumpy().astype(np.uint8)
|
||||
np_data_low = ((np_data & 0x0F) << 4).astype(np.int8) >> 4
|
||||
np_data_high = ((np_data >> 4) << 4).astype(np.int8) >> 4
|
||||
@@ -113,8 +142,12 @@ class A8W4DynamicLinearMethod(LinearMethodBase):
|
||||
np_int4_data[:, :, ::2] = np_data_low
|
||||
np_int4_data[:, :, 1::2] = np_data_high
|
||||
w_scale = layer.w_scale.asnumpy()
|
||||
w_scale_repeat = np.repeat(w_scale, layer.weight.shape[1] // w_scale.shape[1],
|
||||
axis=1).astype(np.uint32).view(np.float32)
|
||||
if self.is_310p:
|
||||
np_int4_data = np_int4_data.transpose(0, 2, 1)
|
||||
w_scale_repeat = np.repeat(w_scale, np_int4_data.shape[1] // w_scale.shape[1], axis=1)
|
||||
else:
|
||||
w_scale_repeat = np.repeat(w_scale, np_int4_data.shape[1] // w_scale.shape[1],
|
||||
axis=1).astype(np.uint32).view(np.float32)
|
||||
gmm_bias = 8 * np.sum(
|
||||
np_int4_data.astype(np.float32) * w_scale_repeat, axis=1)
|
||||
|
||||
@@ -136,15 +169,24 @@ class A8W4DynamicLinearMethod(LinearMethodBase):
|
||||
output_shape = qx.shape[:-1] + (self.output_size_per_partition,)
|
||||
qx = qx.reshape(-1, self.input_size_per_partition)
|
||||
if self.is_group_mm:
|
||||
out = self.matmul([qx], [weight],
|
||||
[gmm_bias], [w_scale],
|
||||
None,
|
||||
None,
|
||||
None, [qx_scale],
|
||||
group_list,
|
||||
split_item=3,
|
||||
group_type=0,
|
||||
group_list_type=1)[0]
|
||||
if self.is_310p:
|
||||
group_list = ops.cast(group_list, dtype=mindspore.int32)
|
||||
out = self.matmul(qx,
|
||||
weight,
|
||||
group_list,
|
||||
gmm_bias,
|
||||
qx_scale,
|
||||
w_scale)
|
||||
else:
|
||||
out = self.matmul([qx], [weight],
|
||||
[gmm_bias], [w_scale],
|
||||
None,
|
||||
None,
|
||||
None, [qx_scale],
|
||||
group_list,
|
||||
split_item=3,
|
||||
group_type=0,
|
||||
group_list_type=1)[0]
|
||||
else:
|
||||
w_scale = ops.cast(w_scale, mindspore.float16)
|
||||
qx = ops.cast(qx, mindspore.float16)
|
||||
|
||||
@@ -24,8 +24,14 @@ from mindspore.ops.auto_generate import QuantBatchMatmul, DynamicQuantExt, Group
|
||||
from mindformers.parallel_core.inference.tensor_parallel.layers import LinearMethodBase
|
||||
from mindformers.parallel_core.inference.tensor_parallel.mappings import reduce_from_model_parallel_region
|
||||
from mindformers.parallel_core.inference.quantization import QuantizationConfig
|
||||
from mindformers.version_control import is_310p
|
||||
from mindformers.parallel_core.inference.weights_utils import set_weight_attrs
|
||||
from mindformers.models.utils import format_type
|
||||
try:
|
||||
from ms_custom_ops import grouped_matmul
|
||||
GMM_310P = True
|
||||
except ImportError:
|
||||
GMM_310P = False
|
||||
|
||||
|
||||
class A8W8DynamicLinearMethod(LinearMethodBase):
|
||||
@@ -35,6 +41,7 @@ class A8W8DynamicLinearMethod(LinearMethodBase):
|
||||
self.quant_config = quant_config
|
||||
self.quant = DynamicQuantExt()
|
||||
self.bias_add = ops.Add()
|
||||
self.is_310p = is_310p() and GMM_310P
|
||||
|
||||
def create_weights(self,
|
||||
layer: nn.Cell,
|
||||
@@ -42,6 +49,7 @@ class A8W8DynamicLinearMethod(LinearMethodBase):
|
||||
output_partition_sizes: list[int],
|
||||
params_dtype,
|
||||
*weight_args,
|
||||
transpose_b=False,
|
||||
num_local_experts=None, **extra_weight_attrs) -> Union[Parameter, None]:
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
self.output_size_per_partition = output_size_per_partition
|
||||
@@ -50,11 +58,16 @@ class A8W8DynamicLinearMethod(LinearMethodBase):
|
||||
|
||||
if self.is_group_mm:
|
||||
weight = None
|
||||
self.matmul = GroupedMatmulV4()
|
||||
if self.is_310p:
|
||||
self.matmul = grouped_matmul
|
||||
else:
|
||||
self.matmul = GroupedMatmulV4()
|
||||
if not extra_weight_attrs.get('skip_weight_param_allocation', False):
|
||||
shape = (num_local_experts, input_size_per_partition, output_size_per_partition)
|
||||
shape = (num_local_experts, output_size_per_partition, input_size_per_partition) \
|
||||
if transpose_b else (num_local_experts, input_size_per_partition, output_size_per_partition)
|
||||
weight = Parameter(initializer('ones', shape, mindspore.int8), requires_grad=False)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 2})
|
||||
input_dim, output_dim = (2, 1) if transpose_b else (1, 2)
|
||||
set_weight_attrs(weight, {"input_dim": input_dim, "output_dim": output_dim})
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
return weight
|
||||
|
||||
@@ -106,15 +119,27 @@ class A8W8DynamicLinearMethod(LinearMethodBase):
|
||||
output_shape = qx.shape[:-1] + (self.output_size_per_partition,)
|
||||
qx = qx.reshape(-1, self.input_size_per_partition)
|
||||
if self.is_group_mm:
|
||||
out = self.matmul([qx], [weight],
|
||||
None, [w_scale],
|
||||
None,
|
||||
None,
|
||||
None, [qx_scale],
|
||||
group_list,
|
||||
split_item=3,
|
||||
group_type=0,
|
||||
group_list_type=1)[0]
|
||||
if self.is_310p:
|
||||
group_list = ops.cast(group_list, dtype=mindspore.int32)
|
||||
out = self.matmul(qx,
|
||||
weight,
|
||||
group_list,
|
||||
None,
|
||||
w_scale,
|
||||
qx_scale,
|
||||
None,
|
||||
False,
|
||||
True)
|
||||
else:
|
||||
out = self.matmul([qx], [weight],
|
||||
None, [w_scale],
|
||||
None,
|
||||
None,
|
||||
None, [qx_scale],
|
||||
group_list,
|
||||
split_item=3,
|
||||
group_type=0,
|
||||
group_list_type=1)[0]
|
||||
if hasattr(layer, 'delay_allreduce'):
|
||||
if not layer.delay_allreduce and not layer.skip_bias_add:
|
||||
out = reduce_from_model_parallel_region(out, layer.tp_group)
|
||||
|
||||
@@ -18,6 +18,7 @@ import mindspore
|
||||
from mindspore import Tensor, Parameter, ops, nn
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops.auto_generate import QuantBatchMatmul, QuantV2
|
||||
from mindformers.version_control import is_310p
|
||||
from mindformers.parallel_core.inference.weights_utils import set_weight_attrs
|
||||
from mindformers.parallel_core.inference.quantization import QuantizationConfig
|
||||
from mindformers.parallel_core.inference.tensor_parallel.layers import LinearMethodBase
|
||||
@@ -31,10 +32,11 @@ class A8W8LinearMethod(LinearMethodBase):
|
||||
self.quant = QuantV2()
|
||||
self.bias_add = ops.Add()
|
||||
self.is_modelslim = self.quant_config.is_modelslim
|
||||
self.is_310p = is_310p()
|
||||
self.is_ms_custom_ops = False
|
||||
try:
|
||||
import ms_custom_ops
|
||||
self.is_ms_custom_ops = True
|
||||
import ms_custom_ops # pylint: disable=import-outside-toplevel
|
||||
self.is_ms_custom_ops = True and not self.is_310p
|
||||
self.ms_custom_ops = ms_custom_ops
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
@@ -56,7 +58,7 @@ class A8W8LinearMethod(LinearMethodBase):
|
||||
weight_shape = (self.output_size_per_partition, self.input_size_per_partition)
|
||||
weight = Parameter(initializer('ones', weight_shape, mindspore.int8), requires_grad=False)
|
||||
deq_scale_shape = self.output_size_per_partition
|
||||
scale_dtype = mindspore.float32
|
||||
scale_dtype = mindspore.int64 if self.is_310p else mindspore.float32
|
||||
deq_scale = Parameter(
|
||||
initializer('ones', deq_scale_shape, scale_dtype), name="deq_scale", requires_grad=False)
|
||||
shape = (self.output_size_per_partition,)
|
||||
@@ -113,7 +115,7 @@ class A8W8LinearMethod(LinearMethodBase):
|
||||
return
|
||||
input_scale = 1 / layer.input_scale.asnumpy()
|
||||
layer.input_scale = Parameter(
|
||||
Tensor(input_scale, dtype=mindspore.bfloat16), name=layer.input_scale.name, requires_grad=False)
|
||||
Tensor(input_scale, dtype=self.params_dtype), name=layer.input_scale.name, requires_grad=False)
|
||||
|
||||
def apply(self,
|
||||
layer: mindspore.nn.Cell,
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
import os
|
||||
import json
|
||||
import glob
|
||||
import numpy as np
|
||||
from mindformers.parallel_core.inference.quantization import (get_quantization_config,
|
||||
QuantizationConfig)
|
||||
from mindformers.models.configuration_utils import PretrainedConfig
|
||||
@@ -66,3 +67,24 @@ def get_quant_config(model_config: PretrainedConfig, weight_mapping: list) -> Qu
|
||||
config["weight_mapping"] = weight_mapping
|
||||
config["quantization"] = quantization
|
||||
return quant_cls.from_config(config)
|
||||
|
||||
def np_unpack_int8_to_int4(packed_data):
|
||||
"""unpack int8 to int4 numpy array."""
|
||||
low_nibbles = (packed_data & 0x0F).astype(np.uint8)
|
||||
high_nibbles = ((packed_data >> 4) & 0x0F).astype(np.uint8)
|
||||
|
||||
unpacked = np.empty((*packed_data.shape[:-1], packed_data.shape[-1] * 2),
|
||||
dtype=np.uint8)
|
||||
unpacked[..., 0::2] = low_nibbles
|
||||
unpacked[..., 1::2] = high_nibbles
|
||||
|
||||
return unpacked
|
||||
|
||||
def np_pack_int4_to_int8(np_data):
|
||||
"""pack int4 numpy array to int8."""
|
||||
np_data = np_data.astype(np.int8)
|
||||
np_data &= 0x000F
|
||||
np_data[..., 0::2] <<= 0
|
||||
np_data[..., 1::2] <<= 4
|
||||
np_int4_data = np_data[..., 0::2] | np_data[..., 1::2]
|
||||
return np_int4_data
|
||||
|
||||
@@ -32,7 +32,7 @@ from mindspore.common.initializer import initializer
|
||||
from mindformers.parallel_core.transformer_config import TransformerConfig
|
||||
from mindformers.parallel_core.inference.quantization import QuantizationConfig
|
||||
from mindformers.parallel_core.inference.quantization.base_config import QuantizeMethodBase
|
||||
from mindformers.parallel_core.inference.utils import divide
|
||||
from mindformers.parallel_core.inference.utils import divide, cast_weight_for_310p
|
||||
from mindformers.parallel_core.inference.weights_utils import set_weight_attrs
|
||||
from mindformers.parallel_core.inference.tensor_parallel.mappings import (
|
||||
gather_from_model_parallel_region,
|
||||
@@ -41,14 +41,15 @@ from mindformers.parallel_core.inference.tensor_parallel.mappings import (
|
||||
)
|
||||
from mindformers.parallel_core.inference.parallel_state import ProcessGroup, default_pgs
|
||||
from mindformers.parallel_core.inference.weights_utils import split_loaded_weight, deal_training_moe_weight
|
||||
|
||||
from mindformers.version_control import is_310p
|
||||
from mindformers.models.utils import format_type
|
||||
|
||||
class GroupedLinearMethodBase(QuantizeMethodBase):
|
||||
"""Base class for different (maybe quantized) grouped linear methods."""
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, layer: nn.Cell, num_local_experts: int,
|
||||
input_size_per_partition: int, output_size_per_partition: int,
|
||||
input_size_per_partition: int, output_partition_sizes: list[int],
|
||||
params_dtype, **extra_weight_attrs):
|
||||
"""Create weights for a grouped linear layer.
|
||||
The weights will be set as attributes of the layer.
|
||||
@@ -57,7 +58,7 @@ class GroupedLinearMethodBase(QuantizeMethodBase):
|
||||
layer: The layer that is using the GroupedLinearMethodBase factory.
|
||||
num_local_experts: The number of local experts.
|
||||
input_size_per_partition: Sizes of the input dim on rank X.
|
||||
output_size_per_partition: Sizes of the output dim on rank X.
|
||||
output_partition_sizes: Sizes of the output dim on rank X.
|
||||
params_dtype: Datatype of the parameters.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -86,7 +87,7 @@ class UnquantizedGroupedLinearMethod(GroupedLinearMethodBase):
|
||||
|
||||
def create_weights(self, layer: nn.Cell, num_local_experts: int,
|
||||
input_size_per_partition: int, output_partition_sizes: list[int],
|
||||
params_dtype, **extra_weight_attrs):
|
||||
params_dtype, **extra_weight_attrs): # pylint: disable=arguments-renamed
|
||||
weight_shape = (num_local_experts, input_size_per_partition, sum(output_partition_sizes))
|
||||
weight = Parameter(initializer("zeros", weight_shape, params_dtype), requires_grad=False)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 2})
|
||||
@@ -151,10 +152,30 @@ class GroupedLinearBase(nn.Cell):
|
||||
QuantizeMethodBase] = UnquantizedGroupedLinearMethod()
|
||||
if quant_config is not None:
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
self.param_load_counts: Dict[str, int] = {}
|
||||
|
||||
def construct(self, x: Tensor, weight: Tensor, group_list: Tensor) -> Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def format_to_nz(self, param, merge_count=1):
|
||||
"""Format parameter to nz format."""
|
||||
current_count = self.param_load_counts.get(param.name, 0) + 1
|
||||
self.param_load_counts[param.name] = current_count
|
||||
|
||||
# Only format when all shards are loaded.
|
||||
if current_count == merge_count:
|
||||
cast_weight = param
|
||||
if param.dtype == ms.qint4x2:
|
||||
# format_cast don't support qint4x2
|
||||
import ms_custom_ops # pylint: disable=import-outside-toplevel
|
||||
cast_weight = ms_custom_ops.type_cast(param, ms.int8)
|
||||
cast_weight = ops.auto_generate.format_cast(cast_weight, format_type['nz'])
|
||||
if param.dtype == ms.qint4x2:
|
||||
cast_weight = ms_custom_ops.type_cast(cast_weight, ms.qint4x2)
|
||||
param.set_data(cast_weight)
|
||||
del self.param_load_counts[param.name]
|
||||
ms.runtime.empty_cache()
|
||||
|
||||
|
||||
class ColumnParallelGroupedLinear(GroupedLinearBase):
|
||||
"""
|
||||
@@ -212,21 +233,22 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
|
||||
weight: Tensor = None,
|
||||
is_expert: bool = True,
|
||||
tp_comm_buffer_name: str = None,
|
||||
transpose_b: bool = False,
|
||||
tp_group: ProcessGroup = default_pgs,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""
|
||||
):
|
||||
super(ColumnParallelGroupedLinear, self).__init__(num_local_experts,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
config.params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
super().__init__(num_local_experts,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
config.params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
if stride > 1:
|
||||
raise NotImplementedError(
|
||||
"For ColumnParallelGroupedLinear, `stride > 1` is not supported for now, "
|
||||
"but got `stride={}`".format(stride))
|
||||
f"For ColumnParallelGroupedLinear, `stride > 1` is not supported for now, "
|
||||
f"but got `stride={stride}`")
|
||||
if skip_bias_add:
|
||||
raise NotImplementedError(
|
||||
"For ColumnParallelGroupedLinear, `skip_bias_add=True` is not supported for now."
|
||||
@@ -249,6 +271,7 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
|
||||
self.tp_group = tp_group
|
||||
self.tensor_parallel_group_size = self.tp_group.size
|
||||
self.output_size_per_partition = divide(output_size, self.tensor_parallel_group_size)
|
||||
self.transpose_b = transpose_b
|
||||
if self.quant_method is None:
|
||||
raise ValueError("`quant_method` is not initialized in ColumnParallelGroupedLinear.")
|
||||
|
||||
@@ -259,6 +282,7 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
|
||||
input_size_per_partition=self.input_size,
|
||||
output_partition_sizes=[self.output_size_per_partition],
|
||||
params_dtype=self.config.params_dtype,
|
||||
transpose_b=self.transpose_b,
|
||||
skip_weight_param_allocation=self.skip_weight_param_allocation,
|
||||
weight_loader=self.weight_loader
|
||||
)
|
||||
@@ -275,7 +299,7 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def construct(self, input_parallel, weight=None, group_list=None):
|
||||
def construct(self, input_parallel, weight=None, group_list=None): # pylint: disable=arguments-renamed
|
||||
"""Forward of ColumnParallelGroupedLinear."""
|
||||
if weight is None:
|
||||
weight = self.weight
|
||||
@@ -297,7 +321,8 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
|
||||
def sharded_state_dict(self):
|
||||
"""Provide the sharded state dict."""
|
||||
expert_parallel_group_size = self.config.num_moe_experts // self.num_local_experts
|
||||
w_shard = (expert_parallel_group_size, 1, self.tensor_parallel_group_size)
|
||||
w_shard = (expert_parallel_group_size, self.tensor_parallel_group_size, 1) \
|
||||
if self.transpose_b else (expert_parallel_group_size, 1, self.tensor_parallel_group_size)
|
||||
|
||||
state_dict = {}
|
||||
if not self.skip_weight_param_allocation:
|
||||
@@ -323,7 +348,6 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
|
||||
"""
|
||||
if expert_id == -1:
|
||||
return
|
||||
|
||||
if shard_id is not None:
|
||||
param_output_dim = getattr(param, "output_dim", None)
|
||||
shard_size = param.shape[param_output_dim] // 2
|
||||
@@ -332,17 +356,20 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
|
||||
# but the dimension splitting in the network is defined based on three dimensions,
|
||||
# so the splitting dimension needs to be subtracted by 1.
|
||||
weight_shape = len(loaded_weight.get_shape())
|
||||
if weight_shape == 1 or (weight_shape != 1 and loaded_weight.get_shape()[1] == 1):
|
||||
shard_dim = getattr(param, "output_dim", None) - 1
|
||||
else:
|
||||
shard_dim = getattr(param, "input_dim", None) - 1
|
||||
weight_need_transpose = not (self.transpose_b and param.name.endswith("weight1"))
|
||||
cond = weight_shape == 1 or (weight_shape != 1 and loaded_weight.get_shape()[1] == 1) \
|
||||
or not weight_need_transpose
|
||||
shard_dim_map = {True: "output_dim", False: "input_dim"}
|
||||
shard_dim = getattr(param, shard_dim_map[cond], None) - 1
|
||||
|
||||
tp_rank = self.tp_group.rank
|
||||
start_idx = tp_rank * shard_size
|
||||
# 310P need unpack qint4x2, and do transpose, then pack back to qint4x2
|
||||
loaded_weight = self.quant_method.process_weight_before_loading(param.name, loaded_weight[:])
|
||||
loaded_weight = split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size)
|
||||
if weight_shape == 1:
|
||||
loaded_weight = loaded_weight.reshape(-1, 1)
|
||||
if loaded_weight.shape[1] != 1:
|
||||
if loaded_weight.shape[1] != 1 and weight_need_transpose:
|
||||
# The Hugging Face weight shape is [hidden_size, moe_ffn_hidden_size]
|
||||
# The shape of param is [moe_ffn_hidden_size, hidden_size]
|
||||
# So must be transposed.
|
||||
@@ -353,6 +380,8 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
|
||||
if weight_dtype == ms.bfloat16:
|
||||
loaded_weight = ms.from_numpy(loaded_weight).astype(ms.float32).asnumpy()
|
||||
|
||||
expected_shape = list(param.shape)
|
||||
expected_shape[param_output_dim] = shard_size
|
||||
if loaded_weight.shape[1] == 1:
|
||||
if loaded_weight.shape[0] != shard_size:
|
||||
raise ValueError(
|
||||
@@ -365,15 +394,18 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
|
||||
elif shard_id == "w3":
|
||||
param.asnumpy()[expert_id][shard_size:shard_size + shard_size] = loaded_weight.squeeze(1)
|
||||
else:
|
||||
if loaded_weight.shape != (param.shape[1], shard_size):
|
||||
update_indices = [slice(None)] * len(param.shape)
|
||||
update_indices[0] = expert_id
|
||||
if shard_id == "w1":
|
||||
update_indices[param_output_dim] = slice(None, shard_size)
|
||||
elif shard_id == "w3":
|
||||
update_indices[param_output_dim] = slice(shard_size, 2 * shard_size)
|
||||
if loaded_weight.shape != tuple(expected_shape[1:]):
|
||||
raise ValueError(
|
||||
f"'{param.name}.shape' should be equal to 'loaded_weight.shape',"
|
||||
f" but got the shape of param is {(param.shape[1], shard_size)} and "
|
||||
f" but got the shape of param is {expected_shape[1:]} and "
|
||||
f"the shape of weight is{loaded_weight.shape}")
|
||||
if shard_id == "w1":
|
||||
param.asnumpy()[expert_id][:, :shard_size] = loaded_weight
|
||||
elif shard_id == "w3":
|
||||
param.asnumpy()[expert_id][:, shard_size:shard_size + shard_size] = loaded_weight
|
||||
param.asnumpy()[tuple(update_indices)] = loaded_weight
|
||||
else:
|
||||
loaded_weight = deal_training_moe_weight(loaded_weight)
|
||||
param.init_data()
|
||||
@@ -383,6 +415,12 @@ class ColumnParallelGroupedLinear(GroupedLinearBase):
|
||||
f" but got the shape of param is {param.data[expert_id].shape} and "
|
||||
f"the shape of weight is{loaded_weight.shape}")
|
||||
param[expert_id] = ms.from_numpy(loaded_weight)
|
||||
self.post_process(param)
|
||||
|
||||
def post_process(self, param):
|
||||
"""post process in weight loader"""
|
||||
if is_310p() and param.name.endswith("weight1"):
|
||||
self.format_to_nz(param, self.num_local_experts * 2)
|
||||
|
||||
|
||||
class RowParallelGroupedLinear(GroupedLinearBase):
|
||||
@@ -441,21 +479,22 @@ class RowParallelGroupedLinear(GroupedLinearBase):
|
||||
delay_allreduce: bool = True,
|
||||
is_expert: bool = True,
|
||||
tp_comm_buffer_name: str = None,
|
||||
transpose_b: bool = False,
|
||||
tp_group: ProcessGroup = default_pgs,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""
|
||||
):
|
||||
super(RowParallelGroupedLinear, self).__init__(num_local_experts,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
config.params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
super().__init__(num_local_experts,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
config.params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
if stride > 1:
|
||||
raise NotImplementedError(
|
||||
"For RowParallelGroupedLinear, `stride > 1` is not supported for now, "
|
||||
"but got `stride={}`".format(stride))
|
||||
f"For RowParallelGroupedLinear, `stride > 1` is not supported for now, "
|
||||
f"but got `stride={stride}`")
|
||||
if not is_expert:
|
||||
raise NotImplementedError(
|
||||
"For RowParallelGroupedLinear, `is_expert=False` is not supported for now.")
|
||||
@@ -478,6 +517,7 @@ class RowParallelGroupedLinear(GroupedLinearBase):
|
||||
self.tp_group = tp_group
|
||||
self.tensor_parallel_group_size = self.tp_group.size
|
||||
self.input_size_per_partition = divide(input_size, self.tensor_parallel_group_size)
|
||||
self.transpose_b = transpose_b
|
||||
if self.quant_method is None:
|
||||
raise ValueError("`quant_method` is not initialized in RowParallelGroupedLinear.")
|
||||
|
||||
@@ -487,6 +527,7 @@ class RowParallelGroupedLinear(GroupedLinearBase):
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=[self.output_size],
|
||||
params_dtype=self.config.params_dtype,
|
||||
transpose_b=self.transpose_b,
|
||||
skip_weight_param_allocation=self.skip_weight_param_allocation,
|
||||
weight_loader=self.weight_loader
|
||||
)
|
||||
@@ -502,13 +543,14 @@ class RowParallelGroupedLinear(GroupedLinearBase):
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def construct(self, input_, weight=None, group_list=None):
|
||||
def construct(self, input_, weight=None, group_list=None): # pylint: disable=arguments-renamed
|
||||
"""Forward of RowParallelGroupedLinear."""
|
||||
if weight is None:
|
||||
weight = self.weight
|
||||
else:
|
||||
# Check the weight passed in is the correct shape.
|
||||
expected_shape = (self.num_local_experts, self.input_size_per_partition, self.output_size)
|
||||
expected_shape = (self.num_local_experts,) + (self.output_size, self.input_size_per_partition) \
|
||||
if self.transpose_b else (self.input_size_per_partition, self.output_size)
|
||||
if weight.shape != expected_shape:
|
||||
raise ValueError(
|
||||
f"supplied weight's shape is {tuple(weight.shape)}, "
|
||||
@@ -531,7 +573,8 @@ class RowParallelGroupedLinear(GroupedLinearBase):
|
||||
def sharded_state_dict(self):
|
||||
"""Provide the sharded state dict."""
|
||||
expert_parallel_group_size = self.config.num_moe_experts // self.num_local_experts
|
||||
w_shard = (expert_parallel_group_size, self.tensor_parallel_group_size, 1)
|
||||
w_shard = (expert_parallel_group_size, 1, self.tensor_parallel_group_size) \
|
||||
if self.transpose_b else (expert_parallel_group_size, self.tensor_parallel_group_size, 1)
|
||||
|
||||
state_dict = {}
|
||||
if not self.skip_weight_param_allocation:
|
||||
@@ -567,12 +610,12 @@ class RowParallelGroupedLinear(GroupedLinearBase):
|
||||
tp_rank = self.tp_group.rank
|
||||
if shard_id is not None:
|
||||
param_output_dim = getattr(param, "input_dim", None)
|
||||
|
||||
if not param.name.endswith("weight") and param_output_dim is None:
|
||||
if param.name.endswith("w_scale") and len(loaded_weight.get_shape()) == 2 \
|
||||
and loaded_weight.get_shape()[1] == 1:
|
||||
loaded_weight = loaded_weight[:].squeeze(-1)
|
||||
param.init_data()
|
||||
loaded_weight = cast_weight_for_310p(loaded_weight[:])
|
||||
weight_dtype = ms.from_numpy(loaded_weight[:]).dtype
|
||||
if weight_dtype == ms.bfloat16:
|
||||
loaded_weight = ms.from_numpy(loaded_weight[:]).astype(ms.float32).asnumpy()
|
||||
@@ -583,17 +626,22 @@ class RowParallelGroupedLinear(GroupedLinearBase):
|
||||
# Because this weight shape is two-dimensional,
|
||||
# but the dimension splitting in the network is defined based on three dimensions,
|
||||
# so the splitting dimension needs to be subtracted by 1.
|
||||
shard_dim = getattr(param, "output_dim", None) - 1
|
||||
shard_dim = getattr(param, "input_dim", None) - 1
|
||||
weight_need_transpose = not self.transpose_b or param.name.endswith("w_scale")
|
||||
if weight_need_transpose:
|
||||
shard_dim = 1 - shard_dim
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = self.quant_method.process_weight_before_loading(param.name, loaded_weight[:])
|
||||
loaded_weight = split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size)
|
||||
# The Hugging Face weight shape is [hidden_size, moe_ffn_hidden_size]
|
||||
# The shape of param is [moe_ffn_hidden_size, hidden_size]
|
||||
# So must be transposed.
|
||||
loaded_weight = loaded_weight.T
|
||||
if loaded_weight.shape != (shard_size, param.shape[2]):
|
||||
if weight_need_transpose:
|
||||
loaded_weight = loaded_weight.T
|
||||
if loaded_weight.shape != param.shape[1:]:
|
||||
raise ValueError(
|
||||
f"'{param.name}.shape' should be equal to 'loaded_weight.shape',"
|
||||
f" but got the shape of param is {(shard_size, param.shape[2])} and "
|
||||
f" but got the shape of param is {(param.shape[1:])} and "
|
||||
f"the shape of weight is{loaded_weight.shape}")
|
||||
param.init_data()
|
||||
weight_dtype = ms.from_numpy(loaded_weight).dtype
|
||||
@@ -612,3 +660,5 @@ class RowParallelGroupedLinear(GroupedLinearBase):
|
||||
f" but got the shape of param is {param.data[expert_id].shape} and "
|
||||
f"the shape of weight is{loaded_weight.shape}")
|
||||
param[expert_id] = ms.from_numpy(loaded_weight)
|
||||
if is_310p() and param.name.endswith("weight2"):
|
||||
self.format_to_nz(param, self.num_local_experts)
|
||||
|
||||
@@ -169,13 +169,14 @@ class LinearBase(ms.nn.Cell):
|
||||
if is_310p():
|
||||
cast_weight = ops.auto_generate.format_cast(param, format_type['nz'])
|
||||
else:
|
||||
import ms_custom_ops
|
||||
import ms_custom_ops # pylint: disable=import-outside-toplevel
|
||||
cast_weight = ms_custom_ops.trans_data(param, transdata_type=1)
|
||||
if move_to_cpu:
|
||||
cast_weight = cast_weight.move_to("CPU")
|
||||
param.set_data(cast_weight)
|
||||
del self.param_load_counts[param.name]
|
||||
|
||||
|
||||
class ColumnParallelLinear(LinearBase):
|
||||
"""
|
||||
The dense layer with weight sliced on second dimension by tensor parallel size.
|
||||
@@ -242,7 +243,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""
|
||||
):
|
||||
super(ColumnParallelLinear, self).__init__(
|
||||
super().__init__(
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
@@ -251,8 +252,8 @@ class ColumnParallelLinear(LinearBase):
|
||||
prefix=prefix
|
||||
)
|
||||
if stride > 1:
|
||||
raise NotImplementedError("For ColumnParallelLinear, `stride > 1` is not supported for now, "
|
||||
"but got `stride={}`".format(stride))
|
||||
raise NotImplementedError(f"For ColumnParallelLinear, `stride > 1` is not supported for now, "
|
||||
f"but got `stride={stride}`")
|
||||
if keep_master_weight_for_test:
|
||||
raise NotImplementedError(
|
||||
"For ColumnParallelLinear, `keep_master_weight_for_test` is not supported for now")
|
||||
@@ -395,7 +396,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
f"'{param.name}.shape' should be equal to 'loaded_weight.shape',"
|
||||
f" but got the shape of param is {param.shape} and the shape of weight is{loaded_weight.shape}")
|
||||
param.init_data()
|
||||
param.set_data(ms.from_numpy(loaded_weight))
|
||||
param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype))
|
||||
if is_310p() and param.name.endswith("weight"):
|
||||
self.format_to_nz(param)
|
||||
|
||||
@@ -480,6 +481,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
array_id = 0
|
||||
elif loaded_shard_id == 'hidden':
|
||||
array_id = 1
|
||||
else:
|
||||
raise ValueError(f"Unknown loaded_shard_id: {loaded_shard_id}")
|
||||
shard_offset = sum(self.output_sizes[:array_id]) // tp_size
|
||||
shard_size = self.output_sizes[array_id] // tp_size
|
||||
|
||||
@@ -634,6 +637,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_offset = (self.num_heads +
|
||||
self.num_kv_heads) * self.head_size
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
else:
|
||||
raise ValueError(f"Unknown loaded_shard_id: {loaded_shard_id}")
|
||||
|
||||
if loaded_shard_id == "q":
|
||||
shard_id = tp_rank
|
||||
@@ -739,15 +744,15 @@ class RowParallelLinear(LinearBase):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""
|
||||
):
|
||||
super(RowParallelLinear, self).__init__(input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
config.params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
super().__init__(input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
config.params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
if stride > 1:
|
||||
raise NotImplementedError("For RowParallelLinear, `stride > 1` is not supported for now, "
|
||||
"but got `stride={}`".format(stride))
|
||||
raise NotImplementedError(f"For RowParallelLinear, `stride > 1` is not supported for now, "
|
||||
f"but got `stride={stride}`")
|
||||
if skip_bias_add:
|
||||
raise NotImplementedError("For RowParallelLinear, `skip_bias_add=True` is not supported for now")
|
||||
if keep_master_weight_for_test:
|
||||
@@ -878,7 +883,7 @@ class RowParallelLinear(LinearBase):
|
||||
raise ValueError(
|
||||
f"'{param.name}.shape' should be equal to 'loaded_weight.shape',"
|
||||
f" but got the shape of param is {param.shape} and the shape of weight is{loaded_weight.shape}")
|
||||
param.set_data(ms.from_numpy(loaded_weight))
|
||||
param.set_data(ms.from_numpy(loaded_weight)) # TODO param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype))
|
||||
if is_310p() and param.name.endswith("weight"):
|
||||
self.format_to_nz(param)
|
||||
|
||||
@@ -932,8 +937,8 @@ class ReplicatedLinear(LinearBase):
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
if stride > 1:
|
||||
raise NotImplementedError("For ReplicatedLinear, `stride > 1` is not supported for now, "
|
||||
"but got `stride={}`".format(stride))
|
||||
raise NotImplementedError(f"For ReplicatedLinear, `stride > 1` is not supported for now, "
|
||||
f"but got `stride={stride}`")
|
||||
if skip_bias_add:
|
||||
raise NotImplementedError("For ReplicatedLinear, `skip_bias_add=True` is not supported for now")
|
||||
if keep_master_weight_for_test:
|
||||
@@ -1061,7 +1066,9 @@ class ReplicatedLinear(LinearBase):
|
||||
f" but got the shape of param is {param.shape} "
|
||||
f"and the shape of weight is{loaded_weight.shape}")
|
||||
param.init_data()
|
||||
param.set_data(ms.from_numpy(loaded_weight))
|
||||
param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype))
|
||||
if is_310p() and param.name.endswith("weight"):
|
||||
self.format_to_nz(param, 2)
|
||||
|
||||
|
||||
class VocabParallelEmbedding(nn.Cell):
|
||||
|
||||
@@ -30,6 +30,7 @@ from mindformers.parallel_core.inference.transformer.activation import get_act_f
|
||||
from mindformers.parallel_core.inference.utils import divide
|
||||
from mindformers.parallel_core.process_group_config import ModelCommProcessGroups, default_model_comm_pgs
|
||||
from mindformers.parallel_core.inference.weights_utils import set_weight_attrs
|
||||
from mindformers.version_control import is_310p
|
||||
|
||||
|
||||
class GroupedMLP(nn.Cell):
|
||||
@@ -54,6 +55,7 @@ class GroupedMLP(nn.Cell):
|
||||
# use model_comm_pgs.moe_tp_group as tensor parallel group in this module.
|
||||
self.tp_group = model_comm_pgs.moe_tp
|
||||
self.tp_group_size = self.tp_group.size
|
||||
self.transpose_b = is_310p()
|
||||
|
||||
ffn_hidden_size = self.config.moe_ffn_hidden_size
|
||||
self.ffn_hidden_size_per_partition = divide(ffn_hidden_size, self.tp_group_size)
|
||||
@@ -70,6 +72,7 @@ class GroupedMLP(nn.Cell):
|
||||
num_local_experts=self.num_local_experts,
|
||||
input_size_per_partition=self.input_size,
|
||||
output_partition_sizes=[divide(ffn_hidden_size, self.tp_group_size)],
|
||||
transpose_b=self.transpose_b,
|
||||
params_dtype=self.config.params_dtype,
|
||||
)
|
||||
self.weight2 = self.quant_method.create_weights(
|
||||
@@ -77,6 +80,7 @@ class GroupedMLP(nn.Cell):
|
||||
num_local_experts=self.num_local_experts,
|
||||
input_size_per_partition=self.ffn_hidden_size_per_partition,
|
||||
output_partition_sizes=[self.input_size],
|
||||
transpose_b=self.transpose_b,
|
||||
params_dtype=self.config.params_dtype,
|
||||
)
|
||||
|
||||
@@ -92,6 +96,7 @@ class GroupedMLP(nn.Cell):
|
||||
weight=self.weight1, # Skip creating weights and use weight1 for gemm linear calculation
|
||||
is_expert=True,
|
||||
tp_group=self.tp_group,
|
||||
transpose_b=self.transpose_b,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_fc1",
|
||||
)
|
||||
@@ -112,6 +117,7 @@ class GroupedMLP(nn.Cell):
|
||||
weight=self.weight2, # Skip creating weights and use weight2 for gemm linear calculation
|
||||
is_expert=True,
|
||||
tp_group=self.tp_group,
|
||||
transpose_b=self.transpose_b,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_fc2",
|
||||
)
|
||||
|
||||
@@ -44,7 +44,7 @@ class Router(nn.Cell):
|
||||
Args:
|
||||
config (TransformerConfig): Configuration object for the Transformer model.
|
||||
"""
|
||||
super(Router, self).__init__()
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_experts = self.config.num_moe_experts
|
||||
self.router_dense_type = self.config.moe_router_dtype
|
||||
@@ -98,7 +98,7 @@ class Router(nn.Cell):
|
||||
"""
|
||||
loaded_weight = loaded_weight[:]
|
||||
if self.ep_group_size > 1 and not self.config.use_alltoall:
|
||||
expert_idx_list = [idx for idx in range(self.num_experts)]
|
||||
expert_idx_list = list(range(self.num_experts))
|
||||
start_idx = self.num_experts // self.ep_group_size * self.ep_rank
|
||||
expert_idx_list = expert_idx_list[start_idx:] + expert_idx_list[:start_idx]
|
||||
loaded_weight = loaded_weight[expert_idx_list]
|
||||
@@ -108,7 +108,7 @@ class Router(nn.Cell):
|
||||
f"'param.data.shape' should be equal to 'loaded_weight.shape',"
|
||||
f" but got the shape of param is {param.shape} "
|
||||
f"and the shape of weight is{loaded_weight.shape}")
|
||||
param.set_data(ms.from_numpy(loaded_weight).astype(param.dtype))
|
||||
param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype))
|
||||
|
||||
def construct(self, input_tensor: Tensor):
|
||||
"""
|
||||
|
||||
@@ -27,7 +27,7 @@ from mindspore.ops.auto_generate import (MoeInitRoutingV2,
|
||||
|
||||
from mindformers.parallel_core.transformer_config import TransformerConfig
|
||||
from mindformers.parallel_core.process_group_config import ModelCommProcessGroups, default_model_comm_pgs
|
||||
from mindformers.version_control import is_910b
|
||||
from mindformers.version_control import is_910b, is_310p
|
||||
|
||||
|
||||
class MoETokenDispatcher:
|
||||
@@ -109,6 +109,7 @@ class MoEAllGatherTokenDispatcher(MoETokenDispatcher):
|
||||
self.cast = ops.Cast()
|
||||
self.moe_init_routing_v2 = MoeInitRoutingV2()
|
||||
self.moe_token_unpermute = MoeTokenUnpermute()
|
||||
self.is_310p = is_310p()
|
||||
|
||||
def dispatch_preprocess(self, expert_weight, routing_map):
|
||||
"""Preprocess expert weight by masking out invalid experts."""
|
||||
@@ -129,8 +130,8 @@ class MoEAllGatherTokenDispatcher(MoETokenDispatcher):
|
||||
expert_capacity=0,
|
||||
expert_num=self.num_experts,
|
||||
drop_pad_mode=0,
|
||||
expert_tokens_count_or_cumsum_flag=2,
|
||||
expert_tokens_before_capacity_flag=True
|
||||
expert_tokens_count_or_cumsum_flag=1 if self.is_310p else 2,
|
||||
expert_tokens_before_capacity_flag=not self.is_310p
|
||||
)
|
||||
|
||||
# Avoid the problem of poor performance of the split(int32) operator
|
||||
@@ -143,7 +144,8 @@ class MoEAllGatherTokenDispatcher(MoETokenDispatcher):
|
||||
def token_combine(self, hidden_states, expert_weight, *args):
|
||||
"""Combines expert outputs."""
|
||||
(tokens_per_expert,) = args
|
||||
hidden_states = mint.nan_to_num(hidden_states, 0, 0, 0)
|
||||
if not self.is_310p:
|
||||
hidden_states = mint.nan_to_num(hidden_states, 0, 0, 0)
|
||||
expert_weight = expert_weight.astype(hidden_states.dtype)
|
||||
hidden_states = self.moe_token_unpermute(
|
||||
permuted_tokens=hidden_states,
|
||||
|
||||
@@ -45,7 +45,8 @@ from mindformers.parallel_core.inference.base_models.common.embeddings.yarn_rota
|
||||
from mindformers.parallel_core.inference.base_models.common.embeddings.rope_utils import get_rope
|
||||
from mindformers.parallel_core.process_group_config import ModelCommProcessGroups, default_model_comm_pgs
|
||||
from mindformers.parallel_core.inference.weights_utils import set_weight_attrs, split_loaded_weight
|
||||
|
||||
from mindformers.version_control import is_310p
|
||||
from mindformers.models.utils import format_type
|
||||
|
||||
@dataclass
|
||||
class MLASelfAttentionSubmodules:
|
||||
@@ -172,6 +173,7 @@ class MultiLatentAttention(Attention):
|
||||
self.dim_slice_3d = P.Slice()
|
||||
self.transpose = P.Transpose()
|
||||
self.out_absorb_matmul = P.BatchMatMul(transpose_b=True)
|
||||
self.is_310p = is_310p()
|
||||
|
||||
def construct(
|
||||
self,
|
||||
@@ -363,7 +365,10 @@ class MLASelfAttention(MultiLatentAttention):
|
||||
Process the weight after loading.
|
||||
This can be used for example, to transpose weights for computation.
|
||||
"""
|
||||
q_absorb, out_absorb = mint.split(self.linear_kv_up_proj.weight,
|
||||
weight = self.linear_kv_up_proj.weight
|
||||
if self.is_310p:
|
||||
weight = ops.auto_generate.format_cast(weight, format_type['nd'])
|
||||
q_absorb, out_absorb = mint.split(weight,
|
||||
[self.num_attention_heads_per_partition * self.config.qk_head_dim,
|
||||
self.num_attention_heads_per_partition * self.config.v_head_dim], -2)
|
||||
self.q_absorb = q_absorb.reshape(self.num_attention_heads_per_partition,
|
||||
@@ -531,7 +536,7 @@ class FusedMLASelfAttention(MLASelfAttention):
|
||||
self.is_modelslim = quant_config.is_modelslim
|
||||
self.fa3_quant = quant_config.fa3_quant
|
||||
self.fa3_quant_layer = quant_config.fa3_quant_layer
|
||||
self.is_fa3_quant_layer = (layer_number - 1) in self.fa3_quant_layer # layer_number start from 1
|
||||
self.is_fa3_quant_layer = layer_number - 1 in self.fa3_quant_layer # layer_number start from 1
|
||||
self.input_layernorm_weight = None
|
||||
self.qkv_down_proj_input_scale = None
|
||||
self.q_layernorm_weight = None
|
||||
@@ -542,10 +547,10 @@ class FusedMLASelfAttention(MLASelfAttention):
|
||||
self.q_up_proj_input_offset = None
|
||||
self.input_format = 1 if self.fa3_quant else 0
|
||||
self.use_ringmla = use_ms_custom_ops() and get_tensor_model_parallel_world_size() < 16
|
||||
import ms_custom_ops
|
||||
import ms_custom_ops # pylint: disable=import-outside-toplevel
|
||||
self.ms_custom_ops = ms_custom_ops
|
||||
self.scale_value = 1 / math.sqrt(self.config.kv_lora_rank + self.config.qk_head_dim) \
|
||||
if self.softmax_scale is None else self.softmax_scale
|
||||
self.scale_value = (1 / math.sqrt(self.config.kv_lora_rank + self.config.qk_head_dim)
|
||||
if self.softmax_scale is None else self.softmax_scale)
|
||||
self.ring_mla_mask = Tensor(np.triu(np.ones((512, 512), dtype=np.float16), 1), dtype.bfloat16)
|
||||
self.depend = P.Depend()
|
||||
self.quant = QuantV2()
|
||||
@@ -793,7 +798,7 @@ class FusedMLASelfAttention(MLASelfAttention):
|
||||
k_cache = self.transpose(key_cache.reshape(-1, self.config.kv_lora_rank // 32, \
|
||||
self.config.block_size, 32), (0, 2, 1, 3)).reshape( \
|
||||
-1, self.config.block_size, self.config.kv_lora_rank)
|
||||
k_cache = (self.cast(k_cache, dtype.bfloat16) / self.quant_ctkv_scale)
|
||||
k_cache = self.cast(k_cache, dtype.bfloat16) / self.quant_ctkv_scale
|
||||
else:
|
||||
k_cache = self.ms_custom_ops.trans_data(key_cache, transdata_type=0)
|
||||
v_cache = self.ms_custom_ops.trans_data(value_cache, transdata_type=0)
|
||||
|
||||
@@ -76,8 +76,8 @@ def get_attn_mask_func(mask_func_type):
|
||||
"""
|
||||
if mask_func_type not in ATTNMASK_FUNC_MAP:
|
||||
raise KeyError("Invalid attention mask function. Supported attention "
|
||||
"mask function are ['attn_mask_fill', 'attn_mask_add'] "
|
||||
", but got {}.".format(mask_func_type))
|
||||
f"mask function are ['attn_mask_fill', 'attn_mask_add'] "
|
||||
f", but got {mask_func_type}.")
|
||||
return ATTNMASK_FUNC_MAP[mask_func_type]
|
||||
|
||||
|
||||
@@ -158,7 +158,7 @@ def create_empty_parameter(shape, *, dtype=None, device=None, **kwargs):
|
||||
def ensure_divisibility(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator."""
|
||||
if numerator % denominator != 0:
|
||||
raise ValueError("{} is not divisible by {}".format(numerator, denominator))
|
||||
raise ValueError(f"{numerator} is not divisible by {denominator}")
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
@@ -178,9 +178,9 @@ def save_strategy_file(state_dict, strategy_file_name):
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
"""
|
||||
import os
|
||||
import stat
|
||||
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy
|
||||
import os # pylint: disable=import-outside-toplevel
|
||||
import stat # pylint: disable=import-outside-toplevel
|
||||
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy # pylint: disable=import-outside-toplevel
|
||||
|
||||
stra = ckpt_strategy()
|
||||
|
||||
@@ -367,9 +367,21 @@ def use_ms_custom_ops():
|
||||
"""
|
||||
try:
|
||||
# pylint: disable=W0611
|
||||
import ms_custom_ops
|
||||
import ms_custom_ops # pylint: disable=import-outside-toplevel
|
||||
except ModuleNotFoundError:
|
||||
# environment need install ms_custom_ops package
|
||||
return False
|
||||
|
||||
return not is_310p()
|
||||
|
||||
def cast_weight_for_310p(loaded_weight):
|
||||
"""
|
||||
Casts weights to float16 for 310p.
|
||||
|
||||
In non-quantized scenarios, the 310P hardware only supports float16 weights.
|
||||
This function converts float32 or bfloat16 weights to float16.
|
||||
"""
|
||||
cast_weight = (loaded_weight.astype(np.float16) if
|
||||
(str(loaded_weight.dtype) == "float32" or str(
|
||||
loaded_weight.dtype) == "bfloat16") else loaded_weight)
|
||||
return cast_weight
|
||||
|
||||
@@ -64,7 +64,7 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size):
|
||||
elif shard_dim == 2:
|
||||
loaded_weight = loaded_weight[:, :, start_idx:end_idx]
|
||||
else:
|
||||
raise ValueError("shard_dim:{} is not supported.".format(shard_dim))
|
||||
raise ValueError(f"shard_dim:{shard_dim} is not supported.")
|
||||
loaded_weight = loaded_weight.astype(np.float16) \
|
||||
if (str(loaded_weight.dtype) == 'bfloat16' and is_310p()) else loaded_weight
|
||||
return loaded_weight
|
||||
@@ -393,3 +393,22 @@ def split_fusion_loaded_weight(loaded_weight, start_idxs, shard_sizes):
|
||||
loaded_weight_parts.append(loaded_weight[start_idx:start_idx + shard_size])
|
||||
perrank_ffn_weight = np.concatenate(loaded_weight_parts, axis=0)
|
||||
return perrank_ffn_weight
|
||||
|
||||
|
||||
def process_weights_for_310p(name, loaded_weight, params_dict, loaded_params):
|
||||
"""
|
||||
Process the loadweight for 310P.
|
||||
|
||||
Args:
|
||||
name: The name of the weight to be loaded.
|
||||
loaded_weight: The weights to be loaded.
|
||||
params_dict: The dictionary of model parameters.
|
||||
loaded_params: The set of already loaded parameter names.
|
||||
Returns:
|
||||
The processed weights.
|
||||
"""
|
||||
|
||||
if "deq_scale" in name:
|
||||
loaded_weight = loaded_weight[:].astype(np.float32).view(np.int32).astype(np.int64)
|
||||
|
||||
return loaded_weight
|
||||
|
||||
@@ -3975,7 +3975,7 @@
|
||||
"signature": "(self, layer: mindspore.nn.cell.Cell) -> None"
|
||||
},
|
||||
"mindformers.parallel_core.inference.tensor_parallel.grouped_layers.ColumnParallelGroupedLinear": {
|
||||
"signature": "(num_local_experts: int, input_size: int, output_size: int, *, config: mindformers.parallel_core.transformer_config.TransformerConfig, bias: bool = False, gather_output: bool = False, stride: int = 1, skip_bias_add: bool = False, weight: mindspore.common.tensor.Tensor = None, is_expert: bool = True, tp_comm_buffer_name: str = None, tp_group: mindformers.parallel_core.inference.parallel_state.ProcessGroup = <mindformers.parallel_core.inference.parallel_state.ProcessGroup object>, quant_config: Optional[mindformers.parallel_core.inference.quantization.base_config.QuantizationConfig] = None, prefix: str = '')"
|
||||
"signature": "(num_local_experts: int, input_size: int, output_size: int, *, config: mindformers.parallel_core.transformer_config.TransformerConfig, bias: bool = False, gather_output: bool = False, stride: int = 1, skip_bias_add: bool = False, weight: mindspore.common.tensor.Tensor = None, is_expert: bool = True, tp_comm_buffer_name: str = None, transpose_b: bool = False, tp_group: mindformers.parallel_core.inference.parallel_state.ProcessGroup = <mindformers.parallel_core.inference.parallel_state.ProcessGroup object>, quant_config: Optional[mindformers.parallel_core.inference.quantization.base_config.QuantizationConfig] = None, prefix: str = '')"
|
||||
},
|
||||
"mindformers.parallel_core.inference.tensor_parallel.grouped_layers.ColumnParallelGroupedLinear.construct": {
|
||||
"signature": "(self, input_parallel, weight=None, group_list=None)"
|
||||
@@ -3987,7 +3987,7 @@
|
||||
"signature": "(self, param, loaded_weight, shard_id=None, expert_id=None) -> None"
|
||||
},
|
||||
"mindformers.parallel_core.inference.tensor_parallel.grouped_layers.RowParallelGroupedLinear": {
|
||||
"signature": "(num_local_experts: int, input_size: int, output_size: int, *, config: mindformers.parallel_core.transformer_config.TransformerConfig, bias: bool = False, input_is_parallel: bool = True, skip_bias_add: bool = False, weight: mindspore.common.tensor.Tensor = None, stride: int = 1, delay_allreduce: bool = True, is_expert: bool = True, tp_comm_buffer_name: str = None, tp_group: mindformers.parallel_core.inference.parallel_state.ProcessGroup = <mindformers.parallel_core.inference.parallel_state.ProcessGroup object>, quant_config: Optional[mindformers.parallel_core.inference.quantization.base_config.QuantizationConfig] = None, prefix: str = '')"
|
||||
"signature": "(num_local_experts: int, input_size: int, output_size: int, *, config: mindformers.parallel_core.transformer_config.TransformerConfig, bias: bool = False, input_is_parallel: bool = True, skip_bias_add: bool = False, weight: mindspore.common.tensor.Tensor = None, stride: int = 1, delay_allreduce: bool = True, is_expert: bool = True, tp_comm_buffer_name: str = None, transpose_b: bool = False, tp_group: mindformers.parallel_core.inference.parallel_state.ProcessGroup = <mindformers.parallel_core.inference.parallel_state.ProcessGroup object>, quant_config: Optional[mindformers.parallel_core.inference.quantization.base_config.QuantizationConfig] = None, prefix: str = '')"
|
||||
},
|
||||
"mindformers.parallel_core.inference.tensor_parallel.grouped_layers.RowParallelGroupedLinear.construct": {
|
||||
"signature": "(self, input_, weight=None, group_list=None)"
|
||||
|
||||
Reference in New Issue
Block a user