!7811 【master】【bugfix】【覆盖率】mindformers用例覆盖率较低,需补充用例

Merge pull request !7811 from Yule100/test_ut
This commit is contained in:
i-robot
2025-12-06 02:18:42 +00:00
committed by Gitee
6 changed files with 923 additions and 0 deletions

View File

@@ -0,0 +1,57 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test base model
"""
import os
import tempfile
import pytest
from mindformers import MindFormerConfig
from mindformers.models.base_config import BaseConfig
from mindformers.models.base_model import BaseModel
NUM_LAYERS = 1
class TestBaseModel:
"""A test class for testing model.save_pretrained() method."""
def setup_method(self):
"""init test class."""
with tempfile.TemporaryDirectory() as temp_dir_path:
self.path = temp_dir_path
@pytest.mark.level1
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_base_model(self):
"""
Feature: Base_model save_pretrained()
Description: Test llama save pretrained
Expectation: Run successfully.
"""
config = BaseConfig(num_layers=NUM_LAYERS)
model = BaseModel(config)
model.save_pretrained(self.path, save_name="mindspore_model")
yaml_path = self.path + "/" + "mindspore_model.yaml"
model_path = self.path + "/" + "mindspore_model.ckpt"
assert os.path.exists(yaml_path)
assert os.path.exists(model_path)
mf_config = MindFormerConfig(yaml_path)
assert mf_config.model.model_config.num_layers == NUM_LAYERS
# pylint: disable=W0212
model._get_config_args(pretrained_model_name_or_dir=self.path)

View File

@@ -0,0 +1,79 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test module for testing build_model for mindformers.
"""
import pytest
from mindformers.models.build_model import build_encoder, build_head
from mindformers.tools.register import MindFormerModuleType, MindFormerRegister
class DummyEncoder:
def __init__(self, **kwargs):
self.kwargs = kwargs
class DummyHead:
def __init__(self, num_classes=10):
self.num_classes = num_classes
MindFormerRegister.register(MindFormerModuleType.ENCODER, "dummy_enc")(DummyEncoder)
MindFormerRegister.register(MindFormerModuleType.HEAD, "dummy_head")(DummyHead)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_build_encoder():
"""
Feature: build_encoder()
Description: Test build_encoder().
Expectation: Run successfully.
"""
encoder_config = None
class_name = None
encoder = build_encoder(encoder_config, class_name=class_name)
assert encoder is None
encoder = build_encoder(class_name=DummyEncoder)
assert encoder is not None
encoder_config = {"type": DummyEncoder}
encoder = build_encoder(encoder_config)
assert encoder is not None
encoder_config = [{"type": DummyEncoder}, {"type": DummyEncoder}]
encoder = build_encoder(encoder_config)
assert encoder is not None
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_build_head():
"""
Feature: build_head()
Description: Test build_head().
Expectation: Run successfully.
"""
head_config = None
class_name = None
head = build_head(head_config, class_name=class_name)
assert head is None
head = build_head(class_name=DummyHead)
assert head is not None
head_config = {"type": DummyHead}
head = build_head(head_config)
assert head is not None
head_config = [{"type": DummyHead}, {"type": DummyHead}]
head = build_head(head_config)
assert head is not None

View File

