Compare commits

...

4 Commits

Author SHA1 Message Date
i-robot
d63b217451 !7777 【master】【bugfix】修复FlashAttention里面的reducemax没有配置shard
Merge pull request !7777 from JavaZero/fix_reduce_max
2025-12-03 08:11:14 +00:00
i-robot
53d6451e85 !7734 【master】增加load_checkpoint_utils.py以及run_check.py的测试用例
Merge pull request !7734 from AAA碧根果批发赵少/testcase
2025-12-03 07:59:10 +00:00
yiyison
b446a39fe7 增加load_checkpoint_utils.py以及run_check.py的测试用例 2025-12-03 10:53:03 +08:00
JavaZero
4dc65d8960 fix redistribution op in flash_attn 2025-12-01 15:27:03 +08:00
4 changed files with 861 additions and 57 deletions

View File

@@ -181,6 +181,7 @@ class FlashAttention(Cell):
parallel_optimizer=False, requires_grad=False
)
self.reduce_max = aclnn_ops.ReduceMax()
self.reduce_max.add_prim_attr("self_define_shard", True)
self.assign_add = ops.AssignAdd()
self.assign_add.add_prim_attr("self_define_shard", True)
@@ -383,4 +384,14 @@ class FlashAttention(Cell):
in_strategy=(layout("tp"), layout("tp")),
out_strategy=(layout("tp"),)
)
if self.input_layout == "BNSD":
self.reduce_max.shard(
in_strategy=(layout("None", "tp", "None", "None"),),
out_strategy=(layout("tp"),)
)
elif self.input_layout == "TND":
self.reduce_max.shard(
in_strategy=(layout("None", "tp", "None"),),
out_strategy=(layout("tp"),)
)
return self

View File

@@ -65,7 +65,7 @@ def _get_origin_network(network):
"""recursive find if cells which have function <convert_name>"""
if 'convert_name' in dir(network):
return network, True
#DFS for network
# DFS for network
for cell in list(network.cells()):
network, find_cell = _get_origin_network(cell)
if find_cell:
@@ -314,7 +314,7 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval
if config.resume_training or (config.get('remove_redundancy', False) and not do_predict):
# pylint: disable=W0212
network = model._train_network
#build model
# build model
if config.use_parallel:
compile_model(
model=model,
@@ -325,7 +325,7 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval
sink_size=config.runner_config.sink_size,
do_eval=do_eval, do_predict=do_predict
)
#wait generate all rank strategy files
# wait generate all rank strategy files
barrier()
# only execute qkv concat check on the main rank in predict mode
@@ -337,7 +337,7 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval
barrier()
process_for_stand_alone_mode(config, network, strategy_path)
#merge dst strategy
# merge dst strategy
strategy_path = get_merged_dst_strategy_path(config, strategy_path)
load_safetensors_checkpoint(config, load_checkpoint_files, network, strategy_path, load_checkpoint, optimizer)
@@ -457,7 +457,7 @@ def load_safetensors_checkpoint(config, load_checkpoint_files, network, strategy
format=config.load_ckpt_format
))
if not config.model.model_config.get("qkv_concat", False) \
and is_hf_safetensors_dir(load_ckpt_path, origin_network):
and is_hf_safetensors_dir(load_ckpt_path, origin_network):
logger.info("......HuggingFace weights convert name......")
params_dict = origin_network.convert_weight_dict(params_dict, model_config=config.model.model_config)
if optimizer and config.resume_training:

View File

@@ -0,0 +1,15 @@
"""Test for run_check function"""
import pytest
from mindformers import run_check
@pytest.mark.level0
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
def test_run_check():
"""
Feature: Test run_check function
Description: Call run_check to check if MindSpore, MindFormers, CANN and driver versions are compatible
Expectation: No exceptions raised, all checks pass
"""
run_check()

View File

