|
|
|
|
@@ -13,17 +13,95 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""test for load_checkpoint_utils."""
|
|
|
|
|
# pylint: disable=W0621
|
|
|
|
|
import tempfile
|
|
|
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
import numpy as np
|
|
|
|
|
from mindspore import Parameter
|
|
|
|
|
|
|
|
|
|
from mindformers.tools.register import MindFormerConfig
|
|
|
|
|
from mindformers.checkpoint.utils import compile_model
|
|
|
|
|
from mindformers.utils.load_checkpoint_utils import CkptFormat, _get_checkpoint_mode, CheckpointFileMode, \
|
|
|
|
|
_check_checkpoint_path
|
|
|
|
|
from mindformers.models.modeling_utils import PreTrainedModel
|
|
|
|
|
from mindformers.utils.load_checkpoint_utils import (
|
|
|
|
|
CkptFormat, _get_checkpoint_mode, CheckpointFileMode, _check_checkpoint_path,
|
|
|
|
|
extract_suffix, get_last_checkpoint, validate_config_with_file_mode,
|
|
|
|
|
update_global_step, unify_safetensors, _revise_remove_redundancy_with_file,
|
|
|
|
|
_get_origin_network, get_load_path_after_hf_convert, _get_src_strategy,
|
|
|
|
|
_get_src_file_suffix, _get_src_file, load_safetensors_checkpoint,
|
|
|
|
|
process_hf_checkpoint, validate_qkv_concat, get_merged_src_strategy_path,
|
|
|
|
|
get_merged_dst_strategy_path, process_for_stand_alone_mode,
|
|
|
|
|
load_checkpoint_with_safetensors
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
def mock_config():
|
|
|
|
|
"""Create a mock config with default values"""
|
|
|
|
|
|
|
|
|
|
class MockConfig:
|
|
|
|
|
"""Mock configuration class for testing"""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.load_checkpoint = "/path/to/checkpoint"
|
|
|
|
|
self.load_ckpt_format = "safetensors"
|
|
|
|
|
self.use_parallel = False
|
|
|
|
|
self.auto_trans_ckpt = False
|
|
|
|
|
self.resume_training = None
|
|
|
|
|
self.remove_redundancy = False
|
|
|
|
|
self.output_dir = "/output"
|
|
|
|
|
self.src_strategy_path_or_dir = None
|
|
|
|
|
self.load_ckpt_async = False
|
|
|
|
|
self.context = type('', (), {})()
|
|
|
|
|
self.context.mode = "GRAPH_MODE"
|
|
|
|
|
self.runner_config = type('', (), {})()
|
|
|
|
|
self.runner_config.sink_mode = True
|
|
|
|
|
self.runner_config.epochs = 1
|
|
|
|
|
self.runner_config.sink_size = 1
|
|
|
|
|
self.runner_config.step_scale = 2.0
|
|
|
|
|
self.model = type('', (), {})()
|
|
|
|
|
self.model.model_config = {}
|
|
|
|
|
self.parallel = type('', (), {})()
|
|
|
|
|
self.parallel.parallel_mode = "DATA_PARALLEL"
|
|
|
|
|
|
|
|
|
|
def get(self, key, default=None):
|
|
|
|
|
return getattr(self, key, default)
|
|
|
|
|
|
|
|
|
|
return MockConfig()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
def mock_network():
|
|
|
|
|
"""Create a mock network"""
|
|
|
|
|
mock_net = MagicMock()
|
|
|
|
|
mock_net.cells.return_value = []
|
|
|
|
|
return mock_net
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
def mock_model():
|
|
|
|
|
"""Create a mock model"""
|
|
|
|
|
mock_mod = MagicMock()
|
|
|
|
|
mock_mod.config = MagicMock()
|
|
|
|
|
mock_mod.config.model_type = "test_model"
|
|
|
|
|
return mock_mod
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
def mock_file():
|
|
|
|
|
"""Create a mock file"""
|
|
|
|
|
mock_f = MagicMock()
|
|
|
|
|
mock_f.metadata.return_value = None
|
|
|
|
|
return mock_f
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestCommonCheckpointMethod:
|
|
|
|
|
"""A test class for testing common methods"""
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_support_type(self):
|
|
|
|
|
"""test CkptFormat support type"""
|
|
|
|
|
# run the test
|
|
|
|
|
@@ -32,25 +110,709 @@ class TestCommonCheckpointMethod:
|
|
|
|
|
# verify the results
|
|
|
|
|
assert result == ['ckpt', 'safetensors']
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_check_checkpoint_path_with_non_string_pathlike(self):
|
|
|
|
|
"""test check checkpoint path with non string pathlike"""
|
|
|
|
|
path = 123
|
|
|
|
|
with pytest.raises(ValueError,
|
|
|
|
|
match=r"config.load_checkpoint must be a str, but got 123 as type <class 'int'>."):
|
|
|
|
|
match=r"config.load_checkpoint must be a `str`, but got `123` as type `<class 'int'>`."):
|
|
|
|
|
_check_checkpoint_path(path)
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_check_checkpoint_path_with_nonexistent_path(self):
|
|
|
|
|
"""test check checkpoint path with nonexistent path"""
|
|
|
|
|
path = 'NoneExistPath'
|
|
|
|
|
with pytest.raises(FileNotFoundError, match=r"config.load_checkpoint NoneExistPath does not exist."):
|
|
|
|
|
with pytest.raises(FileNotFoundError, match=r"config.load_checkpoint `NoneExistPath` does not exist."):
|
|
|
|
|
_check_checkpoint_path(path)
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_check_checkpoint_path_with_valid_path(self):
|
|
|
|
|
"""test check checkpoint path with valid path"""
|
|
|
|
|
# create a temporary directory for testing
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
|
|
|
# test with directory path
|
|
|
|
|
result = _check_checkpoint_path(tmpdir)
|
|
|
|
|
assert result == tmpdir
|
|
|
|
|
|
|
|
|
|
# test with directory path ending with slash
|
|
|
|
|
result = _check_checkpoint_path(tmpdir + '/')
|
|
|
|
|
assert result == tmpdir
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
"file_path, expected",
|
|
|
|
|
[
|
|
|
|
|
# test pattern 1: {prefix}_rank_{rank_id}-{epoch}_{step}.safetensors
|
|
|
|
|
("model_rank_0-10_200.safetensors", "-10_200"),
|
|
|
|
|
# test pattern 2: {prefix}_rank_{rank_id}_{task_id}-{epoch}_{step}.safetensors
|
|
|
|
|
("model_rank_0_1-10_200.safetensors", "_1-10_200"),
|
|
|
|
|
# test with invalid pattern
|
|
|
|
|
("invalid_filename.safetensors", "invalid_filename")
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
def test_extract_suffix(self, file_path, expected):
|
|
|
|
|
"""test extract_suffix function"""
|
|
|
|
|
result = extract_suffix(file_path)
|
|
|
|
|
assert result == expected
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_get_last_checkpoint(self):
|
|
|
|
|
"""test get_last_checkpoint function"""
|
|
|
|
|
# setup mocks using context managers
|
|
|
|
|
with patch('os.path.isdir') as mock_isdir, \
|
|
|
|
|
patch('os.path.exists') as mock_exists, \
|
|
|
|
|
patch('os.listdir') as mock_listdir, \
|
|
|
|
|
patch('os.path.getmtime') as mock_getmtime:
|
|
|
|
|
# setup mock return values
|
|
|
|
|
mock_isdir.return_value = True
|
|
|
|
|
mock_exists.return_value = True
|
|
|
|
|
mock_listdir.return_value = ["model_0.ckpt", "model_1.ckpt", "model_2.ckpt"]
|
|
|
|
|
mock_getmtime.side_effect = lambda x: {
|
|
|
|
|
"/test/model_0.ckpt": 100,
|
|
|
|
|
"/test/model_1.ckpt": 200,
|
|
|
|
|
"/test/model_2.ckpt": 300
|
|
|
|
|
}[x]
|
|
|
|
|
|
|
|
|
|
# test with valid directory
|
|
|
|
|
result = get_last_checkpoint("/test", "ckpt")
|
|
|
|
|
assert result == "/test/model_2.ckpt"
|
|
|
|
|
|
|
|
|
|
# test with no checkpoint files
|
|
|
|
|
mock_listdir.return_value = ["other_file.txt"]
|
|
|
|
|
result = get_last_checkpoint("/test", "ckpt")
|
|
|
|
|
assert result is None
|
|
|
|
|
|
|
|
|
|
# test with invalid directory
|
|
|
|
|
mock_isdir.return_value = False
|
|
|
|
|
with pytest.raises(NotADirectoryError):
|
|
|
|
|
get_last_checkpoint("/invalid/dir", "ckpt")
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
"file_mode, use_parallel, auto_trans_ckpt, expected_exception",
|
|
|
|
|
[
|
|
|
|
|
# test single checkpoint file mode with parallel
|
|
|
|
|
(CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value, True, False, ValueError),
|
|
|
|
|
# test multi checkpoint file mode with parallel but no auto_trans_ckpt
|
|
|
|
|
(CheckpointFileMode.MULTI_CHECKPOINT_FILE.value, True, False, ValueError),
|
|
|
|
|
# test multi checkpoint file with rank id mode without parallel
|
|
|
|
|
(CheckpointFileMode.MULTI_CHECKPOINT_FILE_WITH_RANK_ID.value, False, False, ValueError),
|
|
|
|
|
# test invalid mode
|
|
|
|
|
("invalid_mode", False, False, ValueError),
|
|
|
|
|
# test valid cases - no exception expected
|
|
|
|
|
(CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value, False, False, None),
|
|
|
|
|
(CheckpointFileMode.MULTI_CHECKPOINT_FILE.value, True, True, None),
|
|
|
|
|
(CheckpointFileMode.MULTI_CHECKPOINT_FILE_WITH_RANK_ID.value, True, False, None)
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
def test_validate_config_with_file_mode(self, file_mode, use_parallel, auto_trans_ckpt, expected_exception):
|
|
|
|
|
"""test validate_config_with_file_mode function"""
|
|
|
|
|
if expected_exception:
|
|
|
|
|
with pytest.raises(expected_exception):
|
|
|
|
|
validate_config_with_file_mode(file_mode, use_parallel, auto_trans_ckpt)
|
|
|
|
|
else:
|
|
|
|
|
validate_config_with_file_mode(file_mode, use_parallel, auto_trans_ckpt)
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
"step_scale, initial_global_step, expected_global_step, expected_in_dict",
|
|
|
|
|
[
|
|
|
|
|
(2.0, 100, 200, True),
|
|
|
|
|
(None, 100, 100, True),
|
|
|
|
|
(2.0, None, None, False)
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
def test_update_global_step(self, step_scale, initial_global_step, expected_global_step, expected_in_dict):
|
|
|
|
|
"""test update_global_step function"""
|
|
|
|
|
# setup config
|
|
|
|
|
config = type('', (), {})()
|
|
|
|
|
config.runner_config = type('', (), {})()
|
|
|
|
|
config.runner_config.step_scale = step_scale
|
|
|
|
|
|
|
|
|
|
# setup hyper_param_dict
|
|
|
|
|
hyper_param_dict = {}
|
|
|
|
|
if initial_global_step is not None:
|
|
|
|
|
hyper_param_dict["global_step"] = Parameter(np.array(initial_global_step, dtype=np.int32))
|
|
|
|
|
|
|
|
|
|
# test update_global_step
|
|
|
|
|
update_global_step(config, hyper_param_dict)
|
|
|
|
|
|
|
|
|
|
# verify the results
|
|
|
|
|
if expected_in_dict:
|
|
|
|
|
assert "global_step" in hyper_param_dict
|
|
|
|
|
assert hyper_param_dict["global_step"].asnumpy() == expected_global_step
|
|
|
|
|
else:
|
|
|
|
|
assert "global_step" not in hyper_param_dict
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_unify_safetensors(self):
|
|
|
|
|
"""test unify_safetensors function"""
|
|
|
|
|
# setup mocks using context managers
|
|
|
|
|
with patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.barrier') as mock_barrier, \
|
|
|
|
|
patch('mindspore.unified_safetensors') as mock_unified_safetensors:
|
|
|
|
|
# test when is_main_rank is True
|
|
|
|
|
mock_is_main_rank.return_value = True
|
|
|
|
|
unify_safetensors("/src/checkpoint", "/src/strategy", "/dst/unified", True, "-10_200", False)
|
|
|
|
|
mock_unified_safetensors.assert_called_once()
|
|
|
|
|
mock_barrier.assert_called_once()
|
|
|
|
|
|
|
|
|
|
# test when is_main_rank is False
|
|
|
|
|
mock_is_main_rank.return_value = False
|
|
|
|
|
mock_barrier.reset_mock()
|
|
|
|
|
unify_safetensors("/src/checkpoint", "/src/strategy", "/dst/unified", True, "-10_200", False)
|
|
|
|
|
mock_unified_safetensors.assert_called_once() # should not be called again
|
|
|
|
|
mock_barrier.assert_called_once()
|
|
|
|
|
|
|
|
|
|
# test without parallel
|
|
|
|
|
mock_is_main_rank.return_value = True
|
|
|
|
|
mock_barrier.reset_mock()
|
|
|
|
|
unify_safetensors("/src/checkpoint", "/src/strategy", "/dst/unified", False, "-10_200", False)
|
|
|
|
|
mock_barrier.assert_not_called()
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
"config_remove_redundancy, metadata, expected_result",
|
|
|
|
|
[
|
|
|
|
|
# test with metadata remove_redundancy=True and config remove_redundancy=False
|
|
|
|
|
(False, {"remove_redundancy": "True"}, True),
|
|
|
|
|
# test with metadata remove_redundancy=False and config remove_redundancy=True
|
|
|
|
|
(True, {"remove_redundancy": "False"}, False),
|
|
|
|
|
# test with matching metadata and config
|
|
|
|
|
(True, {"remove_redundancy": "True"}, True),
|
|
|
|
|
# test with no metadata
|
|
|
|
|
(True, None, True),
|
|
|
|
|
# test with metadata but no remove_redundancy key
|
|
|
|
|
(True, {"other_key": "value"}, True)
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
def test__revise_remove_redundancy_with_file(self, config_remove_redundancy, metadata, expected_result, mock_file):
|
|
|
|
|
"""test _revise_remove_redundancy_with_file function"""
|
|
|
|
|
mock_file.metadata.return_value = metadata
|
|
|
|
|
result = _revise_remove_redundancy_with_file(config_remove_redundancy, mock_file)
|
|
|
|
|
assert result == expected_result
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
"network_has_convert_name, child_has_convert_name, expected_found",
|
|
|
|
|
[
|
|
|
|
|
# test with network that has convert_name
|
|
|
|
|
(True, False, True),
|
|
|
|
|
# test with nested network where child has convert_name
|
|
|
|
|
(False, True, True),
|
|
|
|
|
# test with network that doesn't have convert_name and no children with it
|
|
|
|
|
(False, False, False)
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
def test__get_origin_network(self, network_has_convert_name, child_has_convert_name, expected_found):
|
|
|
|
|
"""test _get_origin_network function"""
|
|
|
|
|
# setup mocks using context managers
|
|
|
|
|
with patch('mindformers.utils.load_checkpoint_utils.logger'):
|
|
|
|
|
if network_has_convert_name:
|
|
|
|
|
# create a mock network with convert_name attribute
|
|
|
|
|
mock_network = MagicMock()
|
|
|
|
|
mock_network.convert_name = MagicMock()
|
|
|
|
|
# Return empty list for cells() to avoid recursion
|
|
|
|
|
mock_network.cells.return_value = []
|
|
|
|
|
else:
|
|
|
|
|
if child_has_convert_name:
|
|
|
|
|
# create a mock network without convert_name but with a child that has it
|
|
|
|
|
mock_child = MagicMock()
|
|
|
|
|
mock_child.convert_name = MagicMock()
|
|
|
|
|
# Return empty list for cells() to avoid further recursion
|
|
|
|
|
mock_child.cells.return_value = []
|
|
|
|
|
|
|
|
|
|
# Create a network that returns the child directly when cells() is called
|
|
|
|
|
mock_network = MagicMock()
|
|
|
|
|
mock_network.cells.return_value = [mock_child]
|
|
|
|
|
else:
|
|
|
|
|
# create a mock network without convert_name and no children with it
|
|
|
|
|
mock_network = MagicMock()
|
|
|
|
|
mock_network.cells.return_value = []
|
|
|
|
|
|
|
|
|
|
# Remove convert_name attribute to simulate network without it
|
|
|
|
|
if hasattr(mock_network, 'convert_name'):
|
|
|
|
|
delattr(mock_network, 'convert_name')
|
|
|
|
|
|
|
|
|
|
# run the test
|
|
|
|
|
_, found = _get_origin_network(mock_network)
|
|
|
|
|
assert found == expected_found
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_get_load_path_after_hf_convert(self, mock_config, mock_network):
|
|
|
|
|
"""test get_load_path_after_hf_convert function"""
|
|
|
|
|
# setup mocks using context managers
|
|
|
|
|
with patch('mindformers.utils.load_checkpoint_utils.is_hf_safetensors_dir') as mock_is_hf_safetensors_dir, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.'
|
|
|
|
|
'check_safetensors_addition_param_support') as mock_check_support:
|
|
|
|
|
# test when not hf safetensors
|
|
|
|
|
mock_is_hf_safetensors_dir.return_value = False
|
|
|
|
|
result = get_load_path_after_hf_convert(mock_config, mock_network)
|
|
|
|
|
assert result == "/path/to/checkpoint"
|
|
|
|
|
|
|
|
|
|
# test when hf safetensors but not qkv_concat and not supported
|
|
|
|
|
mock_is_hf_safetensors_dir.return_value = True
|
|
|
|
|
mock_check_support.return_value = False
|
|
|
|
|
mock_config.model.model_config = {"qkv_concat": False}
|
|
|
|
|
|
|
|
|
|
with patch('mindformers.utils.load_checkpoint_utils.process_hf_checkpoint',
|
|
|
|
|
return_value="/path/to/converted"):
|
|
|
|
|
with patch('mindformers.utils.load_checkpoint_utils.barrier'):
|
|
|
|
|
result = get_load_path_after_hf_convert(mock_config, mock_network)
|
|
|
|
|
assert result == "/path/to/converted"
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test__get_src_strategy(self, mock_config):
|
|
|
|
|
"""test _get_src_strategy function"""
|
|
|
|
|
# setup mocks using context managers
|
|
|
|
|
with patch('os.path.isfile') as mock_isfile, \
|
|
|
|
|
patch('os.path.isdir') as mock_isdir, \
|
|
|
|
|
patch('os.path.join') as mock_join, \
|
|
|
|
|
patch('os.path.exists') as mock_exists, \
|
|
|
|
|
patch('os.path.dirname') as mock_dirname, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.logger'):
|
|
|
|
|
# Test case 1: input_src_strategy is provided
|
|
|
|
|
mock_config.load_checkpoint = "/test/checkpoint.ckpt"
|
|
|
|
|
mock_config.src_strategy_path_or_dir = "/input/strategy"
|
|
|
|
|
mock_isdir.return_value = True
|
|
|
|
|
result = _get_src_strategy(mock_config)
|
|
|
|
|
assert result == "/input/strategy"
|
|
|
|
|
|
|
|
|
|
# Test case 2: no strategy dir exists
|
|
|
|
|
mock_config.src_strategy_path_or_dir = None
|
|
|
|
|
mock_isfile.return_value = True
|
|
|
|
|
mock_exists.return_value = False
|
|
|
|
|
|
|
|
|
|
with pytest.raises(
|
|
|
|
|
ValueError,
|
|
|
|
|
match="when use checkpoint after train/finetune, src_strategy_path_or_dir should be set"
|
|
|
|
|
):
|
|
|
|
|
_get_src_strategy(mock_config)
|
|
|
|
|
|
|
|
|
|
# Test case 3: config.load_checkpoint is a directory and strategy dir exists
|
|
|
|
|
mock_isfile.return_value = False
|
|
|
|
|
mock_exists.return_value = True
|
|
|
|
|
|
|
|
|
|
# Setup mock_dirname to return a valid parent directory
|
|
|
|
|
mock_dirname.return_value = "/test"
|
|
|
|
|
|
|
|
|
|
# Setup mock_join to return a valid path
|
|
|
|
|
mock_join.return_value = "/test/strategy"
|
|
|
|
|
|
|
|
|
|
mock_config.load_checkpoint = "/test/checkpoint_dir"
|
|
|
|
|
mock_config.src_strategy_path_or_dir = None
|
|
|
|
|
|
|
|
|
|
result = _get_src_strategy(mock_config)
|
|
|
|
|
assert result == "/test/strategy"
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test__get_src_file_suffix(self, mock_config):
|
|
|
|
|
"""test _get_src_file_suffix function"""
|
|
|
|
|
# setup mocks using context managers
|
|
|
|
|
with patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.get_last_checkpoint') as mock_get_last_checkpoint, \
|
|
|
|
|
patch('os.path.isfile') as mock_isfile, \
|
|
|
|
|
patch('os.path.isdir') as mock_isdir:
|
|
|
|
|
# test when is_main_rank is True and resume_training is string
|
|
|
|
|
mock_is_main_rank.return_value = True
|
|
|
|
|
mock_config.resume_training = "checkpoint-10_200.safetensors"
|
|
|
|
|
mock_config.load_checkpoint = "/path/to/checkpoint"
|
|
|
|
|
mock_config.load_ckpt_format = "safetensors"
|
|
|
|
|
|
|
|
|
|
with patch('mindformers.utils.load_checkpoint_utils.extract_suffix', return_value="-10_200"):
|
|
|
|
|
result = _get_src_file_suffix(mock_config)
|
|
|
|
|
assert result == ("/path/to/checkpoint", "-10_200")
|
|
|
|
|
|
|
|
|
|
# test when is_main_rank is True and load_checkpoint is file
|
|
|
|
|
mock_isfile.return_value = True
|
|
|
|
|
mock_isdir.return_value = False
|
|
|
|
|
mock_config.resume_training = None
|
|
|
|
|
mock_config.load_checkpoint = "/path/to/rank_0/checkpoint-10_200.safetensors"
|
|
|
|
|
|
|
|
|
|
with patch('mindformers.utils.load_checkpoint_utils.extract_suffix', return_value="-10_200"):
|
|
|
|
|
result = _get_src_file_suffix(mock_config)
|
|
|
|
|
assert result == ("/path/to", "-10_200")
|
|
|
|
|
|
|
|
|
|
# test when is_main_rank is True and load_checkpoint is dir
|
|
|
|
|
mock_isfile.return_value = False
|
|
|
|
|
mock_isdir.return_value = True
|
|
|
|
|
mock_config.load_checkpoint = "/path/to/checkpoint"
|
|
|
|
|
mock_get_last_checkpoint.return_value = "/path/to/checkpoint/rank_0/checkpoint-10_200.safetensors"
|
|
|
|
|
|
|
|
|
|
with patch('mindformers.utils.load_checkpoint_utils.extract_suffix', return_value="-10_200"):
|
|
|
|
|
result = _get_src_file_suffix(mock_config)
|
|
|
|
|
assert result == ("/path/to/checkpoint", "-10_200")
|
|
|
|
|
|
|
|
|
|
# test when is_main_rank is False
|
|
|
|
|
mock_is_main_rank.return_value = False
|
|
|
|
|
result = _get_src_file_suffix(mock_config)
|
|
|
|
|
assert result == (None, None)
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test__get_src_file(self):
|
|
|
|
|
"""test _get_src_file function"""
|
|
|
|
|
# setup mocks using context managers
|
|
|
|
|
with patch('os.path.exists') as mock_exists, \
|
|
|
|
|
patch('os.path.join') as mock_join, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.get_real_rank') as mock_get_real_rank, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.get_last_checkpoint') as mock_get_last_checkpoint:
|
|
|
|
|
# test with checkpoint_name provided
|
|
|
|
|
mock_get_real_rank.return_value = 0
|
|
|
|
|
mock_join.return_value = "/test/rank_0/checkpoint.ckpt"
|
|
|
|
|
mock_exists.return_value = True
|
|
|
|
|
|
|
|
|
|
result = _get_src_file("/test", "checkpoint.ckpt", "ckpt")
|
|
|
|
|
assert result == "/test/rank_0/checkpoint.ckpt"
|
|
|
|
|
|
|
|
|
|
# test without checkpoint_name
|
|
|
|
|
mock_get_last_checkpoint.return_value = "/test/rank_0/last_checkpoint.ckpt"
|
|
|
|
|
result = _get_src_file("/test", None, "ckpt")
|
|
|
|
|
assert result == "/test/rank_0/last_checkpoint.ckpt"
|
|
|
|
|
|
|
|
|
|
# test with non-existent file
|
|
|
|
|
mock_exists.return_value = False
|
|
|
|
|
with pytest.raises(FileNotFoundError):
|
|
|
|
|
_get_src_file("/test", "non_existent.ckpt", "ckpt")
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_load_safetensors_checkpoint(self, mock_config, mock_network, mock_file):
|
|
|
|
|
"""test load_safetensors_checkpoint function"""
|
|
|
|
|
# Setup mocks using context managers
|
|
|
|
|
with patch('mindformers.utils.load_checkpoint_utils._get_origin_network') as mock_get_origin_network, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.ms') as mock_ms, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.logger'), \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.safe_open') as mock_safe_open, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.is_hf_safetensors_dir') as mock_is_hf_safetensors_dir:
|
|
|
|
|
# Setup mock return values
|
|
|
|
|
mock_get_origin_network.return_value = (MagicMock(), False)
|
|
|
|
|
mock_ms.load_checkpoint.return_value = {"param1": MagicMock()}
|
|
|
|
|
mock_is_hf_safetensors_dir.return_value = False
|
|
|
|
|
|
|
|
|
|
# Mock the safe_open context manager
|
|
|
|
|
mock_safe_open.return_value.__enter__.return_value = mock_file
|
|
|
|
|
|
|
|
|
|
strategy_path = "/path/to/strategy"
|
|
|
|
|
load_ckpt_path = "/path/to/checkpoint"
|
|
|
|
|
optimizer = None
|
|
|
|
|
|
|
|
|
|
load_safetensors_checkpoint(mock_config, ["/path/to/checkpoint.safetensors"], mock_network, strategy_path,
|
|
|
|
|
load_ckpt_path,
|
|
|
|
|
optimizer)
|
|
|
|
|
mock_ms.load_param_into_net.assert_called_once()
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_process_hf_checkpoint(self, mock_model, tmp_path):
|
|
|
|
|
"""test process_hf_checkpoint function"""
|
|
|
|
|
# setup mocks using context managers
|
|
|
|
|
with patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.barrier_world') as mock_barrier_world, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.Process') as mock_process:
|
|
|
|
|
# test when is_main_rank is True
|
|
|
|
|
mock_is_main_rank.return_value = True
|
|
|
|
|
mock_process_instance = MagicMock()
|
|
|
|
|
mock_process_instance.exitcode = 0
|
|
|
|
|
mock_process.return_value = mock_process_instance
|
|
|
|
|
|
|
|
|
|
# Use tmp_path for output and input paths
|
|
|
|
|
output_dir = tmp_path / "output" / "dir"
|
|
|
|
|
input_checkpoint = tmp_path / "input" / "checkpoint"
|
|
|
|
|
# Create input directory
|
|
|
|
|
input_checkpoint.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
result = process_hf_checkpoint(mock_model, str(output_dir), str(input_checkpoint))
|
|
|
|
|
expected_path = str(output_dir / "test_model_ms_converted_weight")
|
|
|
|
|
assert result == expected_path
|
|
|
|
|
mock_process_instance.start.assert_called_once()
|
|
|
|
|
mock_process_instance.join.assert_called_once()
|
|
|
|
|
mock_barrier_world.assert_called_once()
|
|
|
|
|
|
|
|
|
|
# Reset mocks for next test case
|
|
|
|
|
mock_process.reset_mock()
|
|
|
|
|
mock_process_instance = MagicMock()
|
|
|
|
|
mock_process_instance.exitcode = 1
|
|
|
|
|
mock_process.return_value = mock_process_instance
|
|
|
|
|
|
|
|
|
|
# test when process exits with error
|
|
|
|
|
with pytest.raises(RuntimeError, match="convert HuggingFace weight failed."):
|
|
|
|
|
process_hf_checkpoint(mock_model, str(output_dir), str(input_checkpoint))
|
|
|
|
|
|
|
|
|
|
# Reset mocks for next test case
|
|
|
|
|
mock_process.reset_mock()
|
|
|
|
|
mock_process_instance = MagicMock()
|
|
|
|
|
mock_process_instance.exitcode = 0
|
|
|
|
|
mock_process.return_value = mock_process_instance
|
|
|
|
|
|
|
|
|
|
# test when is_main_rank is False
|
|
|
|
|
mock_is_main_rank.return_value = False
|
|
|
|
|
process_hf_checkpoint(mock_model, str(output_dir), str(input_checkpoint))
|
|
|
|
|
mock_process_instance.start.assert_not_called()
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
"model, qkv_concat_config, check_safetensors_key_return, "
|
|
|
|
|
"has_concat_keys, expected_exception, should_log_warning",
|
|
|
|
|
[
|
|
|
|
|
# test with non-PreTrainedModel
|
|
|
|
|
("not_a_model", False, False, False, None, True),
|
|
|
|
|
# test with PreTrainedModel but no concat keys
|
|
|
|
|
(MagicMock(spec=PreTrainedModel), False, False, False, None, False),
|
|
|
|
|
# Test case where check_safetensors_key returns True and qkv_concat_config is True
|
|
|
|
|
(MagicMock(spec=PreTrainedModel), True, True, True, None, False),
|
|
|
|
|
# Test case where check_safetensors_key returns False and qkv_concat_config is True
|
|
|
|
|
(MagicMock(spec=PreTrainedModel), True, False, True, ValueError, False),
|
|
|
|
|
# Test case where check_safetensors_key returns True and qkv_concat_config is False
|
|
|
|
|
(MagicMock(spec=PreTrainedModel), False, True, True, ValueError, False)
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
def test_validate_qkv_concat(self, model, qkv_concat_config,
|
|
|
|
|
check_safetensors_key_return, has_concat_keys, expected_exception, should_log_warning):
|
|
|
|
|
"""test validate_qkv_concat function"""
|
|
|
|
|
# Setup mocks using context managers
|
|
|
|
|
with patch('mindformers.utils.load_checkpoint_utils.logger') as mock_logger, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.check_safetensors_key') as mock_check_safetensors_key:
|
|
|
|
|
|
|
|
|
|
# Setup mock behavior
|
|
|
|
|
mock_check_safetensors_key.return_value = check_safetensors_key_return
|
|
|
|
|
|
|
|
|
|
# If it's a PreTrainedModel, set up obtain_qkv_ffn_concat_keys
|
|
|
|
|
if hasattr(model, 'obtain_qkv_ffn_concat_keys'):
|
|
|
|
|
model.obtain_qkv_ffn_concat_keys.return_value = ["qkv_concat_key"] if has_concat_keys else None
|
|
|
|
|
|
|
|
|
|
# Run the test and check results
|
|
|
|
|
if expected_exception:
|
|
|
|
|
with pytest.raises(expected_exception, match="The qkv concat check failed!"):
|
|
|
|
|
validate_qkv_concat(model, qkv_concat_config, "/path/to/checkpoint")
|
|
|
|
|
else:
|
|
|
|
|
validate_qkv_concat(model, qkv_concat_config, "/path/to/checkpoint")
|
|
|
|
|
|
|
|
|
|
# Check if warning was logged when expected
|
|
|
|
|
if should_log_warning:
|
|
|
|
|
mock_logger.warning.assert_called_once()
|
|
|
|
|
else:
|
|
|
|
|
mock_logger.warning.assert_not_called()
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_get_merged_src_strategy_path(self, mock_config):
|
|
|
|
|
"""test get_merged_src_strategy_path function"""
|
|
|
|
|
# setup mocks using context managers
|
|
|
|
|
with patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.barrier') as mock_barrier, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils._get_src_strategy') as mock_get_src_strategy, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.ms.merge_pipeline_strategys') as mock_merge_strategys, \
|
|
|
|
|
patch('os.makedirs'):
|
|
|
|
|
# test when is_main_rank is True
|
|
|
|
|
mock_is_main_rank.return_value = True
|
|
|
|
|
mock_get_src_strategy.return_value = "/input/strategy"
|
|
|
|
|
|
|
|
|
|
result = get_merged_src_strategy_path(mock_config)
|
|
|
|
|
assert result == "/output/merged_strategy/src_strategy.ckpt"
|
|
|
|
|
mock_merge_strategys.assert_called_once()
|
|
|
|
|
mock_barrier.assert_called_once()
|
|
|
|
|
|
|
|
|
|
# test when is_main_rank is False
|
|
|
|
|
mock_is_main_rank.return_value = False
|
|
|
|
|
mock_barrier.reset_mock()
|
|
|
|
|
result = get_merged_src_strategy_path(mock_config)
|
|
|
|
|
mock_merge_strategys.assert_called_once() # should not be called again
|
|
|
|
|
mock_barrier.assert_called_once()
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_get_merged_dst_strategy_path(self, mock_config):
|
|
|
|
|
"""test get_merged_dst_strategy_path function"""
|
|
|
|
|
# setup mocks using context managers
|
|
|
|
|
with patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.barrier') as mock_barrier, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.ms.merge_pipeline_strategys') as mock_merge_strategys, \
|
|
|
|
|
patch('os.makedirs'):
|
|
|
|
|
# test with use_parallel=True, auto_trans_ckpt=True, not stand_alone
|
|
|
|
|
mock_is_main_rank.return_value = True
|
|
|
|
|
|
|
|
|
|
mock_config.use_parallel = True
|
|
|
|
|
mock_config.auto_trans_ckpt = True
|
|
|
|
|
mock_config.parallel.parallel_mode = "DATA_PARALLEL"
|
|
|
|
|
|
|
|
|
|
strategy_path = "/path/to/strategy.ckpt"
|
|
|
|
|
|
|
|
|
|
result = get_merged_dst_strategy_path(mock_config, strategy_path)
|
|
|
|
|
assert result == "/output/merged_strategy/dst_strategy.ckpt"
|
|
|
|
|
mock_merge_strategys.assert_called_once()
|
|
|
|
|
mock_barrier.assert_called_once()
|
|
|
|
|
|
|
|
|
|
# test with stand_alone mode
|
|
|
|
|
mock_config.parallel.parallel_mode = "STAND_ALONE"
|
|
|
|
|
result = get_merged_dst_strategy_path(mock_config, strategy_path)
|
|
|
|
|
assert result == "/path/to/strategy.ckpt"
|
|
|
|
|
|
|
|
|
|
# test with use_parallel=False
|
|
|
|
|
mock_config.use_parallel = False
|
|
|
|
|
result = get_merged_dst_strategy_path(mock_config, strategy_path)
|
|
|
|
|
assert result == "/path/to/strategy.ckpt"
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_process_for_stand_alone_mode(self, mock_config, mock_network):
|
|
|
|
|
"""test process_for_stand_alone_mode function"""
|
|
|
|
|
strategy_path = "/path/to/strategy.ckpt"
|
|
|
|
|
|
|
|
|
|
# setup mocks using context managers
|
|
|
|
|
with patch('mindformers.utils.load_checkpoint_utils._pynative_executor'), \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.is_main_rank') as mock_is_main_rank, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.barrier') as mock_barrier, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.generate_state_dict') as mock_generate_state_dict, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.save_strategy_file') as mock_save_strategy_file, \
|
|
|
|
|
patch('os.makedirs') as mock_makedirs, \
|
|
|
|
|
patch('shutil.rmtree') as mock_rmtree, \
|
|
|
|
|
patch('os.path.exists') as mock_exists:
|
|
|
|
|
# test with stand_alone mode
|
|
|
|
|
mock_is_main_rank.return_value = True
|
|
|
|
|
mock_exists.return_value = True
|
|
|
|
|
mock_config.parallel.parallel_mode = "STAND_ALONE"
|
|
|
|
|
mock_config.use_parallel = True
|
|
|
|
|
|
|
|
|
|
process_for_stand_alone_mode(mock_config, mock_network, strategy_path)
|
|
|
|
|
mock_rmtree.assert_called_once()
|
|
|
|
|
mock_makedirs.assert_called_once()
|
|
|
|
|
mock_generate_state_dict.assert_called_once()
|
|
|
|
|
mock_save_strategy_file.assert_called_once()
|
|
|
|
|
mock_barrier.assert_called()
|
|
|
|
|
|
|
|
|
|
# Reset mocks for next test case
|
|
|
|
|
mock_barrier.reset_mock()
|
|
|
|
|
mock_rmtree.reset_mock()
|
|
|
|
|
mock_makedirs.reset_mock()
|
|
|
|
|
mock_generate_state_dict.reset_mock()
|
|
|
|
|
mock_save_strategy_file.reset_mock()
|
|
|
|
|
|
|
|
|
|
# test when strategy dir doesn't exist
|
|
|
|
|
mock_exists.return_value = False
|
|
|
|
|
process_for_stand_alone_mode(mock_config, mock_network, strategy_path)
|
|
|
|
|
mock_rmtree.assert_not_called()
|
|
|
|
|
|
|
|
|
|
# Reset mocks for next test case
|
|
|
|
|
mock_barrier.reset_mock()
|
|
|
|
|
mock_rmtree.reset_mock()
|
|
|
|
|
mock_makedirs.reset_mock()
|
|
|
|
|
mock_generate_state_dict.reset_mock()
|
|
|
|
|
mock_save_strategy_file.reset_mock()
|
|
|
|
|
|
|
|
|
|
# test when not stand_alone mode
|
|
|
|
|
mock_config.parallel.parallel_mode = "DATA_PARALLEL"
|
|
|
|
|
process_for_stand_alone_mode(mock_config, mock_network, strategy_path)
|
|
|
|
|
mock_rmtree.assert_not_called()
|
|
|
|
|
mock_barrier.assert_not_called()
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_load_checkpoint_with_safetensors(self, mock_config, mock_model, mock_network):
|
|
|
|
|
"""test load_checkpoint_with_safetensors function"""
|
|
|
|
|
# setup mocks using context managers
|
|
|
|
|
with patch('mindformers.utils.load_checkpoint_utils._check_checkpoint_path') as mock_check_checkpoint_path, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils._get_checkpoint_mode') as mock_get_checkpoint_mode, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.'
|
|
|
|
|
'validate_config_with_file_mode') as mock_validate_config_with_file_mode, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.compile_model') as mock_compile_model, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.validate_qkv_concat'), \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.process_for_stand_alone_mode'), \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.'
|
|
|
|
|
'get_merged_dst_strategy_path') as mock_get_merged_dst_strategy_path, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.'
|
|
|
|
|
'load_safetensors_checkpoint') as mock_load_safetensors_checkpoint, \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.logger'), \
|
|
|
|
|
patch('mindformers.utils.load_checkpoint_utils.barrier'):
|
|
|
|
|
# setup mocks return values
|
|
|
|
|
mock_check_checkpoint_path.return_value = "/valid/checkpoint"
|
|
|
|
|
mock_get_checkpoint_mode.return_value = CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value
|
|
|
|
|
mock_get_merged_dst_strategy_path.return_value = "/path/to/merged/strategy"
|
|
|
|
|
|
|
|
|
|
# setup input_data and optimizer
|
|
|
|
|
input_data = MagicMock()
|
|
|
|
|
optimizer = None
|
|
|
|
|
|
|
|
|
|
# test with do_eval=True
|
|
|
|
|
load_checkpoint_with_safetensors(mock_config, mock_model, mock_network, input_data, do_eval=True,
|
|
|
|
|
do_predict=False,
|
|
|
|
|
optimizer=optimizer)
|
|
|
|
|
mock_check_checkpoint_path.assert_called_once()
|
|
|
|
|
mock_get_checkpoint_mode.assert_called_once()
|
|
|
|
|
mock_validate_config_with_file_mode.assert_called_once()
|
|
|
|
|
mock_load_safetensors_checkpoint.assert_called_once()
|
|
|
|
|
|
|
|
|
|
# test with do_predict=True
|
|
|
|
|
mock_load_safetensors_checkpoint.reset_mock()
|
|
|
|
|
load_checkpoint_with_safetensors(mock_config, mock_model, mock_network, input_data, do_eval=False,
|
|
|
|
|
do_predict=True,
|
|
|
|
|
optimizer=optimizer)
|
|
|
|
|
mock_load_safetensors_checkpoint.assert_called_once()
|
|
|
|
|
|
|
|
|
|
# test with use_parallel=True
|
|
|
|
|
mock_config.use_parallel = True
|
|
|
|
|
mock_load_safetensors_checkpoint.reset_mock()
|
|
|
|
|
mock_compile_model.reset_mock()
|
|
|
|
|
load_checkpoint_with_safetensors(mock_config, mock_model, mock_network, input_data, do_eval=False,
|
|
|
|
|
do_predict=False,
|
|
|
|
|
optimizer=optimizer)
|
|
|
|
|
mock_compile_model.assert_called_once()
|
|
|
|
|
mock_load_safetensors_checkpoint.assert_called_once()
|
|
|
|
|
|
|
|
|
|
# test with resume_training=True
|
|
|
|
|
mock_config.resume_training = True
|
|
|
|
|
# Access protected member for testing purposes
|
|
|
|
|
# pylint: disable=W0212
|
|
|
|
|
mock_model._train_network = MagicMock()
|
|
|
|
|
mock_load_safetensors_checkpoint.reset_mock()
|
|
|
|
|
load_checkpoint_with_safetensors(mock_config, mock_model, mock_network, input_data, do_eval=False,
|
|
|
|
|
do_predict=False,
|
|
|
|
|
optimizer=optimizer)
|
|
|
|
|
mock_load_safetensors_checkpoint.assert_called_once()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestBuildModel:
|
|
|
|
|
"""A test class for testing build_model"""
|
|
|
|
|
runner_config = {'sink_mode': True, 'epochs': 1, 'sink_size': 1}
|
|
|
|
|
config = {'runner_config': runner_config}
|
|
|
|
|
config = {
|
|
|
|
|
'runner_config': runner_config,
|
|
|
|
|
'context': {'mode': 0} # Add context.mode to fix AttributeError
|
|
|
|
|
}
|
|
|
|
|
model = MagicMock()
|
|
|
|
|
dataset = MagicMock()
|
|
|
|
|
|
|
|
|
|
@@ -124,59 +886,75 @@ class TestBuildModel:
|
|
|
|
|
|
|
|
|
|
class TestGetCheckpointMode:
|
|
|
|
|
"""A test class for testing get_checkpoint_mode"""
|
|
|
|
|
@patch('os.path.isfile')
|
|
|
|
|
@patch('os.path.isdir')
|
|
|
|
|
def test_single_checkpoint_file(self, mock_isdir, mock_isfile):
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_single_checkpoint_file(self):
|
|
|
|
|
"""test single checkpoint file"""
|
|
|
|
|
mock_isfile.return_value = True
|
|
|
|
|
mock_isdir.return_value = False
|
|
|
|
|
config = type('', (), {})()
|
|
|
|
|
config.load_checkpoint = '/test/checkpoint_file.safetensors'
|
|
|
|
|
assert _get_checkpoint_mode(config) == CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value
|
|
|
|
|
with patch('os.path.isfile') as mock_isfile, \
|
|
|
|
|
patch('os.path.isdir') as mock_isdir:
|
|
|
|
|
mock_isfile.return_value = True
|
|
|
|
|
mock_isdir.return_value = False
|
|
|
|
|
config = type('', (), {})()
|
|
|
|
|
config.load_checkpoint = '/test/checkpoint_file.safetensors'
|
|
|
|
|
assert _get_checkpoint_mode(config) == CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value
|
|
|
|
|
|
|
|
|
|
@patch('os.path.isfile')
|
|
|
|
|
@patch('os.path.isdir')
|
|
|
|
|
def test_multi_checkpoint_file_with_rank_id(self, mock_isdir, mock_isfile):
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_multi_checkpoint_file_with_rank_id(self):
|
|
|
|
|
"""test multi checkpoint file with rank id"""
|
|
|
|
|
mock_isfile.return_value = False
|
|
|
|
|
mock_isdir.return_value = True
|
|
|
|
|
with patch('os.listdir', return_value=['rank_0']):
|
|
|
|
|
config = type('', (), {})()
|
|
|
|
|
config.load_checkpoint = '/test/checkpoint_dir/'
|
|
|
|
|
assert _get_checkpoint_mode(config) == CheckpointFileMode.MULTI_CHECKPOINT_FILE_WITH_RANK_ID.value
|
|
|
|
|
with patch('os.path.isfile') as mock_isfile, \
|
|
|
|
|
patch('os.path.isdir') as mock_isdir:
|
|
|
|
|
mock_isfile.return_value = False
|
|
|
|
|
mock_isdir.return_value = True
|
|
|
|
|
with patch('os.listdir', return_value=['rank_0']):
|
|
|
|
|
config = type('', (), {})()
|
|
|
|
|
config.load_checkpoint = '/test/checkpoint_dir/'
|
|
|
|
|
assert _get_checkpoint_mode(config) == CheckpointFileMode.MULTI_CHECKPOINT_FILE_WITH_RANK_ID.value
|
|
|
|
|
|
|
|
|
|
@patch('os.path.isfile')
|
|
|
|
|
@patch('os.path.isdir')
|
|
|
|
|
def test_multi_checkpoint_file(self, mock_isdir, mock_isfile):
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_multi_checkpoint_file(self):
|
|
|
|
|
""" test multi checkpoint file"""
|
|
|
|
|
mock_isfile.return_value = False
|
|
|
|
|
mock_isdir.return_value = True
|
|
|
|
|
with patch('os.listdir', return_value=['checkpoint.safetensors']):
|
|
|
|
|
config = type('', (), {})()
|
|
|
|
|
config.load_checkpoint = '/test/checkpoint_dir/'
|
|
|
|
|
config.load_ckpt_format = '.safetensors'
|
|
|
|
|
assert _get_checkpoint_mode(config) == CheckpointFileMode.MULTI_CHECKPOINT_FILE.value
|
|
|
|
|
with patch('os.path.isfile') as mock_isfile, \
|
|
|
|
|
patch('os.path.isdir') as mock_isdir:
|
|
|
|
|
mock_isfile.return_value = False
|
|
|
|
|
mock_isdir.return_value = True
|
|
|
|
|
with patch('os.listdir', return_value=['checkpoint.safetensors']):
|
|
|
|
|
config = type('', (), {})()
|
|
|
|
|
config.load_checkpoint = '/test/checkpoint_dir/'
|
|
|
|
|
config.load_ckpt_format = '.safetensors'
|
|
|
|
|
assert _get_checkpoint_mode(config) == CheckpointFileMode.MULTI_CHECKPOINT_FILE.value
|
|
|
|
|
|
|
|
|
|
@patch('os.path.isfile')
|
|
|
|
|
@patch('os.path.isdir')
|
|
|
|
|
def test_invalid_path(self, mock_isdir, mock_isfile):
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_invalid_path(self):
|
|
|
|
|
"""test invalid path"""
|
|
|
|
|
mock_isfile.return_value = False
|
|
|
|
|
mock_isdir.return_value = False
|
|
|
|
|
config = type('', (), {})()
|
|
|
|
|
config.load_checkpoint = 'invalid_path'
|
|
|
|
|
with pytest.raises(ValueError, match="Provided path is neither a file nor a directory."):
|
|
|
|
|
_get_checkpoint_mode(config)
|
|
|
|
|
|
|
|
|
|
@patch('os.path.isfile')
|
|
|
|
|
@patch('os.path.isdir')
|
|
|
|
|
def test_no_valid_checkpoint_files(self, mock_isdir, mock_isfile):
|
|
|
|
|
"""test no valid checkpoint files"""
|
|
|
|
|
mock_isfile.return_value = False
|
|
|
|
|
mock_isdir.return_value = True
|
|
|
|
|
with patch('os.listdir', return_value=['not_a_checkpoint_file']):
|
|
|
|
|
with patch('os.path.isfile') as mock_isfile, \
|
|
|
|
|
patch('os.path.isdir') as mock_isdir:
|
|
|
|
|
mock_isfile.return_value = False
|
|
|
|
|
mock_isdir.return_value = False
|
|
|
|
|
config = type('', (), {})()
|
|
|
|
|
config.load_checkpoint = '/test/checkpoint_dir/'
|
|
|
|
|
config.load_ckpt_format = '.safetensors'
|
|
|
|
|
with pytest.raises(ValueError, match="not support mode: no valid checkpoint files found"):
|
|
|
|
|
config.load_checkpoint = 'invalid_path'
|
|
|
|
|
with pytest.raises(ValueError, match="Provided path is neither a file nor a directory."):
|
|
|
|
|
_get_checkpoint_mode(config)
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_cpu
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_no_valid_checkpoint_files(self):
|
|
|
|
|
"""test no valid checkpoint files"""
|
|
|
|
|
with patch('os.path.isfile') as mock_isfile, \
|
|
|
|
|
patch('os.path.isdir') as mock_isdir:
|
|
|
|
|
mock_isfile.return_value = False
|
|
|
|
|
mock_isdir.return_value = True
|
|
|
|
|
with patch('os.listdir', return_value=['not_a_checkpoint_file']):
|
|
|
|
|
config = type('', (), {})()
|
|
|
|
|
config.load_checkpoint = '/test/checkpoint_dir/'
|
|
|
|
|
config.load_ckpt_format = '.safetensors'
|
|
|
|
|
with pytest.raises(ValueError, match="not support mode: no valid checkpoint files found"):
|
|
|
|
|
_get_checkpoint_mode(config)
|
|
|
|
|
|