mirror of
https://gitee.com/ascend/MindSpeed-LLM.git
synced 2025-12-06 11:28:59 +08:00
!1961 适配 pai-megatron-aux-loss和alltoall-overlap 070
Merge pull request !1961 from jzh/master
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -146,6 +146,9 @@ cython_debug/
|
||||
# megatron core
|
||||
/megatron/
|
||||
|
||||
# mindspeed
|
||||
/mindspeed/
|
||||
|
||||
# User stuff
|
||||
/kernel*/
|
||||
/logs/
|
||||
|
||||
@@ -98,7 +98,7 @@ source /usr/local/Ascend/nnal/atb/set_env.sh
|
||||
git clone https://gitee.com/ascend/MindSpeed.git
|
||||
cd MindSpeed
|
||||
# checkout commit from MindSpeed core_r0.7.0 in 2024.11.04
|
||||
git checkout f3332571
|
||||
git checkout c9d20b5
|
||||
pip install -r requirements.txt
|
||||
pip3 install -e .
|
||||
cd ..
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
# 请按照您的真实环境修改 set_env.sh 路径
|
||||
# 按照您的实际需要修改目录信息并完成对应的TP、PP、EP的参数配置
|
||||
|
||||
source /usr/local/Ascend/ascend-toolkit/set_up.sh
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
|
||||
python convert_ckpt.py \
|
||||
--moe-grouped-gemm \
|
||||
--use-mcore-models \
|
||||
--model-type-hf deepseek2-lite \
|
||||
--model-type GPT \
|
||||
@@ -13,6 +14,7 @@ python convert_ckpt.py \
|
||||
--target-tensor-parallel-size 1 \
|
||||
--target-pipeline-parallel-size 1 \
|
||||
--target-expert-parallel-size 8 \
|
||||
--spec modellink.tasks.models.spec.deepseek_spec layer_spec \
|
||||
--load-dir ./model_from_hf/deepseek_v2_lite/ \
|
||||
--save-dir ./model_weights/deepseek2_lite_mcore/ \
|
||||
--tokenizer-model ./model_from_hf/deepseek_v2_lite/
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
|
||||
python convert_ckpt.py \
|
||||
--moe-grouped-gemm \
|
||||
--use-mcore-models \
|
||||
--model-type-hf deepseek2-lite \
|
||||
--model-type GPT \
|
||||
@@ -11,5 +12,6 @@ python convert_ckpt.py \
|
||||
--target-tensor-parallel-size 1 \
|
||||
--target-pipeline-parallel-size 1 \
|
||||
--target-expert-parallel-size 1 \
|
||||
--spec modellink.tasks.models.spec.deepseek_spec layer_spec \
|
||||
--load-dir ./model_weights/deepseek2_lite_mcore/ \
|
||||
--save-dir ./model/deepseek2_lite/
|
||||
|
||||
@@ -33,6 +33,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS evaluation.py \
|
||||
--no-chat-template \
|
||||
--max-new-tokens 1 \
|
||||
--use-mcore-models \
|
||||
--moe-grouped-gemm \
|
||||
--tensor-model-parallel-size ${TP} \
|
||||
--pipeline-model-parallel-size ${PP} \
|
||||
--num-layers 27 \
|
||||
@@ -89,4 +90,4 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS evaluation.py \
|
||||
--rope-scaling-original-max-position-embeddings 4096 \
|
||||
--rope-scaling-type yarn \
|
||||
--distributed-backend nccl \
|
||||
| tee ./logs/evaluation_mcore_deepseek2_lite_16b_${TASK}.log
|
||||
| tee ./logs/evaluation_deepseek2_lite_ptd_8p_${TASK}.log
|
||||
|
||||
@@ -36,6 +36,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS inference.py \
|
||||
--task chat \
|
||||
--max-new-tokens 256 \
|
||||
--use-mcore-models \
|
||||
--moe-grouped-gemm \
|
||||
--tensor-model-parallel-size ${TP} \
|
||||
--pipeline-model-parallel-size ${PP} \
|
||||
--num-layers 27 \
|
||||
@@ -92,5 +93,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS inference.py \
|
||||
--rope-scaling-original-max-position-embeddings 4096 \
|
||||
--rope-scaling-type yarn \
|
||||
--distributed-backend nccl \
|
||||
| tee logs/generate_mcore_deepseek2_lite.log
|
||||
| tee logs/generate_deepseek2_lite.log
|
||||
|
||||
|
||||
@@ -0,0 +1,150 @@
|
||||
#!/bin/bash
|
||||
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
||||
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
|
||||
|
||||
GPUS_PER_NODE=16
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6000
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
|
||||
|
||||
CKPT_SAVE_DIR="your model save ckpt path"
|
||||
DATA_PATH="your data path"
|
||||
TOKENIZER_MODEL="your tokenizer path"
|
||||
CKPT_LOAD_DIR="your model ckpt path"
|
||||
|
||||
TP=1
|
||||
PP=1
|
||||
EP=8
|
||||
|
||||
DISTRIBUTED_ARGS="
|
||||
--nproc_per_node $GPUS_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
--node_rank $NODE_RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT
|
||||
"
|
||||
|
||||
MLA_ARGS="
|
||||
--spec modellink.tasks.models.spec.deepseek_spec layer_spec \
|
||||
--multi-head-latent-attention \
|
||||
--qk-rope-head-dim 64 \
|
||||
--qk-nope-head-dim 128 \
|
||||
--kv-lora-rank 512 \
|
||||
--v-head-dim 128 \
|
||||
--qk-layernorm \
|
||||
"
|
||||
|
||||
MOE_ARGS="
|
||||
--moe-grouped-gemm \
|
||||
--moe-alltoall-overlap-comm \
|
||||
--moe-permutation-async-comm \
|
||||
--moe-token-dispatcher-type alltoall \
|
||||
--use-fused-moe-token-permute-and-unpermute \
|
||||
--first-k-dense-replace 1 \
|
||||
--moe-layer-freq 1 \
|
||||
--n-shared-experts 2 \
|
||||
--num-experts 64 \
|
||||
--moe-router-topk 6 \
|
||||
--moe-intermediate-size 1408 \
|
||||
--moe-router-load-balancing-type pai_megatron_aux_loss \
|
||||
--topk-group 1 \
|
||||
--moe-aux-loss-coeff 0.01 \
|
||||
--routed-scaling-factor 1.0 \
|
||||
--seq-aux
|
||||
"
|
||||
|
||||
ROPE_ARGS="
|
||||
--rope-scaling-beta-fast 32 \
|
||||
--rope-scaling-beta-slow 1 \
|
||||
--rope-scaling-factor 40 \
|
||||
--rope-scaling-mscale 0.707 \
|
||||
--rope-scaling-mscale-all-dim 0.707 \
|
||||
--rope-scaling-original-max-position-embeddings 4096 \
|
||||
--rope-scaling-type yarn
|
||||
"
|
||||
|
||||
GPT_ARGS="
|
||||
--shape-order BNSD \
|
||||
--reuse-fp32-param \
|
||||
--load $CKPT_LOAD_DIR \
|
||||
--use-distributed-optimizer \
|
||||
--use-flash-attn \
|
||||
--use-mcore-models \
|
||||
--tensor-model-parallel-size ${TP} \
|
||||
--pipeline-model-parallel-size ${PP} \
|
||||
--expert-model-parallel-size ${EP} \
|
||||
--sequence-parallel \
|
||||
--num-layers 27 \
|
||||
--hidden-size 2048 \
|
||||
--ffn-hidden-size 10944 \
|
||||
--num-attention-heads 16 \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--tokenizer-name-or-path ${TOKENIZER_MODEL} \
|
||||
--finetune \
|
||||
--num-workers 8 \
|
||||
--seq-length 4096 \
|
||||
--max-position-embeddings 163840 \
|
||||
--micro-batch-size 1 \
|
||||
--global-batch-size 768 \
|
||||
--make-vocab-size-divisible-by 1 \
|
||||
--lr 2e-5 \
|
||||
--train-iters 2000 \
|
||||
--lr-decay-style cosine \
|
||||
--lr-decay-iters 2000 \
|
||||
--untie-embeddings-and-output-weights \
|
||||
--disable-bias-linear \
|
||||
--attention-dropout 0.0 \
|
||||
--init-method-std 0.02 \
|
||||
--hidden-dropout 0.0 \
|
||||
--position-embedding-type rope \
|
||||
--normalization RMSNorm \
|
||||
--use-fused-rotary-pos-emb \
|
||||
--use-rotary-position-embeddings \
|
||||
--use-fused-swiglu \
|
||||
--use-fused-rmsnorm \
|
||||
--swiglu \
|
||||
--no-masked-softmax-fusion \
|
||||
--attention-softmax-in-fp32 \
|
||||
--min-lr 1.0e-8 \
|
||||
--weight-decay 1e-1 \
|
||||
--lr-warmup-iters 100 \
|
||||
--clip-grad 1.0 \
|
||||
--adam-beta1 0.9 \
|
||||
--adam-beta2 0.95 \
|
||||
--initial-loss-scale 65536 \
|
||||
--vocab-size 102400 \
|
||||
--padded-vocab-size 102400 \
|
||||
--rotary-base 10000 \
|
||||
--no-gradient-accumulation-fusion \
|
||||
--norm-epsilon 1e-6 \
|
||||
--no-load-optim \
|
||||
--no-load-rng \
|
||||
--bf16
|
||||
"
|
||||
|
||||
DATA_ARGS="
|
||||
--data-path $DATA_PATH \
|
||||
--split 99,1,0
|
||||
"
|
||||
|
||||
OUTPUT_ARGS="
|
||||
--log-interval 1 \
|
||||
--save-interval 1000 \
|
||||
--eval-interval 10000 \
|
||||
--eval-iters 10 \
|
||||
--no-save-optim \
|
||||
--no-save-rng
|
||||
"
|
||||
|
||||
python -m torch.distributed.launch $DISTRIBUTED_ARGS pretrain_gpt.py \
|
||||
$GPT_ARGS \
|
||||
$DATA_ARGS \
|
||||
$OUTPUT_ARGS \
|
||||
$MLA_ARGS \
|
||||
$ROPE_ARGS \
|
||||
$MOE_ARGS \
|
||||
--distributed-backend nccl \
|
||||
--save $CKPT_SAVE_DIR \
|
||||
| tee logs/pretrain_deepseek2_lite_ptd_16p.log
|
||||
@@ -37,17 +37,20 @@ MLA_ARGS="
|
||||
"
|
||||
|
||||
MOE_ARGS="
|
||||
--moe-grouped-gemm \
|
||||
--moe-alltoall-overlap-comm \
|
||||
--moe-permutation-async-comm \
|
||||
--moe-token-dispatcher-type allgather \
|
||||
--moe-token-dispatcher-type alltoall \
|
||||
--use-fused-moe-token-permute-and-unpermute \
|
||||
--first-k-dense-replace 1 \
|
||||
--moe-layer-freq 1 \
|
||||
--n-shared-experts 2 \
|
||||
--num-experts 64 \
|
||||
--moe-router-topk 6 \
|
||||
--moe-intermediate-size 1408 \
|
||||
--moe-router-load-balancing-type softmax_topk \
|
||||
--moe-router-load-balancing-type pai_megatron_aux_loss \
|
||||
--topk-group 1 \
|
||||
--moe-aux-loss-coeff 0.001 \
|
||||
--moe-aux-loss-coeff 0.01 \
|
||||
--routed-scaling-factor 1.0 \
|
||||
--seq-aux
|
||||
"
|
||||
@@ -63,6 +66,8 @@ ROPE_ARGS="
|
||||
"
|
||||
|
||||
GPT_ARGS="
|
||||
--shape-order BNSD \
|
||||
--reuse-fp32-param \
|
||||
--load $CKPT_LOAD_DIR \
|
||||
--use-distributed-optimizer \
|
||||
--use-flash-attn \
|
||||
@@ -80,13 +85,16 @@ GPT_ARGS="
|
||||
--num-attention-heads 16 \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--tokenizer-name-or-path ${TOKENIZER_MODEL} \
|
||||
--seq-length 8192 \
|
||||
--finetune \
|
||||
--num-workers 8 \
|
||||
--seq-length 4096 \
|
||||
--max-position-embeddings 163840 \
|
||||
--micro-batch-size 1 \
|
||||
--global-batch-size 8 \
|
||||
--make-vocab-size-divisible-by 1 \
|
||||
--lr 1.0e-6 \
|
||||
--train-iters 2000 \
|
||||
--lr 2e-5 \
|
||||
--train-iters 462240 \
|
||||
--lr-decay-iters 462240 \
|
||||
--lr-decay-style cosine \
|
||||
--untie-embeddings-and-output-weights \
|
||||
--disable-bias-linear \
|
||||
@@ -103,11 +111,11 @@ GPT_ARGS="
|
||||
--no-masked-softmax-fusion \
|
||||
--attention-softmax-in-fp32 \
|
||||
--min-lr 1.0e-8 \
|
||||
--weight-decay 1e-2 \
|
||||
--lr-warmup-iters 500 \
|
||||
--weight-decay 1e-1 \
|
||||
--lr-warmup-iters 1920 \
|
||||
--clip-grad 1.0 \
|
||||
--adam-beta1 0.9 \
|
||||
--adam-beta2 0.999 \
|
||||
--adam-beta2 0.95 \
|
||||
--initial-loss-scale 65536 \
|
||||
--vocab-size 102400 \
|
||||
--padded-vocab-size 102400 \
|
||||
@@ -121,14 +129,14 @@ GPT_ARGS="
|
||||
|
||||
DATA_ARGS="
|
||||
--data-path $DATA_PATH \
|
||||
--split 100,0,0
|
||||
--split 99,1,0
|
||||
"
|
||||
|
||||
OUTPUT_ARGS="
|
||||
--log-interval 1 \
|
||||
--save-interval 20000 \
|
||||
--eval-interval 20000 \
|
||||
--eval-iters 0 \
|
||||
--save-interval 1000 \
|
||||
--eval-interval 10000 \
|
||||
--eval-iters 10 \
|
||||
--no-save-optim \
|
||||
--no-save-rng
|
||||
"
|
||||
@@ -142,4 +150,4 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS pretrain_gpt.py \
|
||||
$MOE_ARGS \
|
||||
--distributed-backend nccl \
|
||||
--save $CKPT_SAVE_DIR \
|
||||
| tee logs/npu_pretrain_mcore_deepseek2_lite_ptd_8p.log
|
||||
| tee logs/pretrain_deepseek2_lite_ptd_8p.log
|
||||
|
||||
@@ -37,17 +37,20 @@ MLA_ARGS="
|
||||
"
|
||||
|
||||
MOE_ARGS="
|
||||
--moe-grouped-gemm \
|
||||
--moe-alltoall-overlap-comm \
|
||||
--moe-permutation-async-comm \
|
||||
--moe-token-dispatcher-type alltoall \
|
||||
--use-fused-moe-token-permute-and-unpermute \
|
||||
--first-k-dense-replace 1 \
|
||||
--moe-layer-freq 1 \
|
||||
--n-shared-experts 2 \
|
||||
--num-experts 64 \
|
||||
--moe-router-topk 6 \
|
||||
--moe-intermediate-size 1408 \
|
||||
--moe-router-load-balancing-type softmax_topk \
|
||||
--moe-router-load-balancing-type aux_loss \
|
||||
--topk-group 1 \
|
||||
--moe-aux-loss-coeff 0.001 \
|
||||
--moe-aux-loss-coeff 0.01 \
|
||||
--routed-scaling-factor 1.0 \
|
||||
--seq-aux
|
||||
"
|
||||
@@ -72,6 +75,7 @@ FITUNE_ARGS="
|
||||
"
|
||||
|
||||
GPT_ARGS="
|
||||
--shape-order BNSD \
|
||||
--load $CKPT_LOAD_DIR \
|
||||
--use-distributed-optimizer \
|
||||
--use-flash-attn \
|
||||
@@ -89,18 +93,20 @@ GPT_ARGS="
|
||||
--num-attention-heads 16 \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--tokenizer-name-or-path ${TOKENIZER_MODEL} \
|
||||
--seq-length 8192 \
|
||||
--num-workers 8 \
|
||||
--seq-length 4096 \
|
||||
--max-position-embeddings 163840 \
|
||||
--micro-batch-size 1 \
|
||||
--global-batch-size 8 \
|
||||
--make-vocab-size-divisible-by 1 \
|
||||
--lr 5e-5 \
|
||||
--lr 9e-6 \
|
||||
--train-iters 2000 \
|
||||
--lr-decay-style constant \
|
||||
--lr-decay-style cosine \
|
||||
--lr-decay-iters 2000 \
|
||||
--untie-embeddings-and-output-weights \
|
||||
--disable-bias-linear \
|
||||
--attention-dropout 0.0 \
|
||||
--init-method-std 0.02 \
|
||||
--init-method-std 0.008 \
|
||||
--hidden-dropout 0.0 \
|
||||
--position-embedding-type rope \
|
||||
--normalization RMSNorm \
|
||||
@@ -109,12 +115,13 @@ GPT_ARGS="
|
||||
--use-fused-swiglu \
|
||||
--use-fused-rmsnorm \
|
||||
--swiglu \
|
||||
--dataloader-type cyclic \
|
||||
--no-masked-softmax-fusion \
|
||||
--attention-softmax-in-fp32 \
|
||||
--weight-decay 0e0 \
|
||||
--weight-decay 0.1 \
|
||||
--clip-grad 1.0 \
|
||||
--adam-beta1 0.9 \
|
||||
--adam-beta2 0.999 \
|
||||
--adam-beta2 0.95 \
|
||||
--initial-loss-scale 1 \
|
||||
--vocab-size 102400 \
|
||||
--padded-vocab-size 102400 \
|
||||
@@ -124,7 +131,7 @@ GPT_ARGS="
|
||||
--no-load-optim \
|
||||
--no-load-rng \
|
||||
--bf16 \
|
||||
--reuse-fp32-param
|
||||
--reuse-fp32-param \
|
||||
"
|
||||
|
||||
DATA_ARGS="
|
||||
@@ -151,4 +158,4 @@ torchrun $DISTRIBUTED_ARGS posttrain_gpt.py \
|
||||
$FITUNE_ARGS \
|
||||
--distributed-backend nccl \
|
||||
--save $CKPT_SAVE_DIR \
|
||||
| tee ./logs/npu_tune_mcore_deepseek2_lite.log
|
||||
| tee ./logs/tune_deepseek2_lite_ptd_8p.log
|
||||
|
||||
@@ -0,0 +1,162 @@
|
||||
#!/bin/bash
|
||||
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
||||
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
|
||||
|
||||
GPUS_PER_NODE=16
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6000
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
|
||||
|
||||
CKPT_SAVE_DIR="your checkpoint save path"
|
||||
DATA_PATH="your finetune dataset path"
|
||||
TOKENIZER_MODEL="your tokenizer model path"
|
||||
CKPT_LOAD_DIR="your checkpoint load path"
|
||||
|
||||
TP=1
|
||||
PP=1
|
||||
EP=8
|
||||
|
||||
DISTRIBUTED_ARGS="
|
||||
--nproc_per_node $GPUS_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
--node_rank $NODE_RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT
|
||||
"
|
||||
|
||||
MLA_ARGS="
|
||||
--spec modellink.tasks.models.spec.deepseek_spec layer_spec \
|
||||
--multi-head-latent-attention \
|
||||
--qk-rope-head-dim 64 \
|
||||
--qk-nope-head-dim 128 \
|
||||
--kv-lora-rank 512 \
|
||||
--v-head-dim 128 \
|
||||
--qk-layernorm \
|
||||
"
|
||||
|
||||
MOE_ARGS="
|
||||
--moe-grouped-gemm \
|
||||
--moe-alltoall-overlap-comm \
|
||||
--moe-permutation-async-comm \
|
||||
--moe-token-dispatcher-type alltoall \
|
||||
--use-fused-moe-token-permute-and-unpermute \
|
||||
--first-k-dense-replace 1 \
|
||||
--moe-layer-freq 1 \
|
||||
--n-shared-experts 2 \
|
||||
--num-experts 64 \
|
||||
--moe-router-topk 6 \
|
||||
--moe-intermediate-size 1408 \
|
||||
--moe-router-load-balancing-type aux_loss \
|
||||
--topk-group 1 \
|
||||
--moe-aux-loss-coeff 0.01 \
|
||||
--routed-scaling-factor 1.0 \
|
||||
--seq-aux
|
||||
"
|
||||
|
||||
ROPE_ARGS="
|
||||
--rope-scaling-beta-fast 32 \
|
||||
--rope-scaling-beta-slow 1 \
|
||||
--rope-scaling-factor 40 \
|
||||
--rope-scaling-mscale 0.707 \
|
||||
--rope-scaling-mscale-all-dim 0.707 \
|
||||
--rope-scaling-original-max-position-embeddings 4096 \
|
||||
--rope-scaling-type yarn
|
||||
"
|
||||
|
||||
FITUNE_ARGS="
|
||||
--stage sft \
|
||||
--finetune \
|
||||
--is-instruction-dataset \
|
||||
--variable-seq-lengths \
|
||||
--prompt-type deepseek2-lite \
|
||||
--tokenizer-not-use-fast \
|
||||
"
|
||||
|
||||
|
||||
GPT_ARGS="
|
||||
--shape-order BNSD \
|
||||
--load $CKPT_LOAD_DIR \
|
||||
--use-distributed-optimizer \
|
||||
--use-flash-attn \
|
||||
--use-mcore-models \
|
||||
--reuse-fp32-param \
|
||||
--tensor-model-parallel-size ${TP} \
|
||||
--pipeline-model-parallel-size ${PP} \
|
||||
--expert-model-parallel-size ${EP} \
|
||||
--sequence-parallel \
|
||||
--num-layers 27 \
|
||||
--recompute-granularity full \
|
||||
--recompute-method uniform \
|
||||
--recompute-num-layers 1 \
|
||||
--hidden-size 2048 \
|
||||
--ffn-hidden-size 10944 \
|
||||
--num-attention-heads 16 \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--tokenizer-name-or-path ${TOKENIZER_MODEL} \
|
||||
--num-workers 8 \
|
||||
--seq-length 4096 \
|
||||
--max-position-embeddings 163840 \
|
||||
--micro-batch-size 1 \
|
||||
--global-batch-size 768 \
|
||||
--make-vocab-size-divisible-by 1 \
|
||||
--lr 9e-6 \
|
||||
--train-iters 462240 \
|
||||
--lr-decay-style cosine \
|
||||
--lr-decay-iters 462240 \
|
||||
--untie-embeddings-and-output-weights \
|
||||
--disable-bias-linear \
|
||||
--attention-dropout 0.0 \
|
||||
--init-method-std 0.008 \
|
||||
--hidden-dropout 0.0 \
|
||||
--position-embedding-type rope \
|
||||
--normalization RMSNorm \
|
||||
--use-fused-rotary-pos-emb \
|
||||
--use-rotary-position-embeddings \
|
||||
--use-fused-swiglu \
|
||||
--use-fused-rmsnorm \
|
||||
--swiglu \
|
||||
--dataloader-type cyclic \
|
||||
--no-masked-softmax-fusion \
|
||||
--attention-softmax-in-fp32 \
|
||||
--weight-decay 0.1 \
|
||||
--clip-grad 1.0 \
|
||||
--adam-beta1 0.9 \
|
||||
--adam-beta2 0.95 \
|
||||
--initial-loss-scale 1 \
|
||||
--vocab-size 102400 \
|
||||
--padded-vocab-size 102400 \
|
||||
--rotary-base 10000 \
|
||||
--no-gradient-accumulation-fusion \
|
||||
--norm-epsilon 1e-6 \
|
||||
--no-load-optim \
|
||||
--no-load-rng \
|
||||
--bf16 \
|
||||
"
|
||||
|
||||
DATA_ARGS="
|
||||
--data-path $DATA_PATH \
|
||||
--split 100,0,0
|
||||
"
|
||||
|
||||
OUTPUT_ARGS="
|
||||
--log-interval 1 \
|
||||
--save-interval 2000 \
|
||||
--eval-interval 1000 \
|
||||
--eval-iters 0 \
|
||||
--no-save-optim \
|
||||
--no-save-rng
|
||||
"
|
||||
|
||||
torchrun $DISTRIBUTED_ARGS posttrain_gpt.py \
|
||||
$GPT_ARGS \
|
||||
$DATA_ARGS \
|
||||
$OUTPUT_ARGS \
|
||||
$MLA_ARGS \
|
||||
$ROPE_ARGS \
|
||||
$MOE_ARGS \
|
||||
$FITUNE_ARGS \
|
||||
--distributed-backend nccl \
|
||||
--save $CKPT_SAVE_DIR \
|
||||
| tee ./logs/tune_deepseek2_lite_ptd_16p.log
|
||||
@@ -69,7 +69,7 @@ def should_recompute_activation(self):
|
||||
return recompute_priority < activation_recompute_layers
|
||||
|
||||
|
||||
def core_mlp_init(self, config, submodules, is_expert=False, input_size=None):
|
||||
def core_mlp_init(self, config, submodules, is_expert=False, input_size=None, shared_expert=False):
|
||||
super(MLP, self).__init__(config=config)
|
||||
|
||||
self.config: TransformerConfig = config
|
||||
@@ -94,30 +94,62 @@ def core_mlp_init(self, config, submodules, is_expert=False, input_size=None):
|
||||
if self.config.gated_linear_unit:
|
||||
ffn_hidden_size *= 2
|
||||
|
||||
self.linear_fc1 = build_module(
|
||||
submodules.linear_fc1,
|
||||
self.input_size,
|
||||
ffn_hidden_size,
|
||||
config=self.config,
|
||||
init_method=self.config.init_method,
|
||||
gather_output=False,
|
||||
bias=self.config.add_bias_linear,
|
||||
skip_bias_add=True,
|
||||
is_expert=is_expert,
|
||||
tp_comm_buffer_name='fc1',
|
||||
)
|
||||
if shared_expert:
|
||||
self.linear_fc1 = build_module(
|
||||
submodules.linear_fc1,
|
||||
self.input_size,
|
||||
ffn_hidden_size,
|
||||
config=self.config,
|
||||
init_method=self.config.init_method,
|
||||
gather_output=False,
|
||||
bias=self.config.add_bias_linear,
|
||||
skip_bias_add=True,
|
||||
is_expert=is_expert,
|
||||
tp_comm_buffer_name='fc1',
|
||||
shared_expert=shared_expert
|
||||
)
|
||||
else:
|
||||
self.linear_fc1 = build_module(
|
||||
submodules.linear_fc1,
|
||||
self.input_size,
|
||||
ffn_hidden_size,
|
||||
config=self.config,
|
||||
init_method=self.config.init_method,
|
||||
gather_output=False,
|
||||
bias=self.config.add_bias_linear,
|
||||
skip_bias_add=True,
|
||||
is_expert=is_expert,
|
||||
tp_comm_buffer_name='fc1'
|
||||
)
|
||||
|
||||
self.activation_func = self.config.activation_func
|
||||
|
||||
self.linear_fc2 = build_module(
|
||||
submodules.linear_fc2,
|
||||
self.config.ffn_hidden_size,
|
||||
self.config.hidden_size,
|
||||
config=self.config,
|
||||
init_method=self.config.output_layer_init_method,
|
||||
bias=self.config.add_bias_linear,
|
||||
input_is_parallel=True,
|
||||
skip_bias_add=True,
|
||||
is_expert=is_expert,
|
||||
tp_comm_buffer_name='fc2',
|
||||
)
|
||||
if shared_expert:
|
||||
self.linear_fc2 = build_module(
|
||||
submodules.linear_fc2,
|
||||
self.config.ffn_hidden_size,
|
||||
self.config.hidden_size,
|
||||
config=self.config,
|
||||
init_method=self.config.output_layer_init_method,
|
||||
bias=self.config.add_bias_linear,
|
||||
input_is_parallel=True,
|
||||
skip_bias_add=True,
|
||||
is_expert=is_expert,
|
||||
tp_comm_buffer_name='fc2',
|
||||
shared_expert=shared_expert
|
||||
)
|
||||
else:
|
||||
self.linear_fc2 = build_module(
|
||||
submodules.linear_fc2,
|
||||
self.config.ffn_hidden_size,
|
||||
self.config.hidden_size,
|
||||
config=self.config,
|
||||
init_method=self.config.output_layer_init_method,
|
||||
bias=self.config.add_bias_linear,
|
||||
input_is_parallel=True,
|
||||
skip_bias_add=True,
|
||||
is_expert=is_expert,
|
||||
tp_comm_buffer_name='fc2'
|
||||
)
|
||||
|
||||
self.shared_expert = shared_expert
|
||||
|
||||
@@ -13,6 +13,7 @@ from megatron.core.transformer.mlp import MLPSubmodules, MLP
|
||||
from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP
|
||||
from megatron.core.transformer.moe.moe_utils import save_to_aux_losses_tracker
|
||||
from megatron.training import get_args
|
||||
from mindspeed.core.transformer.moe.moe_layer_overlap_all2all import MoELayerOverlapAll2All
|
||||
|
||||
|
||||
def moe_layer_init_wrapper(init_func):
|
||||
@@ -34,8 +35,14 @@ def moe_layer_init_wrapper(init_func):
|
||||
if global_args.n_shared_experts:
|
||||
config = deepcopy(self.config)
|
||||
config.ffn_hidden_size = global_args.n_shared_experts * self.config.ffn_hidden_size
|
||||
self.shared_experts = MLP(config, MLPSubmodules(linear_fc1=ColumnParallelLinear,
|
||||
linear_fc2=RowParallelLinear,))
|
||||
|
||||
if global_args.moe_allgather_overlap_comm or global_args.moe_alltoall_overlap_comm:
|
||||
from mindspeed.core.transformer.moe.layers import ColumnParallelLinear, RowParallelLinear
|
||||
self.shared_experts = MLP(config, MLPSubmodules(linear_fc1=ColumnParallelLinear,linear_fc2 = RowParallelLinear), shared_expert=True)
|
||||
else:
|
||||
from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear
|
||||
self.shared_experts = MLP(config, MLPSubmodules(linear_fc1=ColumnParallelLinear,linear_fc2 = RowParallelLinear))
|
||||
|
||||
# For using layer_number when recompute activation function is enabled.
|
||||
self.shared_experts.layer_number = self.layer_number
|
||||
if global_args.shared_expert_gate:
|
||||
@@ -53,6 +60,10 @@ def moe_layer_init_wrapper(init_func):
|
||||
|
||||
|
||||
def moe_layer_forward(self, hidden_states: torch.Tensor):
|
||||
global_args = get_args()
|
||||
if global_args.moe_token_dispatcher_type == 'alltoall' and global_args.moe_alltoall_overlap_comm:
|
||||
return MoELayerOverlapAll2All.apply(hidden_states, self)
|
||||
|
||||
# process MoE
|
||||
scores, indices = self.router(hidden_states)
|
||||
|
||||
|
||||
@@ -21,9 +21,8 @@ from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
|
||||
from megatron.training import get_args
|
||||
from megatron.core.transformer.moe.moe_utils import MoEAuxLossAutoScaler, save_to_aux_losses_tracker
|
||||
from megatron.core import parallel_state
|
||||
|
||||
from .moe_utils import topk_softmax_with_capacity, switch_load_balancing_loss_func
|
||||
|
||||
from modellink.tasks.models.common.pai_megatron import pai_megatron_aux_loss
|
||||
|
||||
def group_limited_greedy_topKgating(self, logits: torch.Tensor):
|
||||
args = get_args()
|
||||
@@ -234,6 +233,8 @@ def topk_router_routing(self, logits: torch.Tensor):
|
||||
scores, indices = torch.topk(logits_, k=self.topk, dim=1)
|
||||
elif self.routing_type == "group_limited_greedy":
|
||||
scores, indices = group_limited_greedy_topKgating(self, logits)
|
||||
elif self.routing_type == "pai_megatron_aux_loss":
|
||||
scores, indices = pai_megatron_aux_loss(self, logits)
|
||||
elif self.routing_type == "none":
|
||||
# A naive top-k routing without load balancing
|
||||
# top_logits, indices = torch.topk(logits, k=self.topk, dim=1)
|
||||
|
||||
@@ -235,14 +235,13 @@ class CoreAdaptation(MegatronAdaptationABC):
|
||||
transformer_block_checkpointed_forward_wrapper)
|
||||
|
||||
def patch_core_transformers(self):
|
||||
from mindspeed.core.transformer.moe.router import aux_loss_load_balancing
|
||||
from mindspeed.core.transformer.moe.token_dispatcher import allgather_token_permutation, \
|
||||
allgather_token_unpermutation
|
||||
from mindspeed.core.transformer.moe.grouped_gemm_util import Ops, grouped_gemm_is_available, \
|
||||
get_device_capability, assert_grouped_gemm_is_available
|
||||
from mindspeed.core.transformer.transformer import core_mlp_forward_wrapper
|
||||
from mindspeed.core.transformer.moe.moe_utils import permute, unpermute
|
||||
|
||||
from mindspeed.core.transformer.moe.experts import group_mlp_forward
|
||||
from ..core.transformer.moe.moe_layer import moe_layer_init_wrapper, moe_layer_forward
|
||||
from ..core.transformer.transformer_block import _transformer_block_build_layers
|
||||
from ..core.transformer.transformer_layer import transformer_layer_init_wrapper
|
||||
@@ -286,6 +285,7 @@ class CoreAdaptation(MegatronAdaptationABC):
|
||||
args = MegatronAdaptation.get_args()
|
||||
if args.moe_permutation_async_comm:
|
||||
if args.moe_token_dispatcher_type == 'allgather':
|
||||
from mindspeed.core.transformer.moe.router import aux_loss_load_balancing
|
||||
MegatronAdaptation.register(
|
||||
'megatron.core.transformer.moe.token_dispatcher.MoEAllGatherTokenDispatcher.token_permutation',
|
||||
allgather_token_permutation)
|
||||
@@ -300,13 +300,26 @@ class CoreAdaptation(MegatronAdaptationABC):
|
||||
MegatronAdaptation.register(
|
||||
'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.preprocess',
|
||||
preprocess)
|
||||
MegatronAdaptation.register(
|
||||
'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.token_permutation',
|
||||
alltoall_token_permutation)
|
||||
MegatronAdaptation.register('megatron.core.transformer.moe.experts.SequentialMLP.forward', sequential_mlp_forward)
|
||||
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute', permute)
|
||||
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.unpermute', unpermute)
|
||||
|
||||
if args.moe_alltoall_overlap_comm:
|
||||
from mindspeed.core.transformer.moe.token_dispatcher import alltoall_token_permutation_new, \
|
||||
alltoall_token_unpermutation_new
|
||||
MegatronAdaptation.register('megatron.core.transformer.moe.experts.GroupedMLP.forward',
|
||||
group_mlp_forward)
|
||||
MegatronAdaptation.register(
|
||||
'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.token_permutation',
|
||||
alltoall_token_permutation_new)
|
||||
MegatronAdaptation.register(
|
||||
'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.token_unpermutation',
|
||||
alltoall_token_unpermutation_new)
|
||||
else:
|
||||
MegatronAdaptation.register(
|
||||
'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.token_permutation',
|
||||
alltoall_token_permutation)
|
||||
|
||||
if hasattr(args, 'use_fused_moe_token_permute_and_unpermute') and args.use_fused_moe_token_permute_and_unpermute and not args.moe_expert_capacity_factor:
|
||||
from mindspeed.core.fusions.npu_moe_token_permute import permute_wrapper
|
||||
from mindspeed.core.fusions.npu_moe_token_unpermute import unpermute_wrapper
|
||||
|
||||
40
modellink/tasks/models/common/pai_megatron.py
Normal file
40
modellink/tasks/models/common/pai_megatron.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from megatron.training import get_args
|
||||
|
||||
def pai_megatron_aux_loss(self, logits: torch.Tensor):
|
||||
routing_weights = torch.softmax(logits, dim=1, dtype=torch.float32).type_as(logits)
|
||||
scores, indices = torch.topk(routing_weights, k=self.topk, dim=-1)
|
||||
|
||||
# TopK without capacity
|
||||
num_experts = logits.shape[1]
|
||||
tokens_per_expert = torch.histc(indices, bins=num_experts, min=0, max=num_experts)
|
||||
|
||||
# Apply load balancing loss
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
||||
scores = self.apply_load_balancing_loss(probs, tokens_per_expert, activation=scores)
|
||||
|
||||
args = get_args()
|
||||
global_indices = indices
|
||||
if args.moe_token_dispatcher_type == "allgather":
|
||||
if args.moe_permutation_async_comm and (
|
||||
self.config.sequence_parallel or (self.config.expert_model_parallel_size > 1)):
|
||||
from mindspeed.core.transformer.moe.router import gather_from_sequence_parallel_region_to_moe_async
|
||||
with torch.no_grad():
|
||||
global_indices = gather_from_sequence_parallel_region_to_moe_async(indices)
|
||||
return scores, global_indices
|
||||
|
||||
@@ -185,7 +185,7 @@ class DecoderPackedMTFDataset(torch.utils.data.Dataset):
|
||||
|
||||
def _cut_token(self, token, dtype):
|
||||
token_length = len(token)
|
||||
if token_length >= self.seq_length:
|
||||
if not self.args.no_cut_token and token_length >= self.seq_length:
|
||||
token = token[:self.seq_length]
|
||||
return token.astype(dtype)
|
||||
|
||||
|
||||
@@ -288,13 +288,14 @@ def _add_moe_args(parser):
|
||||
group.add_argument('--moe-router-topk', type=int, default=2,
|
||||
help='Number of experts to route to for each token. The default is 2.')
|
||||
group.add_argument('--moe-router-load-balancing-type', type=str,
|
||||
choices=['aux_loss', "group_limited_greedy", "softmax_topk"],
|
||||
choices=['aux_loss', "group_limited_greedy", "softmax_topk", "pai_megatron_aux_loss"],
|
||||
default='aux_loss',
|
||||
help='Determines the load balancing strategy for the router. "aux_loss" corresponds '
|
||||
'to the load balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds '
|
||||
'to the balancing algorithm used in S-BASE, "softmax_topk" implies no load balancing and '
|
||||
'softmax before topk , "None" implies no load balancing, and "group_limited_greedy" corresponds '
|
||||
'to the Device-Limited Routing method in DeepSeekV2.'
|
||||
'to the Device-Limited Routing method in DeepSeekV2. and "pai_megatron_aux_loss" corresponds '
|
||||
' to the load balancing loss used in pai-megatron loss'
|
||||
'The default is "aux_loss".')
|
||||
group.add_argument('--expert-interval', type=int, default=1,
|
||||
help='Use experts in every "expert-interval" layers')
|
||||
@@ -334,6 +335,8 @@ def _add_moe_args(parser):
|
||||
help="moe model shared expert gate output dimension for qwen2 moe, this parameter can only configured with"
|
||||
"1 or hidden_state")
|
||||
group.add_argument("--fix-router", action='store_true', help="fix router for load balancing.")
|
||||
group.add_argument('--moe-alltoall-overlap-comm', action='store_true', default=False,
|
||||
help='moe_alltoall_overlap_comm')
|
||||
return parser
|
||||
|
||||
|
||||
@@ -589,6 +592,8 @@ def _add_training_args(parser):
|
||||
help='scale embed tokens')
|
||||
group.add_argument('--dim-model-base', type=float, default=None,
|
||||
help='dim-model-base')
|
||||
group.add_argument('--no-cut-token', action='store_true', default=False,
|
||||
help='Used for not cut token in finetune.')
|
||||
group.add_argument('--scale-depth', type=float, default=None,
|
||||
help='scale-depth')
|
||||
group.add_argument('--swap-attention', action='store_true', default=False,
|
||||
@@ -771,6 +776,14 @@ def _validate_moe_args(args):
|
||||
raise AssertionError('shared expert gate output dimension can only be configured with 1 or hidden_size')
|
||||
if hasattr(args, 'use_fused_moe_token_permute_and_unpermute') and args.use_fused_moe_token_permute_and_unpermute:
|
||||
raise AssertionError('moe_expert_capacity_factor mode does not support use_fused_moe_token_permute_and_unpermute')
|
||||
if args.moe_alltoall_overlap_comm:
|
||||
if not args.moe_permutation_async_comm or not args.moe_grouped_gemm:
|
||||
raise AssertionError(
|
||||
'`--moe-alltoall-overlap-comm` or `--moe-allgather-overlap-comm` only support with `--moe-permutation-async-comm` and `--moe-grouped-gemm`.')
|
||||
if args.moe_alltoall_overlap_comm and not args.moe_token_dispatcher_type == 'alltoall':
|
||||
raise AssertionError('`--moe-alltoall-overlap-comm` only support with `--moe-token-dispatcher-type alltoall`.')
|
||||
if args.moe_alltoall_overlap_comm and args.tensor_model_parallel_size > 1:
|
||||
raise AssertionError('`--moe-alltoall-overlap-comm` do not support tp for now.')
|
||||
|
||||
|
||||
def _validate_mla(args):
|
||||
@@ -962,11 +975,12 @@ def _add_dummy_args(args):
|
||||
args.tp_x = 1
|
||||
args.tp_y = 1
|
||||
args.use_nd_matmul = False
|
||||
args.moe_alltoall_overlap_comm = False
|
||||
args.moe_allgather_overlap_comm = False
|
||||
args.moe_without_activation = False
|
||||
args.disable_gloo_group = None
|
||||
args.ampipe_degree = 0
|
||||
args.moe_zero_memory = 'disable'
|
||||
args.moe_zero_memory_num_layers = None
|
||||
|
||||
|
||||
def _validate_noop_layer(args):
|
||||
|
||||
Reference in New Issue
Block a user