@@ -13,17 +13,95 @@
# limitations under the License.
# ============================================================================
"""test for load_checkpoint_utils."""
# pylint: disable=W0621
import tempfile
from unittest.mock import patch, MagicMock
import pytest
import numpy as np
from mindspore import Parameter
from mindformers.tools.register import MindFormerConfig
from mindformers.checkpoint.utils import compile_model
from mindformers.utils.load_checkpoint_utils import CkptFormat, _get_checkpoint_mode, CheckpointFileMode, \
_check_checkpoint_path
from mindformers.models.modeling_utils import PreTrainedModel
from mindformers.utils.load_checkpoint_utils import (
CkptFormat, _get_checkpoint_mode, CheckpointFileMode, _check_checkpoint_path,
extract_suffix, get_last_checkpoint, validate_config_with_file_mode,
update_global_step, unify_safetensors, _revise_remove_redundancy_with_file,
_get_origin_network, get_load_path_after_hf_convert, _get_src_strategy,
_get_src_file_suffix, _get_src_file, load_safetensors_checkpoint,
process_hf_checkpoint, validate_qkv_concat, get_merged_src_strategy_path,
get_merged_dst_strategy_path, process_for_stand_alone_mode,
load_checkpoint_with_safetensors
)
@pytest.fixture
def mock_config():
"""Create a mock config with default values"""
class MockConfig:
"""Mock configuration class for testing"""
def __init__(self):
self.load_checkpoint = "/path/to/checkpoint"
self.load_ckpt_format = "safetensors"
self.use_parallel = False
self.auto_trans_ckpt = False
self.resume_training = None
self.remove_redundancy = False
self.output_dir = "/output"
self.src_strategy_path_or_dir = None
self.load_ckpt_async = False
self.context = type('', (), {})()
self.context.mode = "GRAPH_MODE"
self.runner_config = type('', (), {})()
self.runner_config.sink_mode = True
self.runner_config.epochs = 1
self.runner_config.sink_size = 1
self.runner_config.step_scale = 2.0
self.model = type('', (), {})()
self.model.model_config = {}
self.parallel = type('', (), {})()
self.parallel.parallel_mode = "DATA_PARALLEL"
def get(self, key, default=None):
return getattr(self, key, default)
return MockConfig()
@pytest.fixture
def mock_network():
"""Create a mock network"""
mock_net = MagicMock()
mock_net.cells.return_value = []
return mock_net
@pytest.fixture
def mock_model():
"""Create a mock model"""
mock_mod = MagicMock()
mock_mod.config = MagicMock()
mock_mod.config.model_type = "test_model"
return mock_mod
@pytest.fixture
def mock_file():
"""Create a mock file"""
mock_f = MagicMock()
mock_f.metadata.return_value = None
return mock_f
class TestCommonCheckpointMethod:
"""A test class for testing common methods"""
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_support_type(self):
"""test CkptFormat support type"""
# run the test
@@ -32,25 +110,709 @@ class TestCommonCheckpointMethod:
# verify the results
assert result == ['ckpt', 'safetensors']
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_checkpoint_path_with_non_string_pathlike(self):
"""test check checkpoint path with non string pathlike"""
path = 123
with pytest.raises(ValueError,
match=r"config.load_checkpoint must be a str, but got 123 as type <class 'int'>."):
match=r"config.load_checkpoint must be a `str`, but got `123` as type `<class 'int'>`."):
_check_checkpoint_path(path)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_checkpoint_path_with_nonexistent_path(self):
"""test check checkpoint path with nonexistent path"""
path = 'NoneExistPath'
with pytest.raises(FileNotFoundError, match=r"config.load_checkpoint NoneExistPath does not exist."):
with pytest.raises(FileNotFoundError, match=r"config.load_checkpoint `NoneExistPath` does not exist."):
_check_checkpoint_path(path)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_checkpoint_path_with_valid_path(self):
"""test check checkpoint path with valid path"""
# create a temporary directory for testing
with tempfile.TemporaryDirectory() as tmpdir:
# test with directory path
result = _check_checkpoint_path(tmpdir)
assert result == tmpdir
# test with directory path ending with slash
result = _check_checkpoint_path(tmpdir + '/')
assert result == tmpdir
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize(
"file_path, expected",
[
# test pattern 1: {prefix}_rank_{rank_id}-{epoch}_{step}.safetensors
("model_rank_0-10_200.safetensors", "-10_200"),
# test pattern 2: {prefix}_rank_{rank_id}_{task_id}-{epoch}_{step}.safetensors
("model_rank_0_1-10_200.safetensors", "_1-10_200"),
# test with invalid pattern
("invalid_filename.safetensors", "invalid_filename")
]
)
def test_extract_suffix(self, file_path, expected):
"""test extract_suffix function"""
result = extract_suffix(file_path)
assert result == expected
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_last_checkpoint(self):
"""test get_last_checkpoint function"""
# setup mocks using context managers
with patch('os.path.isdir') as mock_isdir, \
patch('os.path.exists') as mock_exists, \
patch('os.listdir') as mock_listdir, \
patch('os.path.getmtime') as mock_getmtime:
# setup mock return values
mock_isdir.return_value = True
mock_exists.return_value = True
mock_listdir.return_value = ["model_0.ckpt", "model_1.ckpt", "model_2.ckpt"]
mock_getmtime.side_effect = lambda x: {
"/test/model_0.ckpt": 100,
"/test/model_1.ckpt": 200,
"/test/model_2.ckpt": 300
}[x]
# test with valid directory
result = get_last_checkpoint("/test", "ckpt")
assert result == "/test/model_2.ckpt"
# test with no checkpoint files
mock_listdir.return_value = ["other_file.txt"]
result = get_last_checkpoint("/test", "ckpt")
assert result is None
# test with invalid directory
mock_isdir.return_value = False
with pytest.raises(NotADirectoryError):
get_last_checkpoint("/invalid/dir", "ckpt")
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize(
"file_mode, use_parallel, auto_trans_ckpt, expected_exception",
[
# test single checkpoint file mode with parallel
(CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value, True, False, ValueError),
# test multi checkpoint file mode with parallel but no auto_trans_ckpt
(CheckpointFileMode.MULTI_CHECKPOINT_FILE.value, True, False, ValueError),
# test multi checkpoint file with rank id mode without parallel
(CheckpointFileMode.MULTI_CHECKPOINT_FILE_WITH_RANK_ID.value, False, False, ValueError),
# test invalid mode
("invalid_mode", False, False, ValueError),
# test valid cases - no exception expected
(CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value, False, False, None),
(CheckpointFileMode.MULTI_CHECKPOINT_FILE.value, True, True, None),
(CheckpointFileMode.MULTI_CHECKPOINT_FILE_WITH_RANK_ID.value, True, False, None)
]
)
def test_validate_config_with_file_mode(self, file_mode, use_parallel, auto_trans_ckpt, expected_exception):
"""test validate_config_with_file_mode function"""
if expected_exception:
with pytest.raises(expected_exception):
validate_config_with_file_mode(file_mode, use_parallel, auto_trans_ckpt)
else:
validate_config_with_file_mode(file_mode, use_parallel, auto_trans_ckpt)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize(
"step_scale, initial_global_step, expected_global_step, expected_in_dict",
[
(2.0, 100, 200, True),
(None, 100, 100, True),
(2.0, None, None, False)
]
)
def test_update_global_step(self, step_scale, initial_global_step, expected_global_step, expected_in_dict):
"""test update_global_step function"""
# setup config
config = type('', (), {})()
config.runner_config = type('', (), {})()
config.runner_config.step_scale = step_scale
# setup hyper_param_dict
hyper_param_dict = {}
if initial_global_step is not None:
hyper_param_dict["global_step"] = Parameter(np.array(initial_global_step, dtype=np.int32))
# test update_global_step
update_global_step(config, hyper_param_dict)
# verify the results
if expected_in_dict:
assert "global_step" in hyper_param_dict
assert hyper_param_dict["global_step"].asnumpy() == expected_global_step
else:
assert "global_step" not in hyper_param_dict
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_unify_safetensors(self):
"""test unify_safetensors function"""
# setup mocks using context managers
with patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \
patch('mindformers.utils.load_checkpoint_utils.barrier') as mock_barrier, \
patch('mindspore.unified_safetensors') as mock_unified_safetensors:
# test when is_main_rank is True
mock_is_main_rank.return_value = True
unify_safetensors("/src/checkpoint", "/src/strategy", "/dst/unified", True, "-10_200", False)
mock_unified_safetensors.assert_called_once()
mock_barrier.assert_called_once()
# test when is_main_rank is False
mock_is_main_rank.return_value = False
mock_barrier.reset_mock()
unify_safetensors("/src/checkpoint", "/src/strategy", "/dst/unified", True, "-10_200", False)
mock_unified_safetensors.assert_called_once() # should not be called again
mock_barrier.assert_called_once()
# test without parallel
mock_is_main_rank.return_value = True
mock_barrier.reset_mock()
unify_safetensors("/src/checkpoint", "/src/strategy", "/dst/unified", False, "-10_200", False)
mock_barrier.assert_not_called()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize(
"config_remove_redundancy, metadata, expected_result",
[
# test with metadata remove_redundancy=True and config remove_redundancy=False
(False, {"remove_redundancy": "True"}, True),
# test with metadata remove_redundancy=False and config remove_redundancy=True
(True, {"remove_redundancy": "False"}, False),
# test with matching metadata and config
(True, {"remove_redundancy": "True"}, True),
# test with no metadata
(True, None, True),
# test with metadata but no remove_redundancy key
(True, {"other_key": "value"}, True)
]
)
def test__revise_remove_redundancy_with_file(self, config_remove_redundancy, metadata, expected_result, mock_file):
"""test _revise_remove_redundancy_with_file function"""
mock_file.metadata.return_value = metadata
result = _revise_remove_redundancy_with_file(config_remove_redundancy, mock_file)
assert result == expected_result
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize(
"network_has_convert_name, child_has_convert_name, expected_found",
[
# test with network that has convert_name
(True, False, True),
# test with nested network where child has convert_name
(False, True, True),
# test with network that doesn't have convert_name and no children with it
(False, False, False)
]
)
def test__get_origin_network(self, network_has_convert_name, child_has_convert_name, expected_found):
"""test _get_origin_network function"""
# setup mocks using context managers
with patch('mindformers.utils.load_checkpoint_utils.logger'):
if network_has_convert_name:
# create a mock network with convert_name attribute
mock_network = MagicMock()
mock_network.convert_name = MagicMock()
# Return empty list for cells() to avoid recursion
mock_network.cells.return_value = []
else:
if child_has_convert_name:
# create a mock network without convert_name but with a child that has it
mock_child = MagicMock()
mock_child.convert_name = MagicMock()
# Return empty list for cells() to avoid further recursion
mock_child.cells.return_value = []
# Create a network that returns the child directly when cells() is called
mock_network = MagicMock()
mock_network.cells.return_value = [mock_child]
else:
# create a mock network without convert_name and no children with it
mock_network = MagicMock()
mock_network.cells.return_value = []
# Remove convert_name attribute to simulate network without it
if hasattr(mock_network, 'convert_name'):
delattr(mock_network, 'convert_name')
# run the test
_, found = _get_origin_network(mock_network)
assert found == expected_found
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_load_path_after_hf_convert(self, mock_config, mock_network):
"""test get_load_path_after_hf_convert function"""
# setup mocks using context managers
with patch('mindformers.utils.load_checkpoint_utils.is_hf_safetensors_dir') as mock_is_hf_safetensors_dir, \
patch('mindformers.utils.load_checkpoint_utils.'
'check_safetensors_addition_param_support') as mock_check_support:
# test when not hf safetensors
mock_is_hf_safetensors_dir.return_value = False
result = get_load_path_after_hf_convert(mock_config, mock_network)
assert result == "/path/to/checkpoint"
# test when hf safetensors but not qkv_concat and not supported
mock_is_hf_safetensors_dir.return_value = True
mock_check_support.return_value = False
mock_config.model.model_config = {"qkv_concat": False}
with patch('mindformers.utils.load_checkpoint_utils.process_hf_checkpoint',
return_value="/path/to/converted"):
with patch('mindformers.utils.load_checkpoint_utils.barrier'):
result = get_load_path_after_hf_convert(mock_config, mock_network)
assert result == "/path/to/converted"
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test__get_src_strategy(self, mock_config):
"""test _get_src_strategy function"""
# setup mocks using context managers
with patch('os.path.isfile') as mock_isfile, \
patch('os.path.isdir') as mock_isdir, \
patch('os.path.join') as mock_join, \
patch('os.path.exists') as mock_exists, \
patch('os.path.dirname') as mock_dirname, \
patch('mindformers.utils.load_checkpoint_utils.logger'):
# Test case 1: input_src_strategy is provided
mock_config.load_checkpoint = "/test/checkpoint.ckpt"
mock_config.src_strategy_path_or_dir = "/input/strategy"
mock_isdir.return_value = True
result = _get_src_strategy(mock_config)
assert result == "/input/strategy"
# Test case 2: no strategy dir exists
mock_config.src_strategy_path_or_dir = None
mock_isfile.return_value = True
mock_exists.return_value = False
with pytest.raises(
ValueError,
match="when use checkpoint after train/finetune, src_strategy_path_or_dir should be set"
):
_get_src_strategy(mock_config)
# Test case 3: config.load_checkpoint is a directory and strategy dir exists
mock_isfile.return_value = False
mock_exists.return_value = True
# Setup mock_dirname to return a valid parent directory
mock_dirname.return_value = "/test"
# Setup mock_join to return a valid path
mock_join.return_value = "/test/strategy"
mock_config.load_checkpoint = "/test/checkpoint_dir"
mock_config.src_strategy_path_or_dir = None
result = _get_src_strategy(mock_config)
assert result == "/test/strategy"
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test__get_src_file_suffix(self, mock_config):
"""test _get_src_file_suffix function"""
# setup mocks using context managers
with patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \
patch('mindformers.utils.load_checkpoint_utils.get_last_checkpoint') as mock_get_last_checkpoint, \
patch('os.path.isfile') as mock_isfile, \
patch('os.path.isdir') as mock_isdir:
# test when is_main_rank is True and resume_training is string
mock_is_main_rank.return_value = True
mock_config.resume_training = "checkpoint-10_200.safetensors"
mock_config.load_checkpoint = "/path/to/checkpoint"
mock_config.load_ckpt_format = "safetensors"
with patch('mindformers.utils.load_checkpoint_utils.extract_suffix', return_value="-10_200"):
result = _get_src_file_suffix(mock_config)
assert result == ("/path/to/checkpoint", "-10_200")
# test when is_main_rank is True and load_checkpoint is file
mock_isfile.return_value = True
mock_isdir.return_value = False
mock_config.resume_training = None
mock_config.load_checkpoint = "/path/to/rank_0/checkpoint-10_200.safetensors"
with patch('mindformers.utils.load_checkpoint_utils.extract_suffix', return_value="-10_200"):
result = _get_src_file_suffix(mock_config)
assert result == ("/path/to", "-10_200")
# test when is_main_rank is True and load_checkpoint is dir
mock_isfile.return_value = False
mock_isdir.return_value = True
mock_config.load_checkpoint = "/path/to/checkpoint"
mock_get_last_checkpoint.return_value = "/path/to/checkpoint/rank_0/checkpoint-10_200.safetensors"
with patch('mindformers.utils.load_checkpoint_utils.extract_suffix', return_value="-10_200"):
result = _get_src_file_suffix(mock_config)
assert result == ("/path/to/checkpoint", "-10_200")
# test when is_main_rank is False
mock_is_main_rank.return_value = False
result = _get_src_file_suffix(mock_config)
assert result == (None, None)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test__get_src_file(self):
"""test _get_src_file function"""
# setup mocks using context managers
with patch('os.path.exists') as mock_exists, \
patch('os.path.join') as mock_join, \
patch('mindformers.utils.load_checkpoint_utils.get_real_rank') as mock_get_real_rank, \
patch('mindformers.utils.load_checkpoint_utils.get_last_checkpoint') as mock_get_last_checkpoint:
# test with checkpoint_name provided
mock_get_real_rank.return_value = 0
mock_join.return_value = "/test/rank_0/checkpoint.ckpt"
mock_exists.return_value = True
result = _get_src_file("/test", "checkpoint.ckpt", "ckpt")
assert result == "/test/rank_0/checkpoint.ckpt"
# test without checkpoint_name
mock_get_last_checkpoint.return_value = "/test/rank_0/last_checkpoint.ckpt"
result = _get_src_file("/test", None, "ckpt")
assert result == "/test/rank_0/last_checkpoint.ckpt"
# test with non-existent file
mock_exists.return_value = False
with pytest.raises(FileNotFoundError):
_get_src_file("/test", "non_existent.ckpt", "ckpt")
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_load_safetensors_checkpoint(self, mock_config, mock_network, mock_file):
"""test load_safetensors_checkpoint function"""
# Setup mocks using context managers
with patch('mindformers.utils.load_checkpoint_utils._get_origin_network') as mock_get_origin_network, \
patch('mindformers.utils.load_checkpoint_utils.ms') as mock_ms, \
patch('mindformers.utils.load_checkpoint_utils.logger'), \
patch('mindformers.utils.load_checkpoint_utils.safe_open') as mock_safe_open, \
patch('mindformers.utils.load_checkpoint_utils.is_hf_safetensors_dir') as mock_is_hf_safetensors_dir:
# Setup mock return values
mock_get_origin_network.return_value = (MagicMock(), False)
mock_ms.load_checkpoint.return_value = {"param1": MagicMock()}
mock_is_hf_safetensors_dir.return_value = False
# Mock the safe_open context manager
mock_safe_open.return_value.__enter__.return_value = mock_file
strategy_path = "/path/to/strategy"
load_ckpt_path = "/path/to/checkpoint"
optimizer = None
load_safetensors_checkpoint(mock_config, ["/path/to/checkpoint.safetensors"], mock_network, strategy_path,
load_ckpt_path,
optimizer)
mock_ms.load_param_into_net.assert_called_once()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_process_hf_checkpoint(self, mock_model, tmp_path):
"""test process_hf_checkpoint function"""
# setup mocks using context managers
with patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \
patch('mindformers.utils.load_checkpoint_utils.barrier_world') as mock_barrier_world, \
patch('mindformers.utils.load_checkpoint_utils.Process') as mock_process:
# test when is_main_rank is True
mock_is_main_rank.return_value = True
mock_process_instance = MagicMock()
mock_process_instance.exitcode = 0
mock_process.return_value = mock_process_instance
# Use tmp_path for output and input paths
output_dir = tmp_path / "output" / "dir"
input_checkpoint = tmp_path / "input" / "checkpoint"
# Create input directory
input_checkpoint.parent.mkdir(parents=True, exist_ok=True)
result = process_hf_checkpoint(mock_model, str(output_dir), str(input_checkpoint))
expected_path = str(output_dir / "test_model_ms_converted_weight")
assert result == expected_path
mock_process_instance.start.assert_called_once()
mock_process_instance.join.assert_called_once()
mock_barrier_world.assert_called_once()
# Reset mocks for next test case
mock_process.reset_mock()
mock_process_instance = MagicMock()
mock_process_instance.exitcode = 1
mock_process.return_value = mock_process_instance
# test when process exits with error
with pytest.raises(RuntimeError, match="convert HuggingFace weight failed."):
process_hf_checkpoint(mock_model, str(output_dir), str(input_checkpoint))
# Reset mocks for next test case
mock_process.reset_mock()
mock_process_instance = MagicMock()
mock_process_instance.exitcode = 0
mock_process.return_value = mock_process_instance
# test when is_main_rank is False
mock_is_main_rank.return_value = False
process_hf_checkpoint(mock_model, str(output_dir), str(input_checkpoint))
mock_process_instance.start.assert_not_called()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize(
"model, qkv_concat_config, check_safetensors_key_return, "
"has_concat_keys, expected_exception, should_log_warning",
[
# test with non-PreTrainedModel
("not_a_model", False, False, False, None, True),
# test with PreTrainedModel but no concat keys
(MagicMock(spec=PreTrainedModel), False, False, False, None, False),
# Test case where check_safetensors_key returns True and qkv_concat_config is True
(MagicMock(spec=PreTrainedModel), True, True, True, None, False),
# Test case where check_safetensors_key returns False and qkv_concat_config is True
(MagicMock(spec=PreTrainedModel), True, False, True, ValueError, False),
# Test case where check_safetensors_key returns True and qkv_concat_config is False
(MagicMock(spec=PreTrainedModel), False, True, True, ValueError, False)
]
)
def test_validate_qkv_concat(self, model, qkv_concat_config,
check_safetensors_key_return, has_concat_keys, expected_exception, should_log_warning):
"""test validate_qkv_concat function"""
# Setup mocks using context managers
with patch('mindformers.utils.load_checkpoint_utils.logger') as mock_logger, \
patch('mindformers.utils.load_checkpoint_utils.check_safetensors_key') as mock_check_safetensors_key:
# Setup mock behavior
mock_check_safetensors_key.return_value = check_safetensors_key_return
# If it's a PreTrainedModel, set up obtain_qkv_ffn_concat_keys
if hasattr(model, 'obtain_qkv_ffn_concat_keys'):
model.obtain_qkv_ffn_concat_keys.return_value = ["qkv_concat_key"] if has_concat_keys else None
# Run the test and check results
if expected_exception:
with pytest.raises(expected_exception, match="The qkv concat check failed!"):
validate_qkv_concat(model, qkv_concat_config, "/path/to/checkpoint")
else:
validate_qkv_concat(model, qkv_concat_config, "/path/to/checkpoint")
# Check if warning was logged when expected
if should_log_warning:
mock_logger.warning.assert_called_once()
else:
mock_logger.warning.assert_not_called()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_merged_src_strategy_path(self, mock_config):
"""test get_merged_src_strategy_path function"""
# setup mocks using context managers
with patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \
patch('mindformers.utils.load_checkpoint_utils.barrier') as mock_barrier, \
patch('mindformers.utils.load_checkpoint_utils._get_src_strategy') as mock_get_src_strategy, \
patch('mindformers.utils.load_checkpoint_utils.ms.merge_pipeline_strategys') as mock_merge_strategys, \
patch('os.makedirs'):
# test when is_main_rank is True
mock_is_main_rank.return_value = True
mock_get_src_strategy.return_value = "/input/strategy"
result = get_merged_src_strategy_path(mock_config)
assert result == "/output/merged_strategy/src_strategy.ckpt"
mock_merge_strategys.assert_called_once()
mock_barrier.assert_called_once()
# test when is_main_rank is False
mock_is_main_rank.return_value = False
mock_barrier.reset_mock()
result = get_merged_src_strategy_path(mock_config)
mock_merge_strategys.assert_called_once() # should not be called again
mock_barrier.assert_called_once()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_merged_dst_strategy_path(self, mock_config):
"""test get_merged_dst_strategy_path function"""
# setup mocks using context managers
with patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \
patch('mindformers.utils.load_checkpoint_utils.barrier') as mock_barrier, \
patch('mindformers.utils.load_checkpoint_utils.ms.merge_pipeline_strategys') as mock_merge_strategys, \
patch('os.makedirs'):
# test with use_parallel=True, auto_trans_ckpt=True, not stand_alone
mock_is_main_rank.return_value = True
mock_config.use_parallel = True
mock_config.auto_trans_ckpt = True
mock_config.parallel.parallel_mode = "DATA_PARALLEL"
strategy_path = "/path/to/strategy.ckpt"
result = get_merged_dst_strategy_path(mock_config, strategy_path)
assert result == "/output/merged_strategy/dst_strategy.ckpt"
mock_merge_strategys.assert_called_once()
mock_barrier.assert_called_once()
# test with stand_alone mode
mock_config.parallel.parallel_mode = "STAND_ALONE"
result = get_merged_dst_strategy_path(mock_config, strategy_path)
assert result == "/path/to/strategy.ckpt"
# test with use_parallel=False
mock_config.use_parallel = False
result = get_merged_dst_strategy_path(mock_config, strategy_path)
assert result == "/path/to/strategy.ckpt"
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_process_for_stand_alone_mode(self, mock_config, mock_network):
"""test process_for_stand_alone_mode function"""
strategy_path = "/path/to/strategy.ckpt"
# setup mocks using context managers
with patch('mindformers.utils.load_checkpoint_utils._pynative_executor'), \
patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \
patch('mindformers.utils.load_checkpoint_utils.barrier') as mock_barrier, \
patch('mindformers.utils.load_checkpoint_utils.generate_state_dict') as mock_generate_state_dict, \
patch('mindformers.utils.load_checkpoint_utils.save_strategy_file') as mock_save_strategy_file, \
patch('os.makedirs') as mock_makedirs, \
patch('shutil.rmtree') as mock_rmtree, \
patch('os.path.exists') as mock_exists:
# test with stand_alone mode
mock_is_main_rank.return_value = True
mock_exists.return_value = True
mock_config.parallel.parallel_mode = "STAND_ALONE"
mock_config.use_parallel = True
process_for_stand_alone_mode(mock_config, mock_network, strategy_path)
mock_rmtree.assert_called_once()
mock_makedirs.assert_called_once()
mock_generate_state_dict.assert_called_once()
mock_save_strategy_file.assert_called_once()
mock_barrier.assert_called()
# Reset mocks for next test case
mock_barrier.reset_mock()
mock_rmtree.reset_mock()
mock_makedirs.reset_mock()
mock_generate_state_dict.reset_mock()
mock_save_strategy_file.reset_mock()
# test when strategy dir doesn't exist
mock_exists.return_value = False
process_for_stand_alone_mode(mock_config, mock_network, strategy_path)
mock_rmtree.assert_not_called()
# Reset mocks for next test case
mock_barrier.reset_mock()
mock_rmtree.reset_mock()
mock_makedirs.reset_mock()
mock_generate_state_dict.reset_mock()
mock_save_strategy_file.reset_mock()
# test when not stand_alone mode
mock_config.parallel.parallel_mode = "DATA_PARALLEL"
process_for_stand_alone_mode(mock_config, mock_network, strategy_path)
mock_rmtree.assert_not_called()
mock_barrier.assert_not_called()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_load_checkpoint_with_safetensors(self, mock_config, mock_model, mock_network):
"""test load_checkpoint_with_safetensors function"""
# setup mocks using context managers
with patch('mindformers.utils.load_checkpoint_utils._check_checkpoint_path') as mock_check_checkpoint_path, \
patch('mindformers.utils.load_checkpoint_utils._get_checkpoint_mode') as mock_get_checkpoint_mode, \
patch('mindformers.utils.load_checkpoint_utils.'
'validate_config_with_file_mode') as mock_validate_config_with_file_mode, \
patch('mindformers.utils.load_checkpoint_utils.compile_model') as mock_compile_model, \
patch('mindformers.utils.load_checkpoint_utils.validate_qkv_concat'), \
patch('mindformers.utils.load_checkpoint_utils.process_for_stand_alone_mode'), \
patch('mindformers.utils.load_checkpoint_utils.'
'get_merged_dst_strategy_path') as mock_get_merged_dst_strategy_path, \
patch('mindformers.utils.load_checkpoint_utils.'
'load_safetensors_checkpoint') as mock_load_safetensors_checkpoint, \
patch('mindformers.utils.load_checkpoint_utils.logger'), \
patch('mindformers.utils.load_checkpoint_utils.barrier'):
# setup mocks return values
mock_check_checkpoint_path.return_value = "/valid/checkpoint"
mock_get_checkpoint_mode.return_value = CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value
mock_get_merged_dst_strategy_path.return_value = "/path/to/merged/strategy"
# setup input_data and optimizer
input_data = MagicMock()
optimizer = None
# test with do_eval=True
load_checkpoint_with_safetensors(mock_config, mock_model, mock_network, input_data, do_eval=True,
do_predict=False,
optimizer=optimizer)
mock_check_checkpoint_path.assert_called_once()
mock_get_checkpoint_mode.assert_called_once()
mock_validate_config_with_file_mode.assert_called_once()
mock_load_safetensors_checkpoint.assert_called_once()
# test with do_predict=True
mock_load_safetensors_checkpoint.reset_mock()
load_checkpoint_with_safetensors(mock_config, mock_model, mock_network, input_data, do_eval=False,
do_predict=True,
optimizer=optimizer)
mock_load_safetensors_checkpoint.assert_called_once()
# test with use_parallel=True
mock_config.use_parallel = True
mock_load_safetensors_checkpoint.reset_mock()
mock_compile_model.reset_mock()
load_checkpoint_with_safetensors(mock_config, mock_model, mock_network, input_data, do_eval=False,
do_predict=False,
optimizer=optimizer)
mock_compile_model.assert_called_once()
mock_load_safetensors_checkpoint.assert_called_once()
# test with resume_training=True
mock_config.resume_training = True
# Access protected member for testing purposes
# pylint: disable=W0212
mock_model._train_network = MagicMock()
mock_load_safetensors_checkpoint.reset_mock()
load_checkpoint_with_safetensors(mock_config, mock_model, mock_network, input_data, do_eval=False,
do_predict=False,
optimizer=optimizer)
mock_load_safetensors_checkpoint.assert_called_once()
class TestBuildModel:
"""A test class for testing build_model"""
runner_config = {'sink_mode': True, 'epochs': 1, 'sink_size': 1}
config = {'runner_config': runner_config}
config = {
'runner_config': runner_config,
'context': {'mode': 0} # Add context.mode to fix AttributeError
}
model = MagicMock()
dataset = MagicMock()
@@ -124,59 +886,75 @@ class TestBuildModel:
class TestGetCheckpointMode:
"""A test class for testing get_checkpoint_mode"""
@patch('os.path.isfile')
@patch('os.path.isdir')
def test_single_checkpoint_file(self, mock_isdir, mock_isfile):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_single_checkpoint_file(self):
"""test single checkpoint file"""
mock_isfile.return_value = True
mock_isdir.return_value = False
config = type('', (), {})()
config.load_checkpoint = '/test/checkpoint_file.safetensors'
assert _get_checkpoint_mode(config) == CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value
with patch('os.path.isfile') as mock_isfile, \
patch('os.path.isdir') as mock_isdir:
mock_isfile.return_value = True
mock_isdir.return_value = False
config = type('', (), {})()
config.load_checkpoint = '/test/checkpoint_file.safetensors'
assert _get_checkpoint_mode(config) == CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value
@patch('os.path.isfile')
@patch('os.path.isdir')
def test_multi_checkpoint_file_with_rank_id(self, mock_isdir, mock_isfile):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_multi_checkpoint_file_with_rank_id(self):
"""test multi checkpoint file with rank id"""
mock_isfile.return_value = False
mock_isdir.return_value = True
with patch('os.listdir', return_value=['rank_0']):
config = type('', (), {})()
config.load_checkpoint = '/test/checkpoint_dir/'
assert _get_checkpoint_mode(config) == CheckpointFileMode.MULTI_CHECKPOINT_FILE_WITH_RANK_ID.value
with patch('os.path.isfile') as mock_isfile, \
patch('os.path.isdir') as mock_isdir:
mock_isfile.return_value = False
mock_isdir.return_value = True
with patch('os.listdir', return_value=['rank_0']):
config = type('', (), {})()
config.load_checkpoint = '/test/checkpoint_dir/'
assert _get_checkpoint_mode(config) == CheckpointFileMode.MULTI_CHECKPOINT_FILE_WITH_RANK_ID.value
@patch('os.path.isfile')
@patch('os.path.isdir')
def test_multi_checkpoint_file(self, mock_isdir, mock_isfile):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_multi_checkpoint_file(self):
""" test multi checkpoint file"""
mock_isfile.return_value = False
mock_isdir.return_value = True
with patch('os.listdir', return_value=['checkpoint.safetensors']):
config = type('', (), {})()
config.load_checkpoint = '/test/checkpoint_dir/'
config.load_ckpt_format = '.safetensors'
assert _get_checkpoint_mode(config) == CheckpointFileMode.MULTI_CHECKPOINT_FILE.value
with patch('os.path.isfile') as mock_isfile, \
patch('os.path.isdir') as mock_isdir:
mock_isfile.return_value = False
mock_isdir.return_value = True
with patch('os.listdir', return_value=['checkpoint.safetensors']):
config = type('', (), {})()
config.load_checkpoint = '/test/checkpoint_dir/'
config.load_ckpt_format = '.safetensors'
assert _get_checkpoint_mode(config) == CheckpointFileMode.MULTI_CHECKPOINT_FILE.value
@patch('os.path.isfile')
@patch('os.path.isdir')
def test_invalid_path(self, mock_isdir, mock_isfile):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_invalid_path(self):
"""test invalid path"""
mock_isfile.return_value = False
mock_isdir.return_value = False
config = type('', (), {})()
config.load_checkpoint = 'invalid_path'
with pytest.raises(ValueError, match="Provided path is neither a file nor a directory."):
_get_checkpoint_mode(config)
@patch('os.path.isfile')
@patch('os.path.isdir')
def test_no_valid_checkpoint_files(self, mock_isdir, mock_isfile):
"""test no valid checkpoint files"""
mock_isfile.return_value = False
mock_isdir.return_value = True
with patch('os.listdir', return_value=['not_a_checkpoint_file']):
with patch('os.path.isfile') as mock_isfile, \
patch('os.path.isdir') as mock_isdir:
mock_isfile.return_value = False
mock_isdir.return_value = False
config = type('', (), {})()
config.load_checkpoint = '/test/checkpoint_dir/'
config.load_ckpt_format = '.safetensors'
with pytest.raises(ValueError, match="not support mode: no valid checkpoint files found"):
config.load_checkpoint = 'invalid_path'
with pytest.raises(ValueError, match="Provided path is neither a file nor a directory."):
_get_checkpoint_mode(config)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_no_valid_checkpoint_files(self):
"""test no valid checkpoint files"""
with patch('os.path.isfile') as mock_isfile, \
patch('os.path.isdir') as mock_isdir:
mock_isfile.return_value = False
mock_isdir.return_value = True
with patch('os.listdir', return_value=['not_a_checkpoint_file']):
config = type('', (), {})()
config.load_checkpoint = '/test/checkpoint_dir/'
config.load_ckpt_format = '.safetensors'
with pytest.raises(ValueError, match="not support mode: no valid checkpoint files found"):
_get_checkpoint_mode(config)