mirror of
https://gitee.com/mindspore/mindformers.git
synced 2025-12-06 11:29:59 +08:00
!7820 【master】【bugfix】【日志】权重相关日志,在raise Error之前添加logger.error,确保在error.log中有对应日志
Merge pull request !7820 from SaiYao/add_error_log_in_ckpt
This commit is contained in:
@@ -58,6 +58,7 @@ __all__ = ['TransformCkpt']
|
||||
|
||||
class TransformCkpt:
|
||||
"""Transform src_checkpoint from src_strategy to dst_strategy."""
|
||||
|
||||
def __init__(self,
|
||||
auto_trans_ckpt: bool = False,
|
||||
rank_id: Optional[int] = None,
|
||||
@@ -212,23 +213,29 @@ class TransformCkpt:
|
||||
if self.world_size > 1:
|
||||
dst_strategy_list = glob(os.path.join(self.dst_strategy_dir, f"*_rank_{self.rank_id}.ckpt"))
|
||||
if not dst_strategy_list:
|
||||
raise RuntimeError(f"The `dst_strategy`={self.dst_strategy_dir} \
|
||||
does not contain strategy file of rank_{self.rank_id}.")
|
||||
err_msg = (f"The `dst_strategy`={self.dst_strategy_dir} "
|
||||
f"does not contain strategy file of rank_{self.rank_id}.")
|
||||
logger.error(err_msg)
|
||||
raise RuntimeError(err_msg)
|
||||
if len(dst_strategy_list) > 1:
|
||||
raise RuntimeError(f"There can only be one strategy file corresponding to rank_{self.rank_id}, \
|
||||
but multiple strategy files corresponding to rank_{self.rank_id} were found \
|
||||
in {self.dst_strategy_dir}.")
|
||||
err_msg = (f"There can only be one strategy file corresponding to rank_{self.rank_id}, "
|
||||
f"but multiple strategy files corresponding to rank_{self.rank_id} "
|
||||
f"were found in {self.dst_strategy_dir}.")
|
||||
logger.error(err_msg)
|
||||
raise RuntimeError(err_msg)
|
||||
dst_strategy = dst_strategy_list[0]
|
||||
else:
|
||||
dst_strategy = None
|
||||
|
||||
if check_in_modelarts():
|
||||
if not mox.file.exists(self.transformed_checkpoint_dir_obs):
|
||||
raise ValueError(f"transformed_checkpoint_dir_obs: "
|
||||
f"{self.transformed_checkpoint_dir_obs} is not found!")
|
||||
err_msg = f"transformed_checkpoint_dir_obs: {self.transformed_checkpoint_dir_obs} is not found!"
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
if self.world_size > 1 and not mox.file.exists(self.dst_strategy_dir_obs):
|
||||
raise ValueError(f"dst_strategy_dir_obs: {self.dst_strategy_dir_obs} is not found!")
|
||||
|
||||
err_msg = f"dst_strategy_dir_obs: {self.dst_strategy_dir_obs} is not found!"
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
# Get final dst_strategy in auto_trans_ckpt mode.
|
||||
dst_strategy = self.get_dst_strategy(dst_strategy)
|
||||
@@ -247,13 +254,7 @@ class TransformCkpt:
|
||||
barrier_world(f"Remake {dst_ckpt_dir} by main rank.")
|
||||
|
||||
logger.info("The transformed checkpoint will be saved under %s.", dst_ckpt_dir)
|
||||
self.transform_ckpt(
|
||||
src_checkpoint=src_ckpt_dir,
|
||||
dst_checkpoint_dir=dst_ckpt_dir,
|
||||
src_strategy=src_strategy,
|
||||
dst_strategy=dst_strategy,
|
||||
prefix=prefix
|
||||
)
|
||||
self.transform_ckpt(src_ckpt_dir, dst_ckpt_dir, src_strategy, dst_strategy, prefix)
|
||||
|
||||
self.clear_cache()
|
||||
return dst_checkpoint_dir
|
||||
@@ -267,7 +268,9 @@ class TransformCkpt:
|
||||
"""Transform ckpt using mindspore.transform_checkpoint"""
|
||||
self.check_src_checkpoint_and_strategy(src_checkpoint, src_strategy)
|
||||
if src_strategy is None and dst_strategy is None:
|
||||
raise ValueError("`src_strategy` and `dst_strategy` cannot both be None!")
|
||||
err_msg = "`src_strategy` and `dst_strategy` cannot both be None!"
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
if check_in_modelarts():
|
||||
dst_checkpoint_dir_obs = os.path.join(self.transformed_checkpoint_dir_obs,
|
||||
os.path.basename(dst_checkpoint_dir))
|
||||
@@ -339,7 +342,7 @@ class TransformCkpt:
|
||||
dst_strategy):
|
||||
"""transform checkpoints using mindspore.transform_checkpoint_by_rank"""
|
||||
for current_transform_rank_id in \
|
||||
range(self.rank_id, self.rank_id + self.world_size // self.transform_process_num):
|
||||
range(self.rank_id, self.rank_id + self.world_size // self.transform_process_num):
|
||||
logger.info(".........Transforming Ckpt For Rank: %d.........", current_transform_rank_id)
|
||||
src_rank_list = ms.rank_list_for_transform(current_transform_rank_id,
|
||||
src_strategy,
|
||||
@@ -371,11 +374,15 @@ class TransformCkpt:
|
||||
def build_soft_link_of_checkpoint(checkpoint, soft_link_dir):
|
||||
"""Build softlink of src checkpoint"""
|
||||
if os.path.isdir(checkpoint) and not check_rank_folders(checkpoint, 0) and \
|
||||
not check_ckpt_file_exist(checkpoint):
|
||||
raise ValueError(f"No rank_0 folder or ckpt files are found under {checkpoint}.")
|
||||
not check_ckpt_file_exist(checkpoint):
|
||||
err_msg = f"No rank_0 folder or ckpt files are found under {checkpoint}."
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
if os.path.isfile(checkpoint) and not checkpoint.endswith('.ckpt'):
|
||||
raise ValueError(f"The value of load_checkpoint must be a folder or a file with suffix '.ckpt', "
|
||||
f"but got {checkpoint}")
|
||||
err_msg = (f"The value of load_checkpoint must be a folder or a file with suffix '.ckpt', "
|
||||
f"but got {checkpoint}")
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
if os.path.isdir(checkpoint):
|
||||
if check_rank_folders(checkpoint, 0):
|
||||
@@ -418,7 +425,9 @@ class TransformCkpt:
|
||||
return None
|
||||
|
||||
if not os.path.exists(strategy_path):
|
||||
raise ValueError(f'strategy_path: {strategy_path} not found!')
|
||||
err_msg = f'strategy_path: {strategy_path} not found!'
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
if os.path.isfile(strategy_path):
|
||||
return strategy_path
|
||||
@@ -454,8 +463,9 @@ class TransformCkpt:
|
||||
|
||||
if not (dst_strategy.endswith(f"_rank_{self.rank_id}.ckpt") and
|
||||
os.path.exists(dst_strategy)):
|
||||
raise ValueError(f"dst_strategy: {dst_strategy} is not found!")
|
||||
|
||||
err_msg = f"dst_strategy: {dst_strategy} is not found!"
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
logger.info(".........Collecting strategy.........")
|
||||
if check_in_modelarts():
|
||||
@@ -521,16 +531,19 @@ class TransformCkpt:
|
||||
"""
|
||||
# Before obtaining transform_rank_id_list, check 1 ≤ transform_process_num ≤ world_size.
|
||||
if transform_process_num < 1:
|
||||
raise ValueError("transform_process_num should not smaller than 1,"
|
||||
f"but got {transform_process_num}.")
|
||||
err_msg = f"transform_process_num should not smaller than 1, but got {transform_process_num}."
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
if transform_process_num > self.world_size:
|
||||
logger.warning(f"transform_process_num: {transform_process_num} should not "
|
||||
f"bigger than world_size: {self.world_size}. "
|
||||
f"transform_process_num is set to {self.world_size}.")
|
||||
transform_process_num = self.world_size
|
||||
if self.world_size % transform_process_num != 0:
|
||||
raise ValueError(f"transform_process_num: {transform_process_num} "
|
||||
f"should be divided by world_size: {self.world_size}.")
|
||||
err_msg = (f"transform_process_num: {transform_process_num} "
|
||||
f"should be divided by world_size: {self.world_size}.")
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
if check_in_modelarts() and 1 < transform_process_num < self.node_num:
|
||||
logger.warning("transform_process_num: %d should not smaller than \
|
||||
@@ -551,15 +564,15 @@ class TransformCkpt:
|
||||
|
||||
return transform_rank_id_list
|
||||
|
||||
|
||||
@staticmethod
|
||||
def check_src_checkpoint_and_strategy(src_checkpoint, src_strategy):
|
||||
"""check src checkpoint and strategy"""
|
||||
check_path(src_checkpoint, "src_checkpoint")
|
||||
if not os.path.isdir(src_checkpoint) or not glob(os.path.join(src_checkpoint, "rank_*")):
|
||||
raise ValueError("The load_checkpoint must be a dir and "
|
||||
"ckpt should be stored in the format of load_checkpoint/rank_x/xxx.ckpt,"
|
||||
f"but get {src_checkpoint}.")
|
||||
err_msg = ("The load_checkpoint must be a dir and ckpt should be stored "
|
||||
f"in the format of load_checkpoint/rank_x/xxx.ckpt, but get {src_checkpoint}.")
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
# Check rank_dirs is continuous.
|
||||
# For example, rank_0, rank_1, rank_4 is not continuous because it is missing rank_3
|
||||
src_checkpoint_rank_dir_list = glob(os.path.join(src_checkpoint, "rank_*"))
|
||||
@@ -568,7 +581,9 @@ class TransformCkpt:
|
||||
src_checkpoint_rank_num = len(src_checkpoint_rank_id_list)
|
||||
for i in range(src_checkpoint_rank_num):
|
||||
if src_checkpoint_rank_id_list[i] != i:
|
||||
raise FileNotFoundError(f"The rank_{i} folder was not found under src_checkpoint folder.")
|
||||
err_msg = f"The rank_{i} folder was not found under src_checkpoint folder."
|
||||
logger.error(err_msg)
|
||||
raise FileNotFoundError(err_msg)
|
||||
|
||||
# A full checkpoint do not require a strategy.
|
||||
if len(src_checkpoint_rank_id_list) == 1 and src_strategy:
|
||||
@@ -576,7 +591,9 @@ class TransformCkpt:
|
||||
src_strategy = None
|
||||
# Distributed checkpoints must be accompanied by strategy.
|
||||
if len(src_checkpoint_rank_id_list) > 1 and src_strategy is None:
|
||||
raise ValueError("`src_strategy` should not be None when `src_checkpoint` is sliced.")
|
||||
err_msg = "`src_strategy` should not be None when `src_checkpoint` is sliced."
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
def send_strategy_to_obs(self, strategy):
|
||||
"""Local rank send strategy file to obs."""
|
||||
@@ -623,6 +640,7 @@ class TransformCkpt:
|
||||
last_strategy_num = dst_strategy_num
|
||||
if dst_strategy_num < self.world_size:
|
||||
if time.time() - start_time > 7200:
|
||||
logger.error("Timeout while collecting all strategy!")
|
||||
raise TimeoutError("Timeout while collecting all strategy!")
|
||||
time.sleep(5)
|
||||
else:
|
||||
@@ -642,7 +660,9 @@ class TransformCkpt:
|
||||
transform_failed_txts = glob(os.path.join(ckpt_dir, 'transform_failed_rank_*.txt'))
|
||||
transform_succeed_txts = glob(os.path.join(ckpt_dir, 'transform_succeed_rank_*.txt'))
|
||||
if transform_failed_txts:
|
||||
raise ValueError(f"Transform failed, find {transform_failed_txts}.")
|
||||
err_msg = f"Transform failed, find {transform_failed_txts}."
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
current_count = len(transform_succeed_txts)
|
||||
progress = (current_count / self.transform_process_num) * 100
|
||||
if current_count != last_count:
|
||||
|
||||
@@ -84,7 +84,9 @@ def get_resume_checkpoint_by_meta(checkpoint_dir, ckpt_format='ckpt',
|
||||
if check_in_modelarts():
|
||||
resume_record_dir = os.path.join(get_remote_save_url(), "resume_record")
|
||||
if not Validator.is_obs_url(resume_record_dir):
|
||||
raise ValueError(f"{resume_record_dir} is not a valid obs path.")
|
||||
err_meg = f"{resume_record_dir} is not a valid obs path."
|
||||
logger.error(err_meg)
|
||||
raise ValueError(err_meg)
|
||||
else:
|
||||
resume_record_dir = os.path.join(get_output_root_path(), "resume_record")
|
||||
remake_folder(resume_record_dir, permissions=0o750)
|
||||
@@ -167,12 +169,16 @@ def get_resume_ckpt(latest_checkpointed_iteration_txt, rank_id):
|
||||
|
||||
if not check_in_modelarts():
|
||||
if not os.path.exists(latest_checkpointed_iteration_txt):
|
||||
raise ValueError(f"Can not find {latest_checkpointed_iteration_txt}")
|
||||
err_msg = f"Can not find {latest_checkpointed_iteration_txt}"
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
with open(latest_checkpointed_iteration_txt, 'r', encoding='utf-8') as f:
|
||||
resume_info = [line.strip() for line in f.readlines()]
|
||||
else:
|
||||
if not mox.file.exists(latest_checkpointed_iteration_txt):
|
||||
raise ValueError(f"OBS: Can not find {latest_checkpointed_iteration_txt}")
|
||||
err_msg = f"OBS: Can not find {latest_checkpointed_iteration_txt}"
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
with mox.file.File(latest_checkpointed_iteration_txt, 'r') as f:
|
||||
resume_info = [line.strip() for line in f.readlines()]
|
||||
|
||||
@@ -182,7 +188,9 @@ def get_resume_ckpt(latest_checkpointed_iteration_txt, rank_id):
|
||||
return True
|
||||
|
||||
if resume_info[0].startswith("Error"):
|
||||
raise ValueError(f"Get resume-able checkpoint failed, due to {resume_info[0]}")
|
||||
err_msg = f"Get resume-able checkpoint failed, due to {resume_info[0]}"
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
resume_ckpt = replace_rank_id_in_ckpt_name(resume_info[-1], rank_id)
|
||||
logger.info("Get resume checkpoint: %s", resume_ckpt)
|
||||
@@ -242,7 +250,9 @@ def get_resume_ckpt_list(checkpoint_dir, last_ckpt_file, rank_id, rank_dir_num,
|
||||
ckpt_prefix_tmp = ckpt_prefix.replace(f"rank_{original_rank}", f"rank_{rank_id_tmp}")
|
||||
checkpoint_rank_dir = os.path.join(checkpoint_dir, f"rank_{rank_id_tmp}")
|
||||
if not os.path.exists(checkpoint_rank_dir):
|
||||
raise FileNotFoundError(f"{checkpoint_rank_dir} is not found!")
|
||||
err_msg = f"{checkpoint_rank_dir} is not found!"
|
||||
logger.error(err_msg)
|
||||
raise FileNotFoundError(err_msg)
|
||||
for ckpt_file in os.listdir(checkpoint_rank_dir):
|
||||
health_ckpt_match = (ckpt_file.startswith(ckpt_prefix_tmp[:ckpt_prefix_tmp.rfind("_")])
|
||||
and use_checkpoint_health_monitor)
|
||||
@@ -262,10 +272,14 @@ def get_resume_ckpt_list(checkpoint_dir, last_ckpt_file, rank_id, rank_dir_num,
|
||||
ckpt_file = replace_rank_id_in_ckpt_name(ckpts[0], rank_id)
|
||||
resume_ckpt = os.path.join(checkpoint_dir, f"rank_{rank_id}", ckpt_file)
|
||||
if not os.path.exists(resume_ckpt):
|
||||
raise FileNotFoundError(f"{resume_ckpt} is not found!")
|
||||
err_msg = f"{resume_ckpt} is not found!"
|
||||
logger.error(err_msg)
|
||||
raise FileNotFoundError(err_msg)
|
||||
resume_ckpt_list.append(resume_ckpt)
|
||||
if not resume_ckpt_list:
|
||||
raise RuntimeError("No checkpoint could be resumed.")
|
||||
err_msg = "No checkpoint could be resumed."
|
||||
logger.error(err_msg)
|
||||
raise RuntimeError(err_msg)
|
||||
|
||||
if use_checkpoint_health_monitor:
|
||||
resume_ckpt_list.sort(key=lambda x: get_times_epoch_and_step_from_ckpt_name(x, ckpt_format))
|
||||
@@ -325,8 +339,9 @@ def check_last_timestamp_checkpoints(checkpoint_dir, rank_dir_num, ckpt_format='
|
||||
checkpoint_rank_dir = os.path.join(checkpoint_dir, f"rank_{rank_id_tmp}")
|
||||
last_checkpoint = get_last_checkpoint(checkpoint_rank_dir, ckpt_format)
|
||||
if not last_checkpoint:
|
||||
raise ValueError(f"Checkpoint not found under {checkpoint_rank_dir} "
|
||||
f"with config.load_ckpt_format:{ckpt_format}.")
|
||||
err_msg = f"Checkpoint not found under {checkpoint_rank_dir} with config.load_ckpt_format:{ckpt_format}."
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
if check_ckpt_file_name(last_checkpoint, ckpt_format):
|
||||
compared_checkpoint_name = replace_rank_id_in_ckpt_name(last_checkpoint, 0)
|
||||
compared_original_checkpoint_name = os.path.basename(last_checkpoint)
|
||||
@@ -353,12 +368,16 @@ def check_last_timestamp_checkpoints(checkpoint_dir, rank_dir_num, ckpt_format='
|
||||
compared_checkpoint_name = current_checkpoint_name
|
||||
compared_original_checkpoint_name = original_checkpoint_name
|
||||
elif compared_checkpoint_name != current_checkpoint_name:
|
||||
raise ValueError(f"Check name of the checkpoint file with the last timestamp Failed.\n"
|
||||
f"1. Find 2 different checkpoints name: {compared_original_checkpoint_name} and "
|
||||
f"{original_checkpoint_name}.\n2. Checkpoint file name should follow rule: "
|
||||
f"{{prefix}}-{{epoch}}_{{step}}.{ckpt_format}, and not corrupted across all rank "
|
||||
f"folders.\n 3. Rename `resume_training` checkpoint such as "
|
||||
f"llama_7b_rank_0-3_2.{ckpt_format} may solve the problem.")
|
||||
err_msg = (f"Check name of the checkpoint file with the last timestamp Failed.\n"
|
||||
f"1. Find 2 different checkpoints name: {compared_original_checkpoint_name} and "
|
||||
f"{original_checkpoint_name}.\n2. Checkpoint file name should follow rule: "
|
||||
f"{{prefix}}-{{epoch}}_{{step}}.{ckpt_format}, and not corrupted across all rank "
|
||||
f"folders.\n 3. Rename `resume_training` checkpoint such as "
|
||||
f"llama_7b_rank_0-3_2.{ckpt_format} may solve the problem.")
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
if find_diff_ckpt:
|
||||
raise ValueError(f"Some checkpoints follow the {{prefix}}-{{epoch}}_{{step}}.{ckpt_format} "
|
||||
f"naming convention, while others do not.")
|
||||
err_msg = (f"Some checkpoints follow the {{prefix}}-{{epoch}}_{{step}}.{ckpt_format} "
|
||||
f"naming convention, while others do not.")
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
@@ -1133,7 +1133,9 @@ class BaseTrainer:
|
||||
logger.info(".............Start load resume context from common.json..................")
|
||||
common_file = os.path.join(config.load_checkpoint, 'common.json')
|
||||
if not os.path.exists(common_file):
|
||||
raise FileNotFoundError(f"No common.json found in directory '{config.load_checkpoint}'.")
|
||||
error_msg = f"No common.json found in directory '{config.load_checkpoint}'."
|
||||
logger.error(error_msg)
|
||||
raise FileNotFoundError(error_msg)
|
||||
common_info = CommonInfo.load_common(common_file)
|
||||
step_scale = common_info.global_batch_size / config.runner_config.global_batch_size
|
||||
config.runner_config.initial_step = int(common_info.step_num * step_scale)
|
||||
@@ -1167,9 +1169,11 @@ class BaseTrainer:
|
||||
logger.info("..............Start resume checkpoint path from strategy..............")
|
||||
resume_ckpt_path = self.resume_ckpt_path_with_strategy(config)
|
||||
if resume_ckpt_path is None:
|
||||
raise ValueError(f"Try to resume from checkpoints with strategy in directory "
|
||||
f"'{config.load_checkpoint}' failed, please specify load_checkpoint to "
|
||||
f"specific checkpoint file to resume training.")
|
||||
err_msg = (f"Try to resume from checkpoints with strategy in directory "
|
||||
f"'{config.load_checkpoint}' failed, please specify load_checkpoint to "
|
||||
f"specific checkpoint file to resume training.")
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
config.load_checkpoint = resume_ckpt_path
|
||||
load_resume_context_from_checkpoint(config, dataset)
|
||||
resume_dict = {
|
||||
@@ -1253,8 +1257,9 @@ class BaseTrainer:
|
||||
if hasattr(network, "get_model_parameters"):
|
||||
model_params.update(network.get_model_parameters())
|
||||
else:
|
||||
raise NotImplementedError(f"The {type(network)} has not implemented the interface: "
|
||||
f"get_model_parameters.")
|
||||
err_msg = f"The {type(network)} has not implemented the interface: `get_model_parameters`."
|
||||
logger.error(err_msg)
|
||||
raise NotImplementedError(err_msg)
|
||||
|
||||
is_moe_model = False
|
||||
is_mtp_model = False
|
||||
|
||||
@@ -348,7 +348,8 @@ def load_resume_context_from_checkpoint(config, dataset):
|
||||
"""resume training, load training info from checkpoint to config"""
|
||||
if not os.path.realpath(config.load_checkpoint) or \
|
||||
not os.path.exists(config.load_checkpoint):
|
||||
raise FileNotFoundError(f"The load_checkpoint must be correct, but get {config.load_checkpoint}")
|
||||
err_log = f"The load_checkpoint must be correct, but get {config.load_checkpoint}"
|
||||
raise FileNotFoundError(err_log)
|
||||
|
||||
if os.path.isdir(config.load_checkpoint):
|
||||
# When graceful exit is enabled or auto checkpoint transformation is disabled,
|
||||
@@ -570,7 +571,9 @@ def load_slora_ckpt(checkpoint_dict, config, network):
|
||||
logger.info("............Start load slora checkpoint ............")
|
||||
adapter_path = os.path.join(pet_config.adapter_path, "lora_adapter.json")
|
||||
if not os.path.exists(adapter_path):
|
||||
raise FileNotFoundError(f"The adapter_path must be correct, but get {adapter_path}")
|
||||
err_msg = f"The adapter_path must be correct, but get {adapter_path}"
|
||||
logger.error(err_msg)
|
||||
raise FileNotFoundError(err_msg)
|
||||
with open(adapter_path, 'r', encoding='utf-8') as file:
|
||||
path_dict = json.load(file)
|
||||
adapter_list = []
|
||||
@@ -686,8 +689,9 @@ def get_load_checkpoint_result(config):
|
||||
else:
|
||||
checkpoint_dict = load_distributed_checkpoint(config.load_checkpoint)
|
||||
else:
|
||||
raise ValueError(f"{config.load_checkpoint} is not a valid path to load checkpoint "
|
||||
f"when auto_trans_ckpt is False.")
|
||||
err_msg = f"{config.load_checkpoint} is not a valid path to load checkpoint when auto_trans_ckpt is False."
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
return checkpoint_dict if checkpoint_dict else checkpoint_future
|
||||
|
||||
|
||||
|
||||
@@ -262,7 +262,9 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval
|
||||
|
||||
pet_config = config.model.model_config.get("pet_config")
|
||||
if pet_config and pet_config.pet_type == "slora" and network.lora_list:
|
||||
raise ValueError(f"slora only support .ckpt file, {config.load_ckpt_format} file will be compatible soon.")
|
||||
err_msg = f"slora only support .ckpt file, {config.load_ckpt_format} file will be compatible soon."
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
ckpt_file_mode = _get_checkpoint_mode(config)
|
||||
validate_config_with_file_mode(ckpt_file_mode, config.use_parallel, config.auto_trans_ckpt)
|
||||
# reduce compile time in prediction
|
||||
@@ -422,6 +424,7 @@ def load_safetensors_checkpoint(config, load_checkpoint_files, network, strategy
|
||||
logger.info("......obtain name map for HF safetensors.....")
|
||||
name_map = origin_network.obtain_name_map(load_checkpoint_files)
|
||||
except Exception as e:
|
||||
logger.error(f"Please complete abstract function obtain_name_map. Details: {e}")
|
||||
raise TypeError(f"Please complete abstract function obtain_name_map. Details: {e}") from e
|
||||
if is_main_rank():
|
||||
_convert_index_json(load_ckpt_path, load_ckpt_path, origin_network.convert_map_dict, False)
|
||||
@@ -439,7 +442,9 @@ def load_safetensors_checkpoint(config, load_checkpoint_files, network, strategy
|
||||
hyper_param_file = os.path.join(load_ckpt_path, 'hyper_param.safetensors')
|
||||
if optimizer and config.resume_training:
|
||||
if not os.path.exists(hyper_param_file):
|
||||
raise FileNotFoundError(rf"No hyper_param.safetensors in given dir: {load_ckpt_path}")
|
||||
err_msg = rf"No hyper_param.safetensors in given dir: {load_ckpt_path}"
|
||||
logger.error(err_msg)
|
||||
raise FileNotFoundError(err_msg)
|
||||
logger.info("......Start load hyper param into optimizer......")
|
||||
hyper_param_dict = ms.load_checkpoint(ckpt_file_name=hyper_param_file, format='safetensors')
|
||||
update_global_step(config, hyper_param_dict)
|
||||
@@ -562,11 +567,15 @@ def validate_qkv_concat(model_cls_or_instance, qkv_concat_config, load_checkpoin
|
||||
break
|
||||
|
||||
if is_qkv_concat and not qkv_concat_config:
|
||||
raise ValueError("The qkv concat check failed! The qkv in the model weights has been concatenated,"
|
||||
" but qkv_concat is set to false.")
|
||||
err_msg = ("The qkv concat check failed! The qkv in the model weights has been concatenated, "
|
||||
"but qkv_concat is set to false.")
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
if not is_qkv_concat and qkv_concat_config:
|
||||
raise ValueError("The qkv concat check failed! The qkv in the model weights has been not concatenated,"
|
||||
" but qkv_concat is set to true.")
|
||||
err_msg = ("The qkv concat check failed! The qkv in the model weights has been not concatenated, "
|
||||
"but qkv_concat is set to true.")
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
if is_qkv_concat and qkv_concat_config:
|
||||
logger.info("The qkv concat check succeed! The qkv in the model weights has been concatenated and "
|
||||
"qkv_concat is set to true.")
|
||||
|
||||
@@ -57,7 +57,9 @@ def load_resume_checkpoint(load_checkpoint_path, remove_redundancy, load_ckpt_fo
|
||||
"""resume training, load training info from checkpoint to config"""
|
||||
if not os.path.realpath(load_checkpoint_path) or \
|
||||
not os.path.exists(load_checkpoint_path):
|
||||
raise FileNotFoundError(f"The load_checkpoint_path must be correct, but get {load_checkpoint_path}")
|
||||
err_msg = f"The load_checkpoint_path must be correct, but get {load_checkpoint_path}"
|
||||
logger.error(err_msg)
|
||||
raise FileNotFoundError(err_msg)
|
||||
|
||||
if os.path.isdir(load_checkpoint_path):
|
||||
hyper_param_file = os.path.join(load_checkpoint_path, 'hyper_param.safetensors')
|
||||
|
||||
Reference in New Issue
Block a user