Files
MindSpeed-LLM/docs/pytorch/solutions/checkpoint_convert.md
HanhuiChen a840be21ef !3033 [pytorch][md]deprecate documents and cases for legacy model
Merge pull request !3033 from HanhuiChen/2.1.0
2025-07-19 03:51:26 +00:00

21 KiB
Raw Permalink Blame History

权重转换

权重转换背景

随着大规模预训练模型的广泛应用不同的训练框架和硬件平台之间的适配性问题逐渐显现。专有训练框架如MindSpeed-LLM通常采用定制的并行化策略例如Tensor Parallelism、Pipeline Parallelism以应对大规模模型训练中的内存和计算瓶颈。随着训练需求和硬件的变化模型参数的切分策略也需进行相应的调整。然而跨框架的权重转换往往面临格式不兼容和切分策略不同等挑战。权重转换旨在促进大规模预训练模型在不同训练框架之间的无缝迁移与评估解决框架间权重格式不兼容及切分策略差异等问题从而增强模型迁移的灵活性和可扩展性,支持更广泛的应用场景和业务需求。

权重转换使用

权重转换旨在解决不同深度学习框架和训练策略下模型权重的兼容性问题,支持在多个模型和训练配置之间进行高效的权重互转。核心功能包括:

权重互转支持100+种模型的权重互转,能够在 Hugging Face、Megatron-LM主流框架之间实现任意并行切分策略的权重格式互转。在转换过程中用户需要通过指定参数 --use-mcore-models 来将权重转换为 Megatron-Mcore 格式。

训练并行策略权重转换:支持多种训练并行策略之间的权重转换,包括 张量并行、流水线并行、专家并行、流水并行动态划分 和 虚拟流水并行 等。无论是针对不同并行策略的训练,还是需要在不同策略之间切换的场景,都能实现灵活的权重转换,以适应各种训练和推理需求。

Lora权重合并与转换:支持将 Lora 权重与 Base 权重合并简化了模型推理过程中的加载步骤。合并后的模型可直接用于推理显著提升了推理效率减少了不必要的计算资源消耗。支持将Lora微调权重单独转为Huggingface格式以支持客户下游任务。

优化器权重转换:支持多种并行切分策略,确保优化器状态在不同并行策略间的迁移与兼容,便于在不同训练环境下进行优化器状态恢复。

1. 权重下载

从Huggingface等网站下载开源模型权重

预训练权重链接在 稠密模型MoE模型 章节列表的参数列链接中可以获取;更多社区资源可以在模型列链接中获取,如Chat/Instruct权重等。

权重可以基于网页直接下载也可以基于命令行下载保存到MindSpeed-LLM/model_from_hf目录比如

#!/bin/bash
mkdir ./model_from_hf/llama-2-7b-hf/
cd ./model_from_hf/llama-2-7b-hf/
wget https://huggingface.co/daryl149/llama-2-7b-hf/resolve/main/config.json
wget https://huggingface.co/daryl149/llama-2-7b-hf/resolve/main/generation_config.json
wget https://huggingface.co/daryl149/llama-2-7b-hf/resolve/main/pytorch_model-00001-of-00002.bin
wget https://huggingface.co/daryl149/llama-2-7b-hf/resolve/main/pytorch_model-00002-of-00002.bin
wget https://huggingface.co/daryl149/llama-2-7b-hf/resolve/main/pytorch_model.bin.index.json
wget https://huggingface.co/daryl149/llama-2-7b-hf/resolve/main/special_tokens_map.json
wget https://huggingface.co/daryl149/llama-2-7b-hf/resolve/main/tokenizer.json
wget https://huggingface.co/daryl149/llama-2-7b-hf/resolve/main/tokenizer.model
wget https://huggingface.co/daryl149/llama-2-7b-hf/resolve/main/tokenizer_config.json
cd ../../

2. 权重转换

2.1 Huggingface权重转换到Megatron-LM格式

权重转换实现了 HuggingFace 权重到 Megatron-LM 格式的转换,支持多种并行策略(如张量并行、流水并行等),确保转换后可以在 MindSpeed-LLM 框架下继续训练和推理。下面提供一个Llama2-7b模型的hf-mg权重转换脚本仅供参考

python convert_ckpt.py \
    --model-type GPT \
    --load-model-type hf \
    --save-model-type mg \
    --target-tensor-parallel-size 2 \
    --target-pipeline-parallel-size 4 \
    --num-layer-list 8,8,8,8 \
    --model-type-hf llama2 \
    --use-mcore-models \
    --load-dir ./model_from_hf/llama-2-7b-hf/ \
    --save-dir ./model_weights/llama-2-7b-mcore/ \
    --tokenizer-model ./model_from_hf/llama-2-7b-hf/tokenizer.model