@@ -0,0 +1,147 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test module for testing models utils for mindformers.
"""
import unittest
from unittest.mock import MagicMock
import pytest
from mindformers.models.utils import check_use_3d_tensor_parallel_valid
# Mock helper functions and constants
class ParallelMode:
AUTO_PARALLEL = "auto_parallel"
def check_fine_grain_interleave_valid(fine_grain):
return fine_grain is not None and fine_grain > 1
class TestCheckUse3DTensorParallelValid(unittest.TestCase):
"""A class for testing CheckUse3DTensorParallelValid."""
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_disabled_use_3d_tp(self):
"""Branch: use_3d_tensor_parallel = False → return False"""
config = MagicMock()
config.use_3d_tensor_parallel = False
result = check_use_3d_tensor_parallel_valid(config)
self.assertFalse(result)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_config_none(self):
"""Branch: config is None → return False"""
result = check_use_3d_tensor_parallel_valid(None)
self.assertFalse(result)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_parallel_config_none(self):
"""Branch: config.parallel_config is None → return False"""
config = MagicMock()
config.use_3d_tensor_parallel = True
config.parallel_config = None
result = check_use_3d_tensor_parallel_valid(config)
self.assertFalse(result)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_use_flash_attention_false(self):
"""Raise: use_flash_attention must be True"""
config = self._create_valid_config()
config.use_flash_attention = False
with self.assertRaises(ValueError, msg="use_flash_attention must be True"):
check_use_3d_tensor_parallel_valid(config)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_ulysses_cp_num_gt_1(self):
"""Raise: ulysses cp > 1 not supported"""
config = self._create_valid_config()
config.parallel_config.get_ulysses_cp_num.return_value = 2
with self.assertRaises(ValueError, msg="ulysses cp must be 1"):
check_use_3d_tensor_parallel_valid(config)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_moe_enabled(self):
"""Raise: MoE not supported"""
config = self._create_valid_config()
moe_mock = MagicMock()
moe_mock.expert_num = 8
config.moe_config = moe_mock
with self.assertRaises(ValueError, msg="MoE not supported"):
check_use_3d_tensor_parallel_valid(config)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_seq_parallel_false(self):
"""Raise: use_seq_parallel must be True"""
config = self._create_valid_config()
config.parallel_config.use_seq_parallel = False
with self.assertRaises(ValueError, msg="use_seq_parallel must be True"):
check_use_3d_tensor_parallel_valid(config)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_fine_grain_interleave_invalid(self):
"""Raise: fine_grain_interleave not supported"""
config = self._create_valid_config()
config.fine_grain_interleave = 2 # triggers True in check_fine_grain_interleave_valid
with self.assertRaises(ValueError, msg="fine_grain_interleave not supported"):
check_use_3d_tensor_parallel_valid(config)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_tp_product_mismatch(self):
"""Raise: tp_x * tp_y * tp_z != model_parallel"""
config = self._create_valid_config()
config.tp_x = 2
config.tp_y = 2
config.tp_z = 2
config.parallel_config.model_parallel = 7 # 2*2*2=8 ≠ 7
with self.assertRaises(ValueError, msg="tp product mismatch"):
check_use_3d_tensor_parallel_valid(config)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def _create_valid_config(self):
"""Helper to create a config that passes initial checks"""
config = MagicMock()
config.use_3d_tensor_parallel = True
config.use_flash_attention = True
config.fine_grain_interleave = None # valid
config.moe_config = None
parallel_config = MagicMock()
parallel_config.get_ulysses_cp_num.return_value = 1
parallel_config.use_seq_parallel = True
parallel_config.model_parallel = 4
config.parallel_config = parallel_config
return config

View File

@@ -0,0 +1,116 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test module for testing tools check_rules for mindformers.
"""
import pytest
from mindformers.tools.check_rules import (
_restore_net_type,
_rule_fa_only_for_train,
_check_keyword_gen_dataset,
_check_context_parallel_algo_valid,
_check_recompute
)
from mindformers import MindFormerConfig
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_restore_net_type():
"""
Feature: Check rules
Description: Test check_rules.
Expectation: Run successfully.
"""
config = MindFormerConfig()
config.set_value('model.model_config.compute_dtype', 'bfloat16')
config.set_value('model.model_config.param_init_type', 'float32')
_restore_net_type(config=config)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_rule_fa_only_train():
"""
Feature: Check rules
Description: Test check_rules.
Expectation: Run successfully.
"""
config = MindFormerConfig()
config.set_value('model.model_config.use_flash_attention', True)
_rule_fa_only_for_train(config=config, mode="train")
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_context_parallel_algo_valid():
"""
Feature: Check rules
Description: Test check_rules.
Expectation: Run successfully.
"""
config = MindFormerConfig()
config.set_value('model.model_config.n_kv_heads', None)
config.set_value('model.model_config.multi_query_group_num', 2)
config.set_value('model.model_config.num_heads', None)
config.set_value('model.model_config.num_attention_heads', 32)
config.set_value('parallel_config.context_parallel_algo.value', "ulysses_cp")
with pytest.raises(ValueError, match=r"cp \* mp <= attention head"):
_check_context_parallel_algo_valid(config=config, cp=8, mp=8)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_keyword_gen_dataset():
"""
Feature: Check rules
Description: Test check_rules.
Expectation: Run successfully.
"""
config = MindFormerConfig()
config.set_value('model.model_config.seq_length', 101)
config.set_value('do_eval', False)
config.set_value('metric', [{"type": "ADGENMetric"}, {"type": "PerplexityMetric"}])
# train dataset
config.set_value('train_dataset.data_loader.type', "ADGenDataLoader")
config.set_value('train_dataset.max_source_length', 50)
config.set_value('train_dataset.max_target_length', 50)
# eval dataset
config.set_value('eval_dataset.data_loader.type', "ADGenDataLoader")
config.set_value('eval_dataset.data_loader.phase', "eval")
config.set_value('eval_dataset.max_source_length', 101)
config.set_value('eval_dataset.max_target_length', 20)
config.set_value('eval_dataset_task.dataset_config.data_loader.phase', "eval")
_check_keyword_gen_dataset(config=config, mode='train')
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_recompute():
"""
Feature: Check rules
Description: Test check_rules.
Expectation: Run successfully.
"""
config = MindFormerConfig()
config.set_value("swap_config.swap", True)
config.set_value("recompute_config.recompute", True)
config.set_value("recompute_config.select_recompute", True)
config.set_value("recompute_config.select_comm_recompute", True)
_check_recompute(config=config)

