mirror of
https://gitee.com/mindspore/mindformers.git
synced 2025-12-06 11:29:59 +08:00
【GLM4】GLM4模型,封装权重转换函数,统一到mf根目录下的convert_weight.py中,并修复文档中相应描述,防止使用时产生歧义
This commit is contained in:
@@ -36,7 +36,7 @@ reversed_dtype_map = {
|
||||
convert_map = {
|
||||
'llama': 'mindformers.models.llama.convert_weight.convert_pt_to_ms',
|
||||
'qwen2_5': 'research.qwen2_5.convert_weight.convert_weight',
|
||||
'glm-n': 'mindformers.models.glm2.convert_weight.convert_pt_to_ms',
|
||||
'glm-n': 'mindformers.models.glm2.convert_weight.convert_weight',
|
||||
'mixtral': 'research.mixtral.convert_weight.convert_pt_to_ms',
|
||||
'telechat': 'research.telechat.convert_weight.convert_pt_to_ms',
|
||||
'deepseekv3': 'toolkit.weight_convert.deepseekv3.convert_deepseekv3_hf_weight.convert_weight'
|
||||
@@ -60,6 +60,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--telechat_type', default="telechat_12b", type=str, required=False,
|
||||
help="Only for telechat. Telechat version.")
|
||||
parser.add_argument('--is_lora', default=False, type=str2bool, required=False)
|
||||
|
||||
args, extra_args = parser.parse_known_args()
|
||||
extra_args = [i
|
||||
for item in extra_args
|
||||
@@ -78,6 +79,9 @@ if __name__ == '__main__':
|
||||
raise ValueError("Custom config key need to start with --.")
|
||||
extra_kwargs[key[2:]] = value
|
||||
|
||||
if args.model in ["glm4"]:
|
||||
args.model = "glm-n"
|
||||
|
||||
if args.reversed:
|
||||
module_func = reversed_convert_map.get(args.model)
|
||||
dtype = reversed_dtype_map.get(args.dtype)
|
||||
@@ -93,8 +97,5 @@ if __name__ == '__main__':
|
||||
model_name, func_name = module_func.rsplit('.', 1)
|
||||
convert_func = getattr(importlib.import_module(model_name), func_name)
|
||||
|
||||
if args.model in ["qwen2_5", "deepseekv3"]:
|
||||
merged_args = argparse.Namespace(**{**vars(args), **extra_kwargs})
|
||||
convert_func(merged_args)
|
||||
else:
|
||||
convert_func(input_path=args.input_path, output_path=args.output_path, dtype=dtype, **extra_kwargs)
|
||||
merged_args = argparse.Namespace(**{**vars(args), **extra_kwargs})
|
||||
convert_func(merged_args)
|
||||
|
||||
@@ -129,7 +129,7 @@ MindSpore TransFormers 提供已经转换完成的预训练权重、词表文件
|
||||
|
||||
#### 模型权重转换
|
||||
|
||||
1. 如果使能高性能模式(`enable_high_performance=True`),需要按如下方式修改 yaml :
|
||||
1. 如果使能高性能模式(`enable_high_performance=True`),需要按如下方式修改 yaml 文件配置:
|
||||
|
||||
```yaml
|
||||
model:
|
||||
@@ -138,36 +138,37 @@ MindSpore TransFormers 提供已经转换完成的预训练权重、词表文件
|
||||
mlp_concat: False
|
||||
```
|
||||
|
||||
2. 执行 `convert_weight.py` 转换脚本,将 HuggingFace 的权重转换为完整的 ckpt 权重。
|
||||
2. 执行 mindforers 根目录下的 `convert_weight.py` [转换脚本](https://gitee.com/mindspore/mindformers/blob/dev/convert_weight.py),将 HuggingFace 的权重转换为完整的 MindSpore ckpt 权重。
|
||||
|
||||
```shell
|
||||
python convert_weight.py --torch_ckpt_path TORCH_CKPT_DIR --mindspore_ckpt_path MS_CKPT_NAME --dtype DTYPE --config YAML_PATH
|
||||
python convert_weight.py --model glm4 --input_path HF_CKPT_PATH --output_path MS_NOT_CONCAT_CKPT_PATH --dtype DTYPE --config YAML_PATH
|
||||
```
|
||||
|
||||
参数说明如下表:
|
||||
|
||||
| 参数名 | 含义 | 取值说明 |
|
||||
|-----------------------------|-------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------|
|
||||
| `--torch_ckpt_path` | HuggingFace 权重文件路径。 | (str, 可选) - 默认值: `None` 。 |
|
||||
| `--mindspore_ckpt_path` | 转换后的 MindSpore 权重文件保存路径 (qkv 和 ffn concat)。 | (str, 必选) - 默认值: `None` 。 |
|
||||
| `--dtype` | 权重的数值类型,一般有 `float16` 、 `float32` 、 `bfloat16` 。 | (str, 可选) - 配置为 [`fp32`, `fp16`, `bf16`] 其中之一,默认值: `fp32` 。 |
|
||||
| `--config` | glm4 模型所用 yaml 文件的路径。 | (str, 必选) - 如 `research/glm32k/finetune_glm32k.yaml` 。 |
|
||||
| 参数名 | 含义 | 取值说明 |
|
||||
|-----------------|--------------------------------------------------|---------------------------------------------------------------------|
|
||||
| `--model` | 需要进行权重转换的模型,此处使用 `glm4` 。 | (str, 必选) - 默认值: `None` 。 |
|
||||
| `--input_path` | HuggingFace 权重文件路径。 | (str, 必选) - 默认值: `None` 。 |
|
||||
| `--output_path` | 转换后的 MindSpore 权重文件保存路径 (qkv 和 ffn concat)。 | (str, 必选) - 默认值: `None` 。 |
|
||||
| `--dtype` | 权重的数值类型,一般有 `float16` 、 `float32` 、 `bfloat16` 。 | (str, 可选) - 配置为 [`fp32`, `fp16`, `bf16`] 其中之一,默认值: `fp32` 。 |
|
||||
| `--config` | glm4 模型所用 yaml 文件的路径。 | (str, 必选) - 如 `configs/glm4/finetune_glm4_9b.yaml` ,默认值: `None` 。 |
|
||||
|
||||
3. 如果使能高性能模式,除了将 HuggingFace 权重转换为 ckpt 权重后,还需要将转换后得到的 ckpt 权重作为 `--ms_not_concat_ckpt_path` 指定的路径,额外执行如下转换。
|
||||
3. 如果使能高性能模式(步骤1),除了将 HuggingFace 权重转换为 ckpt 权重后,还需要将转换后得到的 ckpt 权重作为 `--input_path` 指定的路径。在转换完 HuggingFace 权重后,需要将得到的 MindSpore 权重额外执行如下转换。
|
||||
|
||||
```shell
|
||||
python convert_weight.py --ms_not_concat_ckpt_path PRE_CKPT_DIR --mindspore_ckpt_path MS_CKPT_NAME --dtype DTYPE --config YAML_PATH --concat True
|
||||
python convert_weight.py --model glm4 --input_path MS_NOT_CONCAT_CKPT_PATH --output_path MS_CONCATED_BIAS_CKPT_DIR --config YAML_PATH --concat True
|
||||
```
|
||||
|
||||
参数说明如下表:
|
||||
|
||||
| 参数名 | 含义 | 取值说明 |
|
||||
|-----------------------------|-------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------|
|
||||
| `--ms_not_concat_ckpt_path` | qkv 和 ffn 没有 concat 的 MindSpore 权重路径。<br>结合 `--concat` 进行配置后,将对此路径下的权重进行 qkv 和 ffn 的 concat,并生成新权重在 `--mindspore_ckpt_path` 指定的路径下。 | (str, 可选) - 默认值: `None` 。 |
|
||||
| `--mindspore_ckpt_path` | 转换后的 MindSpore 权重文件保存路径 (qkv 和 ffn concat)。 | (str, 必选) - 默认值: `None` 。 |
|
||||
| `--dtype` | 权重的数值类型,一般有 `float16` 、 `float32` 、 `bfloat16` 。 | (str, 可选) - 配置为 [`fp32`, `fp16`, `bf16`] 其中之一,默认值: `fp32` 。 |
|
||||
| `--config` | glm4 模型所用 yaml 文件的路径。 | (str, 必选) - 如 `research/glm32k/finetune_glm32k.yaml` 。 |
|
||||
| `--concat` | 指定开启 qkv、ffn concat。 | (bool, 可选) - 默认值: `False` 。 |
|
||||
| 参数名 | 含义 | 取值说明 |
|
||||
|------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------|
|
||||
| `--model` | 需要进行权重转换的模型,此处使用 `glm4` 。 | (str, 必选) - 默认值: `None` 。 |
|
||||
| `--input_path` | qkv 和 ffn 没有 concat 的 MindSpore ckpt 权重路径。<br>结合 `--concat` 进行配置后,将对此路径下的 MindSpore 权重进行 qkv 和 ffn 的 concat,并生成新权重在 `--output_path` 指定的路径下。 | (str, 必选) - 默认值: `None` 。 |
|
||||
| `--output_path` | 转换后的 MindSpore 权重文件保存路径 (qkv 和 ffn concat)。 | (str, 必选) - 默认值: `None` 。 |
|
||||
| `--config` | glm4 模型所用 yaml 文件的路径。 | (str, 必选) - 如 `configs/glm4/finetune_glm4_9b.yaml` 。 |
|
||||
| `--concat` | 指定开启 qkv、ffn concat。<br>注意:使用 `--concat` 时,需要保证 `--input_path` 为 MindSpore 权重,而非 HuggingFace 权重。否则,会出现 `huggingface_hub.errors.HFValidationError` 的报错提示。 | (bool, 可选) - 默认值: `False` 。 |
|
||||
|
||||
## 全参微调
|
||||
|
||||
@@ -198,6 +199,8 @@ bash scripts/msrun_launcher.sh "run_mindformer.py \
|
||||
|
||||
MindSpore Transformers 提供了 `GLM-4-9B-Chat` 的快速推理脚本,脚本主要通过 `generate` 高阶接口实现,支持单卡、双卡多 batch 推理。
|
||||
|
||||
配置文件可以参考 `configs/glm4/predict_glm4_9b_chat.yaml` 和 `configs/glm4/predict_glm4_9b_chat_800I_A2.yaml` 示例。
|
||||
|
||||
```shell
|
||||
bash scripts/examples/glm4/run_glm4_predict.sh PARALLEL CONFIG_PATH CKPT_PATH TOKENIZER DEVICE_NUM
|
||||
```
|
||||
@@ -226,7 +229,7 @@ bash scripts/examples/glm4/run_glm4_predict.sh \
|
||||
# 双卡推理
|
||||
bash scripts/examples/glm4/run_glm4_predict.sh \
|
||||
parallel \
|
||||
/path/to/glm4/predict_glm4_9b_800I_A2.yaml \
|
||||
/path/to/glm4/predict_glm4_9b_chat_800I_A2.yaml \
|
||||
/path/to/glm4_ckpt_dir \
|
||||
/path/to/tokenizer.model \
|
||||
2
|
||||
|
||||
@@ -29,6 +29,11 @@ from mindformers.tools.utils import str2bool
|
||||
from mindformers.models.glm2.glm2_config import ChatGLM2Config
|
||||
from mindformers.utils.convert_utils import pt2ms
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"fp32": ms.float32,
|
||||
"fp16": ms.float16,
|
||||
"bf16": ms.bfloat16
|
||||
}
|
||||
|
||||
def npy2ms(arr: np.array, dtype):
|
||||
"""npy2ms"""
|
||||
@@ -250,8 +255,43 @@ def concat_weight_and_bias(param_dict, config):
|
||||
return param_dict
|
||||
|
||||
|
||||
def convert_weight(para):
|
||||
"""convert weight entrance"""
|
||||
if para.config is None:
|
||||
raise RuntimeError("config must be specified")
|
||||
|
||||
if not hasattr(para, 'concat'):
|
||||
para.concat = False
|
||||
else:
|
||||
para.concat = para.concat == "True"
|
||||
|
||||
if para.concat:
|
||||
if not hasattr(para, 'ms_not_concat_ckpt_path'):
|
||||
para.ms_not_concat_ckpt_path = para.input_path
|
||||
if not hasattr(para, 'mindspore_ckpt_path'):
|
||||
para.mindspore_ckpt_path = para.output_path
|
||||
|
||||
convert_to_concat_ckpt(
|
||||
ms_not_concat_ckpt_path=para.ms_not_concat_ckpt_path,
|
||||
ms_concat_ckpt_path=para.mindspore_ckpt_path,
|
||||
config_path=para.config
|
||||
)
|
||||
else:
|
||||
if not hasattr(para, 'torch_ckpt_path'):
|
||||
para.torch_ckpt_path = para.input_path
|
||||
if not hasattr(para, 'mindspore_ckpt_path'):
|
||||
para.mindspore_ckpt_path = para.output_path
|
||||
|
||||
convert_pt_to_ms(
|
||||
input_path=para.torch_ckpt_path,
|
||||
output_path=para.mindspore_ckpt_path,
|
||||
dtype=DTYPE_MAPPING.get(para.dtype, ms.float32),
|
||||
config=para.config
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="GLM2/3 weight convert script")
|
||||
parser = argparse.ArgumentParser(description="GLM4 weight convert script")
|
||||
parser.add_argument("--torch_ckpt_path",
|
||||
type=str,
|
||||
default="None",
|
||||
@@ -269,23 +309,10 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--config",
|
||||
type=str,
|
||||
required=True,
|
||||
default="None",
|
||||
help="Path to model config yaml")
|
||||
parser.add_argument('--concat', default=False, type=str2bool, help="Whether to concat weight and bias")
|
||||
parser.add_argument('--ms_not_concat_ckpt_path', default=None)
|
||||
mapping = {
|
||||
"fp32": ms.float32,
|
||||
"fp16": ms.float16,
|
||||
"bf16": ms.bfloat16
|
||||
}
|
||||
|
||||
opt = parser.parse_args()
|
||||
if opt.config is None:
|
||||
raise RuntimeError("config must be specified")
|
||||
if opt.concat:
|
||||
convert_to_concat_ckpt(ms_not_concat_ckpt_path=opt.ms_not_concat_ckpt_path,
|
||||
ms_concat_ckpt_path=opt.mindspore_ckpt_path,
|
||||
config_path=opt.config)
|
||||
else:
|
||||
convert_pt_to_ms(input_path=opt.torch_ckpt_path, output_path=opt.mindspore_ckpt_path,
|
||||
dtype=mapping.get(opt.dtype, ms.bfloat16),
|
||||
config=opt.config)
|
||||
convert_weight(opt)
|
||||
|
||||
Reference in New Issue
Block a user