参数 说明 可选/必选
--target-tensor-parallel-size TP 切分数量,默认为 1 必选
--target-pipeline-parallel-size PP 切分数量,默认为 1 必选
--num-layer-list 动态PP划分通过列表指定每个PP Stage的层数默认为None 可选
--num-layers-per-virtual-pipeline-stage VPP划分指定VPP的每个Stage层数默认为None 可选
--target-expert-model-parallel-size 专家并行指定专家并行卡数默认为1 可选
--noop-layers 自定义空层操作指定在模型某层增加空层转换后层数为原huggingface模型层数+空层数默认为None 可选
--use-mcore-models 转换为Megatron-Mcore权重 必选
--model-type-hf huggingface模型类别默认为llama2 可选
--tokenizer-model 需要指明到具体的分词器模型文件,如 tokenizer.model、tokenizer.json、qwen.tiktoken、None等具体取决于huggingface中词表文件的格式形式 必选
--params-dtype 指定权重转换后的权重精度模式默认为fp16如果源格式文件为bf16则需要对应设置为bf16影响推理或评估结果 必选

注意

1、VPP和动态PP划分只能二选一

2、目前支持的模型见 model_cfg.json中“model_mappings”下包含的模型。

【启动脚本】

MindSpeed-LLM Huggingface到Megatron-Mcore权重转换脚本命名风格及启动方法为

# 命名及启动:
# bash examples/mcore/model_name/ckpt_convert_xxx_hf2mcore.sh
# 需要配置并行参数以及权重词表加载保存等路径

bash examples/mcore/llama2/ckpt_convert_llama2_hf2mcore.sh

2.2 Megatron-LM权重转换到Huggingface格式

权重转换实现了 Megatron-LM 权重到 HuggingFace 格式的转换,支持多种并行策略(如张量并行、流水并行等)。转换过程中,模型的权重会被适配为 HuggingFace 的标准格式,确保可以在 HuggingFace 环境下继续进行训练和推理。下面提供一个Llama2-7b模型的mg-hf权重转换脚本仅供参考

python convert_ckpt.py \
    --model-type GPT \
    --load-model-type mg \
    --save-model-type hf \
    --model-type-hf llama2 \
    --use-mcore-models \
    --load-dir ./model_weights/llama-2-7b-mcore/ \
    --target-tensor-parallel-size 1 \
    --target-pipeline-parallel-size 1 \
    --save-dir ./model_from_hf/llama-2-7b-hf/  # <-- 需要填入原始HF模型路径新权重会存于./model_from_hf/llama-2-7b-hf/mg2hf/

参数意义参考2.1

注意: 转到Huggingface权重必须设置--target-tensor-parallel-size = 1、--target-pipeline-parallel-size = 1。

【启动脚本】

MindSpeed-LLM Megatron-Mcore到Huggingface的权重转换脚本命名风格及启动方法为

# 命名及启动:
# bash examples/mcore/model_name/ckpt_convert_xxx_mcore2hf.sh
# 需要配置并行参数以及权重词表加载保存等路径

bash examples/mcore/llama2/ckpt_convert_llama2_mcore2hf.sh

2.3 lora权重转换

当前仓库支持以下两种lora权重转换方法:

(1) 将Lora微调权重与基础模型权重合并转换为Megatron或Huggingface格式 ;

(2) 将Lora微调权重单独转为Huggingface格式在lora微调脚本中加入参数--lora-ckpt-filter仅保存lora权重。

2.3.1 Megatron-Mcore格式权重合并

在权重转换命令中加入如下参数可以将训练的lora权重与权重转换出的base权重进行融合。

--lora-load ./ckpt/llama-2-7b-lora  \
--lora-r 16 \
--lora-alpha 32 \
--lora-target-modules linear_qkv linear_proj linear_fc1 linear_fc2 \
参数 说明 可选/必选
--lora-load 加载 lora 微调后生成的权重 可选
--lora-r LoRA中的秩rank它决定了低秩矩阵的大小 可选
--lora-alpha 定义了LoRA适应的学习率缩放因子。这个参数影响了低秩矩阵的更新速度 可选
--lora-target-modules 该参数定义了LoRA目标模块为一个由空格分隔的字符串列表且不具有默认值。每个字符串对应需要进行LoRA微调的层名称且只能在上述四种预定义的参数配置中选择。用户可根据具体需求调整该参数。 可选