View File

@@ -0,0 +1,482 @@
# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test utils"""
import os
import stat
import tempfile
from pathlib import Path
from unittest import mock
import pytest
from mindspore import context
from mindformers.tools.utils import (
check_in_modelarts,
get_output_root_path,
is_version_le,
is_version_ge,
get_epoch_and_step_from_ckpt_name,
str2bool,
parse_value,
set_safe_mode_for_file_or_dir,
PARALLEL_MODE,
MODE,
Validator,
check_obs_url,
get_rank_id_from_ckpt_name,
replace_rank_id_in_ckpt_name,
get_ascend_log_path,
calculate_pipeline_stage,
divide,
is_pynative,
create_and_write_info_to_txt,
check_ckpt_file_name,
get_times_epoch_and_step_from_ckpt_name,
is_last_pipeline_stage,
get_dp_from_dataset_strategy,
check_shared_disk,
replace_tk_to_mindpet,
get_num_nodes_devices,
)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_in_modelarts_true():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
with mock.patch.dict(os.environ, {"MA_LOG_DIR": "/tmp"}):
assert check_in_modelarts() is True
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_in_modelarts_false():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
with mock.patch.dict(os.environ, clear=True):
assert check_in_modelarts() is False
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_output_root_path_default():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
with mock.patch.dict(os.environ, {}, clear=True):
path = get_output_root_path()
assert path == os.path.realpath("./output")
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_output_root_path_env():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
with mock.patch.dict(os.environ, {"LOCAL_DEFAULT_PATH": "/custom/output"}):
path = get_output_root_path()
assert path == "/custom/output"
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_is_version_le():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
assert is_version_le("1.8.1", "1.11.0") is True
assert is_version_le("1.11.0", "1.11.0") is True
assert is_version_le("2.0.0", "1.11.0") is False
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_is_version_ge():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
assert is_version_ge("1.11.0", "1.8.1") is True
assert is_version_ge("1.11.0", "1.11.0") is True
assert is_version_ge("1.8.1", "1.11.0") is False
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_epoch_and_step_from_ckpt_name():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
epoch, step = get_epoch_and_step_from_ckpt_name("model-5_100.ckpt")
assert epoch == 5
assert step == 100
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_epoch_and_step_invalid():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
with pytest.raises(ValueError, match="Can't match epoch and step"):
get_epoch_and_step_from_ckpt_name("invalid_name.txt")
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_str2bool():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
assert str2bool("True") is True
assert str2bool("False") is False
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_str2bool_invalid():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
with pytest.raises(Exception, match="Invalid Bool Value"):
str2bool("maybe")
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_parse_value():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
assert parse_value("123") == 123
assert parse_value("3.14") == 3.14
assert parse_value("True") is True
assert parse_value('{"a": 1}') == {"a": 1}
assert parse_value("hello") == "hello"
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_set_safe_mode_for_file_or_dir():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
with tempfile.TemporaryDirectory() as tmpdir:
file_path = Path(tmpdir) / "test.txt"
dir_path = Path(tmpdir) / "subdir"
dir_path.mkdir()
file_path.write_text("test")
set_safe_mode_for_file_or_dir([str(file_path), str(dir_path)])
assert (file_path.stat().st_mode & stat.S_IRUSR) != 0
assert (file_path.stat().st_mode & stat.S_IWUSR) != 0
assert (dir_path.stat().st_mode & stat.S_IXUSR) != 0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_parallel_mode_mapping():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
assert PARALLEL_MODE["DATA_PARALLEL"] == context.ParallelMode.DATA_PARALLEL
assert PARALLEL_MODE[0] == context.ParallelMode.DATA_PARALLEL
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_mode_mapping():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
assert MODE["GRAPH_MODE"] == context.GRAPH_MODE
assert MODE[0] == context.GRAPH_MODE
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_validator_check_type():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
Validator.check_type(42, int)
with pytest.raises(TypeError):
Validator.check_type("42", int)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_obs_url_valid():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
check_obs_url("obs://bucket/path")
check_obs_url("s3://bucket/path")
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_obs_url_invalid():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
with pytest.raises(TypeError, match="should be start with obs:// or s3://"):
check_obs_url("/local/path")
with pytest.raises(TypeError, match="type should be a str"):
check_obs_url(123)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_rank_id_from_ckpt_name():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
rank = get_rank_id_from_ckpt_name("llama_7b_rank_3-5_100.ckpt")
assert rank == 3
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_replace_rank_id_in_ckpt_name():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
new_name = replace_rank_id_in_ckpt_name("model_rank_2-1_50.ckpt", 5)
assert new_name == "model_rank_5-1_50.ckpt"
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_ascend_log_path():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
os.environ['ASCEND_PROCESS_LOG_PATH'] = '/home/log'
assert get_ascend_log_path() == '/home/log'
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_calculate_pipeline_stage():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
layers_per_stage = [4, 4]
model_layers = [6]
input_layers_per_stage = layers_per_stage.copy()
result = calculate_pipeline_stage(input_layers_per_stage, model_layers)
expected = [
{
"offset": [1, -1], # [4-3, 2-3]
"start_stage": 0,
"stage_num": 2
}
]
assert result == expected
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_divide():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
res = divide(10, 2)
assert res == 5
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_is_pynative():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
os.environ['ENFORCE_EAGER'] = 'true'
res = is_pynative()
assert res
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_create_and_write_info_to_txt():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
with tempfile.TemporaryDirectory() as tmpdir:
txt_path = os.path.join(tmpdir, "output.txt")
info = "Hello, world!"
create_and_write_info_to_txt(txt_path, info)
assert os.path.exists(txt_path)
with open(txt_path, 'r', encoding='utf-8') as f:
content = f.read()
assert content == "Hello, world!"
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_ckpt_file_name():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
ckpt_name = "llama_0-3_1.ckpt"
res = check_ckpt_file_name(ckpt_name)
assert res
ckpt_name = "dsadsdsadasd"
res = check_ckpt_file_name(ckpt_name)
assert not res
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_times_epoch_and_step_from_ckpt_name():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
ckpt_name = "llama_0-3_1.ckpt"
res = check_ckpt_file_name(ckpt_name)
if res:
times, epcoh, step = get_times_epoch_and_step_from_ckpt_name(ckpt_name)
assert times == 0
assert epcoh == 3
assert step == 1
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_is_last_pipeline_stage():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
with mock.patch("mindformers.tools.utils.get_real_group_size", return_value=8), \
mock.patch("mindformers.tools.utils.get_real_rank", return_value=6), \
mock.patch("mindspore.get_auto_parallel_context", return_value=2):
assert is_last_pipeline_stage() is True
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_dp_from_dataset_strategy():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
with mock.patch("mindspore.get_auto_parallel_context", return_value=[[2, 1]]):
dp = get_dp_from_dataset_strategy()
assert dp == 2
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_shared_disk():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
disk_path = "/home/workspace"
res = check_shared_disk(disk_path=disk_path)
assert not res
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_replace_tk_to_mindpet():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
ckpt_dict = {"tk_delta": 1}
new_ckpt = replace_tk_to_mindpet(ckpt_dict)
assert new_ckpt['mindpet_delta'] == 1
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_num_nodes_devices():
"""
Feature: Utils functions
Description: Test utils functions.
Expectation: Run successfully.
"""
rank_size = 7
with mock.patch("mindformers.tools.utils.get_device_num_per_node", return_value=8):
num_nodes, num_devices = get_num_nodes_devices(rank_size=rank_size)
assert num_nodes == 1
assert num_devices == rank_size

View File

@@ -0,0 +1,42 @@
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test module for testing import utils for mindformers.
"""
import os
import tempfile
import pytest
from mindformers.utils.import_utils import direct_mindformers_import
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_direct_mindformers_import_success():
"""
Feature: Import utils
Description: Test direct_mindformers_import.
Expectation: Run successfully.
"""
with tempfile.TemporaryDirectory() as tmpdir:
init_file = os.path.join(tmpdir, "__init__.py")
with open(init_file, "w", encoding='utf-8') as f:
f.write('''
def hello():
return "Hello from mocked mindformers!"
''')
module = direct_mindformers_import(tmpdir)
assert module is not None