mirror of
https://gitee.com/mindspore/mindformers.git
synced 2025-12-06 19:42:57 +08:00
Compare commits
4 Commits
3d81fa1cb1
...
r1.6.0-bet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
983b95ebc0 | ||
|
|
5d6081e56f | ||
|
|
556c5f36c1 | ||
|
|
fa3d2c9e74 |
@@ -1335,6 +1335,7 @@ class FreqsMgrDynamicNTK(Cell):
|
||||
def __init__(self,
|
||||
head_dim,
|
||||
max_position_embedding,
|
||||
base_seqlen=None,
|
||||
rotary_dtype=mstype.float16,
|
||||
theta=10000,
|
||||
parallel_config=None,
|
||||
@@ -1378,7 +1379,7 @@ class FreqsMgrDynamicNTK(Cell):
|
||||
|
||||
self.base = theta
|
||||
self.max_position_embedding = max_position_embedding
|
||||
self.max_position_embedding_inverse = 1 / max_position_embedding
|
||||
self.max_position_embedding_inverse = 1 / base_seqlen if base_seqlen else 1 / max_position_embedding
|
||||
self.log_scale_inverse = 1 / math.log(2)
|
||||
self.log_scale = math.log(2)
|
||||
self.min_ntk_alpha = 1.0
|
||||
|
||||
@@ -1567,17 +1567,21 @@ class TopkRouterV2(Cell):
|
||||
expert_gate = self._normalize(expert_gate) # (dp, N, k) <-- (dp, N, k)
|
||||
|
||||
expert_gate = ops.mul(self.moe_config.routed_scaling_factor, expert_gate) # (dp, N, k) <-- (dp, N, k)
|
||||
|
||||
if self.moe_config.balance_via_topk_bias and (self.aux_loss_config.get("expert", 0) > 0 \
|
||||
or self.aux_loss_config.get("device", 0) > 0 or self.aux_loss_config.get("comm", 0) > 0):
|
||||
_, expert_index_for_aux = self.topk(router_prob_for_aux, self.num_experts_chosen)
|
||||
else:
|
||||
expert_index_for_aux = expert_index
|
||||
if self.aux_loss_config.get("expert", 0):
|
||||
expert_load_loss = self._expert_load_balancing(router_prob_for_aux, expert_index,
|
||||
expert_load_loss = self._expert_load_balancing(router_prob_for_aux, expert_index_for_aux,
|
||||
self.aux_loss_config.get("expert"), seq_chunk=seq_chunk)
|
||||
extra_loss = self.add_scalar(extra_loss, expert_load_loss)
|
||||
if self.aux_loss_config.get("device", 0):
|
||||
device_load_loss = self._device_load_balancing(router_prob_for_aux, expert_index,
|
||||
device_load_loss = self._device_load_balancing(router_prob_for_aux, expert_index_for_aux,
|
||||
self.aux_loss_config.get("device"))
|
||||
extra_loss = self.add_scalar(extra_loss, device_load_loss)
|
||||
if self.aux_loss_config.get("comm", 0):
|
||||
comm_load_loss = self._comm_load_balancing(router_prob_for_aux, expert_index,
|
||||
comm_load_loss = self._comm_load_balancing(router_prob_for_aux, expert_index_for_aux,
|
||||
self.aux_loss_config.get("comm"))
|
||||
extra_loss = self.add_scalar(extra_loss, comm_load_loss)
|
||||
|
||||
|
||||
@@ -241,7 +241,11 @@ class MoEV3(Cell):
|
||||
router_coeff = expert_gate
|
||||
router_coeff = self.mul(self.moe_config.routed_scaling_factor, router_coeff)
|
||||
# float32 <-- (dp, N, E) fp32, (dp, N, k) int32, float32
|
||||
router_aux_loss = self._expert_load_balancing(router_prob_for_aux, expert_index, self.aux_loss_factor,
|
||||
if self.moe_config.balance_via_topk_bias:
|
||||
_, expert_index_for_aux = self.topk(router_prob_for_aux, self.num_experts_chosen)
|
||||
else:
|
||||
expert_index_for_aux = expert_index
|
||||
router_aux_loss = self._expert_load_balancing(router_prob_for_aux, expert_index_for_aux, self.aux_loss_factor,
|
||||
seq_chunk=seq_chunk)
|
||||
|
||||
if self.enable_deredundency or self.use_3d_tensor_parallel:
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
"""Telechat models' APIs."""
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, ops, mint
|
||||
from mindspore.communication import get_group_size
|
||||
from mindspore import Tensor, ops, mint, mutable
|
||||
from mindspore.communication._comm_helper import _is_initialized
|
||||
|
||||
from mindformers.experimental.infer.core.layers import ColumnParallelLinear
|
||||
@@ -25,9 +25,14 @@ from mindformers.experimental.parallel_core.pynative.parallel_state import get_g
|
||||
from mindformers.experimental.infer.models.llama.utils import convert_model_config
|
||||
from mindformers.models.modeling_utils import PreTrainedModel
|
||||
from mindformers.modules import Linear
|
||||
from mindformers.tools.register.register import MindFormerModuleType, MindFormerRegister
|
||||
from mindformers.tools.utils import get_predict_run_mode
|
||||
from mindformers.tools.logger import logger
|
||||
from mindformers.tools.register.register import MindFormerModuleType, MindFormerRegister
|
||||
from mindformers.tools.utils import get_predict_run_mode, is_pynative
|
||||
from mindformers.experimental.parallel_core.pynative.parallel_state import (
|
||||
get_data_parallel_group,
|
||||
get_tensor_model_parallel_group,
|
||||
)
|
||||
from mindformers.models.utils import jit
|
||||
|
||||
from research.telechat2.infer.telechat_transformers import TelechatParallelTransformer
|
||||
from research.telechat2.telechat_config import TelechatConfig
|
||||
@@ -49,7 +54,7 @@ class TelechatPreTrainedModel(PreTrainedModel):
|
||||
@MindFormerRegister.register(MindFormerModuleType.MODELS)
|
||||
class ParallelTelechatForCausalLM(TelechatPreTrainedModel):
|
||||
r"""
|
||||
Provide llama training loss or logits through network.
|
||||
Provide telechat training loss or logits through network.
|
||||
|
||||
Args:
|
||||
config (TelechatConfig): The config of llama model.
|
||||
@@ -61,15 +66,24 @@ class ParallelTelechatForCausalLM(TelechatPreTrainedModel):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config, auto_prefix=True)
|
||||
if get_group_info('tp').group is None and _is_initialized():
|
||||
initialize_model_parallel(get_group_size(), order='tp')
|
||||
self.config = convert_model_config(config)
|
||||
self.config.out_proj_has_bias = True
|
||||
tp_group = get_group_info('tp').group is None
|
||||
dp_group = get_group_info('dp').group is None
|
||||
logger.info(f"tp_group is:{tp_group}")
|
||||
logger.info(f"dp_group is:{dp_group}")
|
||||
all_groups_initialized = tp_group and dp_group
|
||||
if all_groups_initialized and _is_initialized():
|
||||
initialize_model_parallel(tensor_model_parallel_size=self.config.parallel_config.model_parallel,
|
||||
order='tp-dp')
|
||||
logger.info(f"data_parallel_group:{get_data_parallel_group()}")
|
||||
logger.info(f"tensor_model_parallel_group:{get_tensor_model_parallel_group()}")
|
||||
self.ignore_token_id = config.ignore_token_id
|
||||
self.pad_token_id = config.pad_token_id
|
||||
self.use_past = config.use_past
|
||||
self.vocab_size = config.vocab_size
|
||||
self.is_first_iteration = True
|
||||
self.is_pynative = is_pynative()
|
||||
|
||||
self.shape = ops.Shape()
|
||||
self.reshape = ops.Reshape()
|
||||
@@ -79,7 +93,7 @@ class ParallelTelechatForCausalLM(TelechatPreTrainedModel):
|
||||
self.mul = ops.Mul()
|
||||
self.add = ops.Add()
|
||||
self.ones = ops.Ones()
|
||||
self.gather = ops.Gather()
|
||||
self.gather = ops.Gather(1) if self.is_pynative else ops.Gather()
|
||||
self.sub_batch_valid_len = ops.Sub()
|
||||
self.model = TelechatParallelTransformer(config=config)
|
||||
if config.parallel_config.vocab_emb_dp:
|
||||
@@ -106,6 +120,8 @@ class ParallelTelechatForCausalLM(TelechatPreTrainedModel):
|
||||
self.predict_run_mode = get_predict_run_mode()
|
||||
|
||||
self.use_past = config.use_past
|
||||
self.npu_mem_size = config.npu_mem_size if hasattr(config, "npu_mem_size") else 2
|
||||
self.return_hidden_states = config.return_hidden_states
|
||||
|
||||
# pylint: disable=W0613
|
||||
def prepare_inputs_for_predict_layout(self, input_ids, **kwargs):
|
||||
@@ -117,22 +133,73 @@ class ParallelTelechatForCausalLM(TelechatPreTrainedModel):
|
||||
prefix_keys_values = Tensor(kwargs["prefix_keys_values"]) if "prefix_keys_values" in kwargs else None
|
||||
return input_ids, labels, None, None, None, None, None, None, None, None, None, slot_mapping, prefix_keys_values
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||
"""
|
||||
prepare inputs for generation.
|
||||
A model class needs to define a `prepare_inputs_for_generation` method
|
||||
in order to use `.generate()`
|
||||
|
||||
"""
|
||||
model_inputs = {"input_ids": Tensor.from_numpy(input_ids.astype(np.int32))}
|
||||
batch_valid_length = kwargs.get("valid_length_each_example")
|
||||
prefill = kwargs.get("prefill")
|
||||
|
||||
if self.is_pynative:
|
||||
model_inputs = {}
|
||||
if self.config.is_dynamic and "origin_inputs" in kwargs and self.use_past:
|
||||
input_ids = kwargs["origin_inputs"]
|
||||
model_inputs["input_ids"] = Tensor.from_numpy(input_ids.astype(np.int32))
|
||||
else:
|
||||
if self.config.is_dynamic:
|
||||
if prefill and "origin_inputs" in kwargs:
|
||||
origin_inputs = kwargs["origin_inputs"]
|
||||
slot_mapping = kwargs.get("slot_mapping")
|
||||
model_inputs = self._prepare_inputs_for_prefill_flatten(origin_inputs,
|
||||
batch_valid_length,
|
||||
slot_mapping,
|
||||
model_inputs)
|
||||
position_ids = batch_valid_length - 1
|
||||
model_inputs["position_ids"] = ms.Tensor(position_ids, dtype=ms.int32).reshape(-1)
|
||||
|
||||
if not prefill:
|
||||
q_seq_lens = np.ones(batch_valid_length.shape, dtype=np.int32).reshape(-1)
|
||||
else:
|
||||
q_seq_lens = batch_valid_length.astype(np.int32).reshape(-1)
|
||||
model_inputs["q_seq_lens"] = Tensor.from_numpy(q_seq_lens)
|
||||
|
||||
model_inputs["attention_mask"] = self.model.casual_mask.gen_attention_mask(prefill)
|
||||
return model_inputs
|
||||
|
||||
def set_dynamic_inputs(self, **kwargs):
|
||||
"""Set dynamic input for telechat."""
|
||||
dynamic_input_ids = Tensor(shape=[None, None], dtype=mstype.int32)
|
||||
dynamic_batch_valid_length = Tensor(shape=[None, None], dtype=mstype.int32)
|
||||
dynamic_block_tables = Tensor(shape=[None, None], dtype=mstype.int32)
|
||||
dynamic_slot_mapping = Tensor(shape=[None], dtype=mstype.int32)
|
||||
dynamic_position_ids = Tensor(shape=[None], dtype=mstype.int32)
|
||||
dynamic_q_seq_lens = Tensor(shape=[None], dtype=mstype.int32)
|
||||
dynamic_attention_mask = Tensor(shape=[None, None], dtype=self.config.compute_dtype)
|
||||
have_prefix_keys_values = getattr(kwargs, "have_prefix_keys_values", False)
|
||||
|
||||
def get_input():
|
||||
if self.npu_mem_size > 0:
|
||||
return None
|
||||
cache_list = []
|
||||
for _ in self.model.layers:
|
||||
cache_list.append(Tensor(shape=[None, None, None, None], dtype=self.config.compute_dtype))
|
||||
return mutable(cache_list)
|
||||
|
||||
key_cache = get_input()
|
||||
value_cache = get_input()
|
||||
if have_prefix_keys_values:
|
||||
dynamic_prefix_keys_values = Tensor(shape=[2, None, None, None, None], dtype=mstype.float16)
|
||||
self.set_inputs(dynamic_input_ids, None, None, None, None, None, None,
|
||||
dynamic_batch_valid_length, None, None, dynamic_block_tables,
|
||||
dynamic_slot_mapping, dynamic_prefix_keys_values, None)
|
||||
dynamic_slot_mapping, dynamic_prefix_keys_values, None, key_cache, value_cache)
|
||||
else:
|
||||
self.set_inputs(dynamic_input_ids, None, None, None, None, None, None,
|
||||
self.set_inputs(dynamic_input_ids, None, None, dynamic_position_ids, dynamic_attention_mask, None, None,
|
||||
dynamic_batch_valid_length, None, None, dynamic_block_tables,
|
||||
dynamic_slot_mapping, None, None)
|
||||
dynamic_slot_mapping, None, None, dynamic_q_seq_lens, key_cache, value_cache)
|
||||
logger.info("Set dynamic input for telechat.")
|
||||
|
||||
def add_flags_custom(self, is_first_iteration):
|
||||
@@ -145,22 +212,30 @@ class ParallelTelechatForCausalLM(TelechatPreTrainedModel):
|
||||
layer.attention.paged_attention_mgr.add_flags(is_first_iteration=is_first_iteration)
|
||||
|
||||
# pylint: disable=W0613
|
||||
@jit
|
||||
def construct(self, input_ids, labels=None, input_position=None, position_ids=None, attention_mask=None,
|
||||
input_embeds=None, init_reset=None, batch_valid_length=None, batch_index=None, zactivate_len=None,
|
||||
block_tables=None, slot_mapping=None, prefix_keys_values=None, llm_boost_inputs=None):
|
||||
block_tables=None, slot_mapping=None, prefix_keys_values=None, llm_boost_inputs=None,
|
||||
q_seq_lens=None, key_cache=None, value_cache=None):
|
||||
"""
|
||||
Forward of llama model.
|
||||
Forward of telechat model.
|
||||
"""
|
||||
bsz, _ = self.shape(input_ids)
|
||||
if batch_valid_length is not None:
|
||||
batch_valid_length = batch_valid_length.reshape(-1,)
|
||||
else:
|
||||
batch_valid_length = self.ones((bsz,), mstype.int32)
|
||||
if self.use_past:
|
||||
if not isinstance(batch_valid_length, Tensor):
|
||||
batch_valid_length = self.ones((bsz,), mstype.int32)
|
||||
else:
|
||||
batch_valid_length = self.reshape(batch_valid_length, (-1,))
|
||||
output = self.model(input_ids, batch_valid_length, batch_index, zactivate_len, block_tables,
|
||||
slot_mapping, prefix_keys_values)
|
||||
slot_mapping, prefix_keys_values, position_ids=position_ids, attention_mask=attention_mask,
|
||||
q_seq_lens=q_seq_lens, key_cache=key_cache, value_cache=value_cache)
|
||||
if self.return_hidden_states:
|
||||
output = self.reshape(output, (-1, output.shape[-1]))
|
||||
return output
|
||||
pre_gather = (not self.use_past or self.is_first_iteration) and batch_valid_length is not None
|
||||
if pre_gather:
|
||||
batch_valid_length = mint.cumsum(batch_valid_length, 0)
|
||||
if not self.is_pynative:
|
||||
batch_valid_length = mint.cumsum(batch_valid_length, 0)
|
||||
output = self.gather(output, self.sub_batch_valid_len(batch_valid_length, 1), 1)
|
||||
logits = self.lm_head(output)
|
||||
|
||||
|
||||
@@ -204,8 +204,8 @@ class TelechatParallelAttention(ParallelAttention):
|
||||
"""
|
||||
|
||||
def construct(self, x, batch_valid_length, block_tables, slot_mapping, freqs_cis=None,
|
||||
attn_mask=None, alibi_mask=None, prefix_keys_values=None, encoder_output=None,
|
||||
key_cache=None, value_cache=None):
|
||||
attn_mask=None, alibi_mask=None, encoder_output=None, prefix_keys_values=None,
|
||||
q_seq_lens=None, key_cache=None, value_cache=None):
|
||||
"""Construct function of attention block."""
|
||||
# hidden_states: [B, S, H]
|
||||
ori_dtype = x.dtype
|
||||
@@ -283,7 +283,7 @@ class TelechatParallelAttention(ParallelAttention):
|
||||
context_layer = self.core_attention(query, key, value, attn_mask)
|
||||
else:
|
||||
context_layer = self.paged_attention_mgr.paged_attn(query, batch_valid_length, block_tables,
|
||||
key_cache=key_cache, value_cache=value_cache)
|
||||
attn_mask, q_seq_lens, key_cache, value_cache)
|
||||
else:
|
||||
# [B, S, H] -> [B, N, S, D]
|
||||
query = query.reshape(bs, seq_len, -1, self.head_dim).transpose((0, 2, 1, 3))
|
||||
@@ -390,8 +390,8 @@ class TelechatParallelTransformerLayer(ParallelTransformerLayer):
|
||||
# MLP
|
||||
self.expert_num = 1 if config.moe_config is None else config.moe_config.expert_num
|
||||
self.use_moe_infer = config.use_past and self.expert_num > 1
|
||||
config.moe_config.router_dense_type = config.router_dense_type
|
||||
if self.use_moe_infer:
|
||||
config.moe_config.router_dense_type = config.router_dense_type
|
||||
self.feed_forward = TelechatParallelMoE(
|
||||
ffn=TelechatRoutedParallelMLP(config),
|
||||
hidden_size=config.hidden_size,
|
||||
@@ -436,8 +436,11 @@ class TelechatParallelTransformer(ParallelTransformer):
|
||||
self.enable_dynamic_ntk = False
|
||||
if config.extend_method == 'DYNAMIC_NTK':
|
||||
self.enable_dynamic_ntk = True
|
||||
base_seqlen = config.base_seqlen \
|
||||
if hasattr(config, "base_seqlen") and config.base_seqlen else None
|
||||
self.freqs_mgr = FreqsMgrDynamicNTK(head_dim=self.head_dim,
|
||||
max_position_embedding=config.max_position_embedding,
|
||||
base_seqlen=base_seqlen,
|
||||
rotary_dtype=config.rotary_dtype,
|
||||
theta=config.theta,
|
||||
parallel_config=config.parallel_config,
|
||||
@@ -451,7 +454,8 @@ class TelechatParallelTransformer(ParallelTransformer):
|
||||
|
||||
# pylint: disable=W0613
|
||||
def construct(self, tokens: Tensor, batch_valid_length=None, batch_index=None, zactivate_len=None,
|
||||
block_tables=None, slot_mapping=None, prefix_keys_values=None, key_cache=None, value_cache=None):
|
||||
block_tables=None, slot_mapping=None, prefix_keys_values=None, position_ids=None, attention_mask=None,
|
||||
q_seq_lens=None, key_cache=None, value_cache=None):
|
||||
"""
|
||||
Forward of ParallelTransformer.
|
||||
|
||||
@@ -466,7 +470,7 @@ class TelechatParallelTransformer(ParallelTransformer):
|
||||
"""
|
||||
# preprocess
|
||||
bs, seq_len = self.shape(tokens)
|
||||
mask = None
|
||||
mask = attention_mask
|
||||
if self.use_past:
|
||||
if self.is_first_iteration:
|
||||
if self.enable_dynamic_ntk:
|
||||
@@ -476,8 +480,6 @@ class TelechatParallelTransformer(ParallelTransformer):
|
||||
|
||||
if self.is_pynative:
|
||||
mask = self.casual_mask(tokens)
|
||||
else:
|
||||
mask = self.casual_mask.prefill()
|
||||
|
||||
if prefix_keys_values is not None:
|
||||
if mask is None:
|
||||
@@ -504,7 +506,7 @@ class TelechatParallelTransformer(ParallelTransformer):
|
||||
value_cache_i = value_cache[i] if value_cache is not None else None
|
||||
hidden_states = self.layers[i](hidden_states, freqs_cis, mask, batch_valid_length=batch_valid_length,
|
||||
block_tables=block_tables, slot_mapping=slot_mapping,
|
||||
prefix_keys_values=prefix_kv,
|
||||
prefix_keys_values=prefix_kv, q_seq_lens=q_seq_lens,
|
||||
key_cache=key_cache_i, value_cache=value_cache_i)
|
||||
|
||||
if self.post_norm:
|
||||
|
||||
@@ -19,6 +19,7 @@ parallel:
|
||||
parallel_config:
|
||||
data_parallel: 1
|
||||
model_parallel: 8
|
||||
vocab_emb_dp: False
|
||||
|
||||
# mindspore context init config
|
||||
context:
|
||||
|
||||
@@ -19,6 +19,7 @@ parallel:
|
||||
parallel_config:
|
||||
data_parallel: 1
|
||||
model_parallel: 2
|
||||
vocab_emb_dp: False
|
||||
|
||||
# mindspore context init config
|
||||
context:
|
||||
|
||||
@@ -33,6 +33,7 @@ parallel:
|
||||
parallel_config:
|
||||
data_parallel: 1
|
||||
model_parallel: 2
|
||||
vocab_emb_dp: False
|
||||
|
||||
# mindspore context init config
|
||||
context:
|
||||
|
||||
@@ -20,6 +20,7 @@ parallel:
|
||||
parallel_config:
|
||||
data_parallel: 1
|
||||
model_parallel: 1
|
||||
vocab_emb_dp: False
|
||||
|
||||
# mindspore context init config
|
||||
context:
|
||||
|
||||
@@ -93,6 +93,7 @@ class TelechatConfig(PretrainedConfig):
|
||||
The maximum number of tokens in one block can have when using paged attention.
|
||||
num_blocks (`int`, *optional*, defaults to 512):
|
||||
The maximum number of blocks when using paged attention.
|
||||
return_hidden_states (bool, optional): Whether to return hidden states. Default: ``False``.
|
||||
Returns:
|
||||
Class, TelechatConfig.
|
||||
"""
|
||||
@@ -155,6 +156,7 @@ class TelechatConfig(PretrainedConfig):
|
||||
quant: str = "",
|
||||
sigma: float = 0.0048,
|
||||
mean: float = 0.0,
|
||||
return_hidden_states: bool = False,
|
||||
**kwargs):
|
||||
super(TelechatConfig, self).__init__(**kwargs)
|
||||
if isinstance(parallel_config, dict):
|
||||
@@ -218,3 +220,4 @@ class TelechatConfig(PretrainedConfig):
|
||||
self.num_blocks = num_blocks
|
||||
self.quant = quant
|
||||
self.qkv_concat = qkv_concat
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
@@ -52,17 +52,17 @@ class TestInferParallel:
|
||||
Expectation: AssertionError
|
||||
"""
|
||||
commands = [
|
||||
(f"export ASCEND_RT_VISIBLE_DEVICES=0,1 && export LCAL_IF_PORT=10068 && "
|
||||
(f"export ASCEND_RT_VISIBLE_DEVICES=0,1 && export LCAL_COMM_ID=127.0.0.1:10068 && "
|
||||
f"export HCCL_IF_BASE_PORT=61000 && msrun --worker_num=2 "
|
||||
f"--local_worker_num=2 --master_port=8222 --log_dir=parallel_qwen2_0_5b_predict_mp2 --join=True "
|
||||
f"{cur_dir}/run_parallel.py --mode parallel_qwen2_0_5b_predict_mp2",
|
||||
'parallel_qwen2_0_5b_predict_mp2/worker_0.log'),
|
||||
(f"export ASCEND_RT_VISIBLE_DEVICES=2,3 && export LCAL_IF_PORT=10069 && "
|
||||
(f"export ASCEND_RT_VISIBLE_DEVICES=2,3 && export LCAL_COMM_ID=127.0.0.1:10069 && "
|
||||
f"export HCCL_IF_BASE_PORT=61100 && msrun --worker_num=2 "
|
||||
f"--local_worker_num=2 --master_port=8226 --log_dir=parallel_glm3_6b_predict_mp2 --join=True "
|
||||
f"{cur_dir}/run_parallel.py --mode parallel_glm3_6b_predict_mp2",
|
||||
'parallel_glm3_6b_predict_mp2/worker_0.log'),
|
||||
(f"export ASCEND_RT_VISIBLE_DEVICES=4,5 && export LCAL_IF_PORT=10070 && "
|
||||
(f"export ASCEND_RT_VISIBLE_DEVICES=4,5 && export LCAL_COMM_ID=127.0.0.1:10070 && "
|
||||
f"export HCCL_IF_BASE_PORT=61200 && msrun --worker_num=2 "
|
||||
f"--local_worker_num=2 --master_port=8228 --log_dir=parallel_shared_expert_predict_mp2 --join=True "
|
||||
f"{cur_dir}/run_parallel.py --mode parallel_shared_expert_predict_mp2",
|
||||
|
||||
Reference in New Issue
Block a user