【合并后转换为Megatron-Mcore权重】

下面提供Megatron-Mcore格式的Llama2-7b模型的Lora权重与base权重合并并转为Megatron-Mcore格式的示例脚本仅供参考

python convert_ckpt.py \
    --model-type GPT \
    --use-mcore-models \
    --load-model-type mg \
    --save-model-type mg \
    --load-dir ./model_weights/llama-2-7b-mcore/ \
    --lora-load ./ckpt/llama-2-7b-lora \
    --lora-r 16 \
    --lora-alpha 32 \
    --lora-target-modules linear_qkv linear_proj linear_fc1 linear_fc2 \
    --target-tensor-parallel-size 1 \
    --target-pipeline-parallel-size 1 \
    --save-dir ./model_weights/llama-2-7b-lora2mcore
参数 说明 可选/必选
--lora-target-modules 该参数定义了LoRA目标模块为一个由空格分隔的字符串列表且不具有默认值。每个字符串对应需要进行LoRA微调的层名称且只能在上述四种预定义的参数配置中选择。用户可根据具体需求调整该参数。 可选

转换脚本命名风格及启动方法为:

#命令启动方式以 llama2 为例
bash examples/mcore/llama2/ckpt_convert_llama2_mg2mg_lora.sh

【合并后转换为Huggingface权重】

下面提供Megatron-Mcore格式的Llama2-7b模型的Lora权重与base权重合并并转为Huggingface格式的示例脚本仅供参考

python convert_ckpt.py \
    --model-type GPT \
    --use-mcore-models \
    --load-model-type mg \
    --save-model-type hf \
    --load-dir ./model_weights/llama-2-7b-mcore/ \
    --lora-load ./ckpt/llama-2-7b-lora \
    --lora-r 16 \
    --lora-alpha 32 \
    --lora-target-modules linear_qkv linear_proj linear_fc1 linear_fc2 \
    --target-tensor-parallel-size 1 \
    --target-pipeline-parallel-size 1 \
    --save-dir ./model_from_hf/llama-2-7b-hf/    # <-- 需要填入原始HF模型路径新权重会存于./model_from_hf/llama-2-7b-hf/mg2hg/

转换脚本命名风格及启动方法为:

#命令启动方式以 llama2 为例
bash examples/mcore/llama2/ckpt_convert_llama2_mcore2hf_lora.sh

注意:

lora参数值需与lora微调时的参数保持一致,且lora权重的切分方式需与base权重的切分方式保持一致。

由于调用peft库合并lora权重后权重数据类型为float16但是部分模型如qwen系列模型默认数据类型为bfloat16合并后的权重转回hf格式会有精度损失问题。可以将原始HF模型的config.json中的数据类型改为float16暂时规避。

moe模型暂不支持开启--moe-grouped-gemm 特性后的lora权重转换

2.3.2 Lora权重转换为Huggingface权重

通过使能参数--save-lora-to-hf,支持将Lora微调后的lora权重转换为Huggingface格式下面提供Llama2-7b模型的Lora权重转为Huggingface格式的示例脚本仅供参考

python convert_ckpt.py \
    --model-type GPT \
    --use-mcore-models \
    --load-model-type mg \
    --save-model-type hf \
    --load-dir ./ckpt/llama2_lora_filter \
    --lora-r 16 \
    --lora-alpha 32 \
    --lora-target-modules linear_qkv linear_proj linear_fc1 linear_fc2 \
    --target-tensor-parallel-size 1 \
    --target-pipeline-parallel-size 1 \
    --load-checkpoint-loosely \
    --save-lora-to-hf \
    --save-dir ./model_from_hf/llama-2-7b-hf/  # <-- 需要填入原始HF模型路径新权重会存于./model_from_hf/llama-2-7b-hf/mg2hf/
参数 说明 可选/必选
--save-lora-to-hf lora转hf时设置此参数以指定仅转换lora权重 可选
--load-checkpoint-loosely 允许松弛加载转换lora权重时需要设置此参数 可选

注意:

原始权重仅为lora权重不包含base权重需要在lora微调脚本中加入参数--lora-ckpt-filter仅保存lora权重

--save-lora-to-hf和--moe-grouped-gemm两个参数不能同时使用,在lora微调时,脚本中不能加入--moe-grouped-gemm参数;

--save-lora-to-hf和--load-hf-from-config两个参数不能同时使用

lora权重转换仅支持mcore格式仅支持fc_type为gate_up_down的模型其余待适配当前仅支持llama2、mixtral。

【启动脚本】

MindSpeed-LLM lora到Huggingface的权重转换脚本命名风格及启动方法为

