mirror of
https://gitee.com/mindspore/mindformers.git
synced 2025-12-06 11:29:59 +08:00
!6218 【br_feature_mcore】添加Mcore新接口中VocabParallelCrossEntropy 的测试用例
Merge pull request !6218 from JavaZero/mcore_add_ce_test
This commit is contained in:
@@ -54,10 +54,10 @@ class _LogSoftmax(nn.Cell):
|
||||
def __init__(self, parallel_config=default_dpmp_config):
|
||||
super(_LogSoftmax, self).__init__()
|
||||
dp = parallel_config.data_parallel
|
||||
mp = parallel_config.model_parallel
|
||||
mp = parallel_config.tensor_parallel
|
||||
# on/off value for onehot, for smooth labeling, modify the off_value
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.on_value = Tensor(1.0, mstype.int32)
|
||||
self.off_value = Tensor(0.0, mstype.int32)
|
||||
|
||||
self.sum = SumExt().shard(((dp, mp),))
|
||||
self.max = ArgMaxWithValue(axis=1, keep_dims=True).shard(
|
||||
@@ -108,7 +108,7 @@ class _NLLLoss(nn.Cell):
|
||||
def __init__(self, parallel_config=default_dpmp_config):
|
||||
super(_NLLLoss, self).__init__()
|
||||
dp = parallel_config.data_parallel
|
||||
mp = parallel_config.model_parallel
|
||||
mp = parallel_config.tensor_parallel
|
||||
self.repeat_loss = 1
|
||||
self.gather_d = GatherD()
|
||||
self.expand_dims = ExpandDims()
|
||||
@@ -225,7 +225,7 @@ class CrossEntropyLoss(nn.Cell):
|
||||
calculate_per_token_loss=False, seq_split_num=1, **kwargs):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
dp = parallel_config.data_parallel
|
||||
mp = parallel_config.model_parallel
|
||||
mp = parallel_config.tensor_parallel
|
||||
self.seq_pipe = seq_split_num > 1
|
||||
self.kwargs = kwargs
|
||||
self.enable_force_redistribute = False
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
# Copyright 2025 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Data generation utilities for VocabParallelCrossEntropy tests with random data"""
|
||||
import numpy as np
|
||||
|
||||
def get_init_params(batch_size, seq_length, vocab_size):
|
||||
"""
|
||||
Generates random initial parameters (inputs) for VocabParallelCrossEntropy.
|
||||
"""
|
||||
np.random.seed(42)
|
||||
|
||||
logits_shape = (batch_size * seq_length, vocab_size)
|
||||
logits = 0.01 * np.random.randn(*logits_shape).astype(np.float32)
|
||||
|
||||
target_shape = (batch_size * seq_length,)
|
||||
target = np.random.randint(0, vocab_size, size=target_shape).astype(np.int32)
|
||||
input_mask = np.random.randint(0, 2, size=target_shape).astype(np.float32)
|
||||
|
||||
if np.sum(input_mask) == 0 and input_mask.size > 0:
|
||||
input_mask[0] = 1.0
|
||||
|
||||
return {
|
||||
"logits": logits,
|
||||
"target": target,
|
||||
"input_mask": input_mask,
|
||||
}
|
||||
|
||||
|
||||
|
||||
GOLDEN_DATA = {
|
||||
"numerator": np.array(41.57823944091796875),
|
||||
"denominator": np.array(6.),
|
||||
}
|
||||
|
||||
GPU_DATA = {
|
||||
"numerator": np.array(41.57823944091796875),
|
||||
"denominator": np.array(6.),
|
||||
}
|
||||
@@ -0,0 +1,124 @@
|
||||
# Copyright 2025 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Run VocabParallelCrossEntropy accuracy test with configurable parameters via args"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore.communication import init
|
||||
from mindformers.experimental.graph.loss_func import VocabParallelCrossEntropy
|
||||
from mindformers.experimental.graph.transformer.transformer_config import TransformerConfig
|
||||
from data_gen_utils import get_init_params
|
||||
|
||||
SCRIPT_DIR = Path(__file__).parent.resolve()
|
||||
|
||||
|
||||
class VocabParallelCrossEntropyRunner:
|
||||
"""Class to manage VocabParallelCrossEntropy model and data"""
|
||||
|
||||
def __init__(self, args_from_parser):
|
||||
self.args = args_from_parser
|
||||
self.check_for_nan_in_loss_and_grad = self.args.check_for_nan_in_loss_and_grad
|
||||
self.calculate_per_token_loss = self.args.calculate_per_token_loss
|
||||
|
||||
self.vocab_size = self.args.vocab_size
|
||||
self.batch_size = self.args.batch_size
|
||||
self.seq_length = self.args.seq_length
|
||||
|
||||
rank_id_str = os.environ.get("RANK_ID")
|
||||
self.rank_id = int(rank_id_str) if rank_id_str is not None else None
|
||||
|
||||
self.worker_num = int(os.environ.get("MS_WORKER_NUM", "1"))
|
||||
|
||||
if self.rank_id is not None:
|
||||
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, full_batch=True)
|
||||
init()
|
||||
|
||||
self.config = TransformerConfig(
|
||||
data_parallel=self.worker_num // self.args.tensor_parallel,
|
||||
tensor_parallel=self.args.tensor_parallel,
|
||||
num_attention_heads=self.args.tensor_parallel,
|
||||
)
|
||||
|
||||
init_params_data = get_init_params(self.batch_size, self.seq_length, self.vocab_size)
|
||||
|
||||
logits = init_params_data.get("logits")
|
||||
|
||||
self.logits = ms.Tensor(logits, dtype=ms.float32)
|
||||
self.target = ms.Tensor(
|
||||
init_params_data.get("target").reshape((self.batch_size, self.seq_length)).transpose((1, 0)).reshape(-1),
|
||||
dtype=ms.int32,
|
||||
)
|
||||
self.input_mask = ms.Tensor(
|
||||
init_params_data.get("input_mask")
|
||||
.reshape((self.batch_size, self.seq_length))
|
||||
.transpose((1, 0))
|
||||
.reshape(-1),
|
||||
dtype=ms.int32,
|
||||
)
|
||||
|
||||
def build_model(self):
|
||||
"""Build VocabParallelCrossEntropy model"""
|
||||
net = VocabParallelCrossEntropy(
|
||||
parallel_config=self.config,
|
||||
check_for_nan_in_loss_and_grad=self.check_for_nan_in_loss_and_grad,
|
||||
calculate_per_token_loss=self.calculate_per_token_loss,
|
||||
)
|
||||
return net
|
||||
|
||||
def run(self):
|
||||
"""Run the model with given inputs"""
|
||||
net = self.build_model()
|
||||
|
||||
result = net(self.logits, self.target, self.input_mask)
|
||||
|
||||
output_ms = {}
|
||||
if not self.calculate_per_token_loss:
|
||||
output_ms["loss"] = result
|
||||
else:
|
||||
numerator, denominator = result
|
||||
output_ms["numerator"] = numerator
|
||||
output_ms["denominator"] = denominator
|
||||
|
||||
if self.rank_id is None or self.rank_id == 0:
|
||||
output_np = {k: v.asnumpy().astype(np.float32) for k, v in output_ms.items() if v is not None}
|
||||
output_path = self.args.output_path
|
||||
np.savez(output_path, **output_np)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run VocabParallelCrossEntropy test")
|
||||
parser.add_argument("--vocab_size", type=int, default=1024)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--seq_length", type=int, default=8)
|
||||
parser.add_argument("--check_for_nan_in_loss_and_grad", type=lambda x: x.lower() == "true", default="false")
|
||||
parser.add_argument("--calculate_per_token_loss", type=lambda x: x.lower() == "true", default="false")
|
||||
parser.add_argument("--output_path", type=str, default="output_ms_loss.npz")
|
||||
parser.add_argument("--tensor_parallel", type=int, default=1)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
ms.context.set_context(deterministic="ON")
|
||||
ms.set_context(mode=ms.GRAPH_MODE)
|
||||
ms.set_seed(42)
|
||||
|
||||
runner = VocabParallelCrossEntropyRunner(args)
|
||||
runner.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,234 @@
|
||||
# Copyright 2025 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Test VocabParallelCrossEntropy with various configurations"""
|
||||
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
import pytest
|
||||
import numpy as np
|
||||
from data_gen_utils import GOLDEN_DATA, GPU_DATA
|
||||
from tests.utils.double_benchmark import DoubleBenchmarkStandard, DoubleBenchmarkComparator
|
||||
|
||||
TOTAL_VOCAB_SIZE = 1024
|
||||
BATCH_SIZE = 4
|
||||
SEQ_LENGTH = 4
|
||||
|
||||
SINGLE_CARD_TEST_PARAM = "model_args, data_keys, expect_error"
|
||||
SINGLE_CARD_TEST_CASES = [
|
||||
# Test Case 1: Single Card, check_for_nan=False, calculate_per_token=True
|
||||
(
|
||||
{"check_for_nan_in_loss_and_grad": False, "calculate_per_token_loss": True},
|
||||
{"numerator": "numerator", "denominator": "denominator"},
|
||||
False,
|
||||
),
|
||||
# Test Case 2: Single Card, check_for_nan=True, calculate_per_token=True
|
||||
(
|
||||
{"check_for_nan_in_loss_and_grad": True, "calculate_per_token_loss": True},
|
||||
{"numerator": "numerator", "denominator": "denominator"},
|
||||
False,
|
||||
),
|
||||
]
|
||||
|
||||
FOUR_CARD_TEST_PARAM = "model_args, data_keys, expect_error, tensor_parallel"
|
||||
FOUR_CARD_TEST_CASES = [
|
||||
# Test Case 3: Four Cards (DP=2, TP=2), check_for_nan=False, calculate_per_token=True
|
||||
(
|
||||
{"check_for_nan_in_loss_and_grad": False, "calculate_per_token_loss": True},
|
||||
{"numerator": "numerator", "denominator": "denominator"},
|
||||
False,
|
||||
2,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def build_msrun_command_list(
|
||||
worker_num,
|
||||
local_worker_num,
|
||||
log_dir,
|
||||
run_script_path,
|
||||
vocab_size,
|
||||
batch_size,
|
||||
seq_length,
|
||||
check_for_nan_in_loss_and_grad,
|
||||
calculate_per_token_loss,
|
||||
output_path_param,
|
||||
tensor_parallel,
|
||||
):
|
||||
"""Build the msrun command with the specified parameters for VocabParallelCrossEntropy."""
|
||||
if tensor_parallel == 1:
|
||||
cmd_list = ["python"]
|
||||
else:
|
||||
cmd_list = [
|
||||
"msrun",
|
||||
f"--worker_num={worker_num}",
|
||||
f"--local_worker_num={local_worker_num}",
|
||||
"--master_port=8167",
|
||||
f"--log_dir={log_dir}",
|
||||
"--join=True",]
|
||||
cmd_list += [
|
||||
str(run_script_path),
|
||||
f"--vocab_size={vocab_size}",
|
||||
f"--batch_size={batch_size}",
|
||||
f"--seq_length={seq_length}",
|
||||
f"--check_for_nan_in_loss_and_grad={str(check_for_nan_in_loss_and_grad).lower()}",
|
||||
f"--calculate_per_token_loss={str(calculate_per_token_loss).lower()}",
|
||||
f"--output_path={output_path_param}",
|
||||
f"--tensor_parallel={tensor_parallel}",
|
||||
]
|
||||
print(f"Equivalent shell command for debugging (approximate): {' '.join(cmd_list)}")
|
||||
return cmd_list
|
||||
|
||||
|
||||
class TestVocabParallelCrossEntropy:
|
||||
"""Test class for VocabParallelCrossEntropy with different configurations"""
|
||||
|
||||
OUTPUT_MS_FILENAME = "output_ms_loss.npz"
|
||||
LOG_DIR_NAME = "msrun_log_loss"
|
||||
WORKER_LOG_FILENAME = "worker_0.log"
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup method to prepare test environment"""
|
||||
self.sh_path = Path(__file__).parent.resolve()
|
||||
self.run_script_path = self.sh_path / "run_vocab_parallel_cross_entropy.py"
|
||||
|
||||
def check_acc(self, output_ms_dict, data_keys):
|
||||
"""
|
||||
Compare output_ms with GOLDEN_DATA and GPU_DATA using DoubleBenchmarkComparator.
|
||||
"""
|
||||
standard = DoubleBenchmarkStandard(dtype="float32")
|
||||
|
||||
for key, data_key in data_keys.items():
|
||||
assert key in output_ms_dict, f"Key '{key}' not found in MindSpore output."
|
||||
npu_data = output_ms_dict.get(key)
|
||||
|
||||
assert data_key in GOLDEN_DATA, f"Golden data key '{data_key}' not found."
|
||||
golden_data = GOLDEN_DATA.get(data_key)
|
||||
|
||||
gpu_data = GPU_DATA.get(data_key)
|
||||
|
||||
DoubleBenchmarkComparator.check_pass_or_not(
|
||||
npu_data=npu_data, gpu_data=gpu_data, golden_data=golden_data, standard=standard
|
||||
)
|
||||
|
||||
def check_output_keys(self, output_ms_dict, model_args):
|
||||
"""Check if the expected keys are present in the output based on model_args."""
|
||||
output_keys = list(output_ms_dict.keys())
|
||||
calculate_per_token_loss = model_args["calculate_per_token_loss"]
|
||||
|
||||
if calculate_per_token_loss:
|
||||
assert "numerator" in output_keys, (
|
||||
f"The 'numerator' key is expected when calculate_per_token_loss is True. Found keys: {output_keys}"
|
||||
)
|
||||
assert "denominator" in output_keys, (
|
||||
f"The 'denominator' key is expected when calculate_per_token_loss is True. Found keys: {output_keys}"
|
||||
)
|
||||
assert "loss" not in output_keys, (
|
||||
f"The 'loss' key is NOT expected when calculate_per_token_loss is True. Found keys: {output_keys}"
|
||||
)
|
||||
else:
|
||||
assert "loss" in output_keys, (
|
||||
f"The 'loss' key is expected when calculate_per_token_loss is False. Found keys: {output_keys}"
|
||||
)
|
||||
assert "numerator" not in output_keys, (
|
||||
f"The 'numerator' key is NOT expected when calculate_per_token_loss is False. Found keys: {output_keys}"
|
||||
)
|
||||
assert "denominator" not in output_keys, (
|
||||
f"The 'denominator' key is NOT expected when calculate_per_token_loss is False. "
|
||||
f"Found keys: {output_keys}"
|
||||
)
|
||||
|
||||
def run_test(
|
||||
self,
|
||||
worker_num,
|
||||
local_worker_num,
|
||||
model_args,
|
||||
data_keys,
|
||||
tmp_path,
|
||||
tensor_parallel=1,
|
||||
expect_error=False,
|
||||
):
|
||||
"""Helper function to run test and check results"""
|
||||
output_file_path = tmp_path / self.OUTPUT_MS_FILENAME
|
||||
log_dir_path = tmp_path / self.LOG_DIR_NAME
|
||||
log_dir_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cmd_list = build_msrun_command_list(
|
||||
worker_num=worker_num,
|
||||
local_worker_num=local_worker_num,
|
||||
log_dir=log_dir_path,
|
||||
run_script_path=self.run_script_path,
|
||||
vocab_size=TOTAL_VOCAB_SIZE,
|
||||
batch_size=BATCH_SIZE,
|
||||
seq_length=SEQ_LENGTH,
|
||||
check_for_nan_in_loss_and_grad=model_args["check_for_nan_in_loss_and_grad"],
|
||||
calculate_per_token_loss=model_args["calculate_per_token_loss"],
|
||||
output_path_param=output_file_path,
|
||||
tensor_parallel=tensor_parallel,
|
||||
)
|
||||
|
||||
result = subprocess.run(cmd_list, shell=False, capture_output=True, text=True, check=False)
|
||||
|
||||
if expect_error:
|
||||
assert result.returncode != 0, (
|
||||
f"Expected an error but test script passed. Stdout:\n{result.stdout}\nStderr:\n{result.stderr}"
|
||||
)
|
||||
else:
|
||||
assert result.returncode == 0, (
|
||||
f"Test script failed with non-zero exit code: "
|
||||
f"{result.returncode}.\nStdout:\n{result.stdout}\nStderr:\n{result.stderr}"
|
||||
)
|
||||
assert output_file_path.exists(), f"Output file {output_file_path} was not created."
|
||||
|
||||
output_ms_dict = np.load(output_file_path)
|
||||
self.check_output_keys(output_ms_dict, model_args)
|
||||
self.check_acc(output_ms_dict, data_keys)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize(
|
||||
SINGLE_CARD_TEST_PARAM,
|
||||
SINGLE_CARD_TEST_CASES
|
||||
)
|
||||
def test_single_card_cases(self, model_args, data_keys, expect_error, tmp_path):
|
||||
"""Test single card with various configurations."""
|
||||
self.run_test(
|
||||
worker_num=1,
|
||||
local_worker_num=1,
|
||||
model_args=model_args,
|
||||
data_keys=data_keys,
|
||||
expect_error=expect_error,
|
||||
tmp_path=tmp_path,
|
||||
tensor_parallel=1,
|
||||
)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_single
|
||||
@pytest.mark.parametrize(
|
||||
FOUR_CARD_TEST_PARAM,
|
||||
FOUR_CARD_TEST_CASES
|
||||
)
|
||||
def test_four_cards_case(self, model_args, data_keys, expect_error, tensor_parallel, tmp_path):
|
||||
"""Test four cards with various configurations."""
|
||||
self.run_test(
|
||||
worker_num=4,
|
||||
local_worker_num=4,
|
||||
model_args=model_args,
|
||||
data_keys=data_keys,
|
||||
expect_error=expect_error,
|
||||
tmp_path=tmp_path,
|
||||
tensor_parallel=tensor_parallel,
|
||||
)
|
||||
Reference in New Issue
Block a user