# 命名及启动:
# bash examples/mcore/model_name/ckpt_convert_xxx_lora2hf.sh
# 需要配置并行参数以及权重词表加载保存等路径

bash examples/mcore/llama2/ckpt_convert_llama2_lora2hf.sh

2.4 优化器权重转换

在权重转换脚本中指定--load-model-type参数为optim , 则为优化器权重转换。

使用方法:

1.准备预训练权重

优化器状态为预训练保存得到,并且需要在预训练脚本中加入参数--use-distributed-optimizer 表示使用分布式优化器,并且删除参数--no-save-optim 使训练生成的每个权重文件夹都包括model_optim_rng.ptdistrib_optim.pt 模型权重文件和优化器状态文件。

2.mg-mg权重转换

优化器权重需要先做一次mg-mg的权重转换并指定所需的切分方式脚本参考2.3中mcore-mcore脚本:

    python convert_ckpt.py \
        --model-type GPT \
        --load-model-type mg \
        --save-model-type mg \
        --target-tensor-parallel-size 4 \
        --target-pipeline-parallel-size 2 \
        --load-dir ./ckpt/llama2-7b-tp2pp4 \
        --save-dir ./ckpt/llama2-7b-tp4pp2-optim \
        --use-mcore-models

在此步骤完成后,--save-dir 中应该会生成 model_optim_rng.pt 格式的权重文件。

3.权重转换优化器:

完成上述步骤后,可以执行优化器权重转换。此时,指定 --load-model-type optim 参数来加载优化器权重并进行转换下面提供Llama2-7b模型的优化器权重转换的示例脚本仅供参考

注意: 并行配置如TP、PP、EP、VPP、num-layer-list、noop-layers等参数需要与mcore-mcore权重转换脚本相同。

python convert_ckpt.py
--model-type GPT
--load-model-type optim
--load-dir ./ckpt/llama2-7b-tp2pp4
--target-tensor-parallel-size 4
--target-pipeline-parallel-size 2
--save-dir ./ckpt/llama2-7b-tp4pp2-optim
--use-mcore-models
--model-type-hf llama2
参数 说明 可选/必选
--save-dir 权重保存路径,需要与`mg-mg`转换时的保存路径一致 必选
--load-model-type 指定加载模型的方式。对于优化器权重转换,必须设置参数值为`optim` 必选
--moe-grouped-gemm 对于moe模型如果在预训练脚本和mcore-mcore权重转换脚本中加入此参数`--moe-grouped-gemm` ,则在优化器权重转换脚本中,也需要加入该参数。 可选

转换脚本命名风格及启动方法为:

# 命令启动方式以 llama2 为例子
bash examples/mcore/llama2/ckpt_convert_llama2_optim.sh

注意

优化器权重转换当前仅支持模型llama2-7b、deepseek2-lite

llama2-7b支持TP、PP、EP、VPP、DPP、noop-layers

deepseek2-lite支持PP、EP、DPP、noop-layers。

权重转换特性清单

MindSpeed-LLM 支持 Huggingface 和 Megatron-Core 之间的权重格式互转,具体功能列表如下:

源格式 目标格式 支持特性 特性入参
HuggingFace Megatron-Core 张量并行 --target-tensor-parallel-size
流水并行 --target-pipeline-parallel-size
流水并行动态划分 --num-layer-list
虚拟流水并行 --num-layers-per-virtual-pipeline-stage
专家并行 --target-expert-model-parallel-size
自定义空操作层 --noop-layers
Megatron-Core Huggingface 张量并行 --target-tensor-parallel-size
流水并行 --target-pipeline-parallel-size
LoRA训练模块 --lora-target-modules
LoRA权重 --lora-load
LoRA r --lora-r
LoRA alpha --lora-alpha
Megatron-Core 张量并行 --target-tensor-parallel-size
流水并行 --target-pipeline-parallel-size
专家并行 --target-expert-model-parallel-size
流水并行动态划分 --num-layer-list
虚拟流水并行 --num-layers-per-virtual-pipeline-stage
LoRA训练模块 --lora-target-modules
LoRA权重 --lora-load
LoRA r --lora-r
LoRA alpha --lora-alpha
自定义空操作层 --noop-layers
distributed-optimizer 张量并行 --target-tensor-parallel-size
流水并行 --target-pipeline-parallel-size
专家并行 --target-expert-model-parallel-size
虚拟流水并行 --num-layers-per-virtual-pipeline-stage
流水并行动态划分 --target-expert-model-parallel-size
自定义空操作层 --noop-layers