mirror of
https://gitee.com/mindspore/mindformers.git
synced 2025-12-06 11:29:59 +08:00
Compare commits
4 Commits
3988632900
...
3d81fa1cb1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3d81fa1cb1 | ||
|
|
a9530fbbe2 | ||
|
|
fa9210dabd | ||
|
|
3d61efd427 |
@@ -650,7 +650,7 @@ class TransformerConfig(ModelParallelConfig, MFModelConfig):
|
||||
"When using moe_dry_run, moe_token_dispatcher_type must be 'alltoall' or 'alltoall_deredundency'."
|
||||
)
|
||||
|
||||
if self.position_embedding_type not in ["rope", "yarn", "none", "relative", "learned_absolute"]:
|
||||
if self.position_embedding_type not in ["rope", "yarn", "none", "relative", "learned_absolute", "partial_rope"]:
|
||||
raise ValueError(
|
||||
f"The current value of position_embedding_type is {self.position_embedding_type},"
|
||||
" but position_embedding_type must be one of: 'rope', 'yarn', 'none', 'relative', 'learned_absolute'."
|
||||
|
||||
@@ -168,13 +168,6 @@ torch_path: torch版本权重保存目录路径
|
||||
mindspore_path: 权重保存文件名,可以指定自定义保存路径
|
||||
```
|
||||
|
||||
2.获取MindFormers提供的已转换权重,可直接从下面的链接获取。
|
||||
|
||||
- [TeleChat2-7B](https://telechat-docker.obs.cn-north-4.myhuaweicloud.com/model_weight/Telechat_7B/Telechat_7B.zip)
|
||||
- [TeleChat2-35B](https://telechat-docker.obs.cn-north-4.myhuaweicloud.com/model_weight/Telechat_35B/Telechat_35B.zip)
|
||||
- [TeleChat2-115B](https://telechat-docker.obs.cn-north-4.myhuaweicloud.com/model_weight/Telechat_115B/Telechat_115B.zip)
|
||||
- [Telechat2-39B-A12B](https://telechat-docker.obs.cn-north-4.myhuaweicloud.com/model_weight/Telechat_39B_A12.tar):仅适用于8卡推理,使用方式请参考[Telechat2-39B-A12B推理](#Telechat2-39B-A12B推理)章节。
|
||||
|
||||
### 分布式权重切分与合并
|
||||
|
||||
分布式训练/微调后所得到的权重文件为根据策略切分后的权重,需要手动将切分权重合一,以用于评估和推理。
|
||||
|
||||
0
tests/st/test_ut/test_models/test_glm4/__init__.py
Normal file
0
tests/st/test_ut/test_models/test_glm4/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright 2025 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Unit tests for Glm4Config."""
|
||||
import pytest
|
||||
|
||||
from mindformers.models.glm4.configuration_glm4 import Glm4Config
|
||||
|
||||
|
||||
class TestGlm4Config:
|
||||
"""Validates default behaviors for the Glm4 configuration."""
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
def test_default_configuration_fields(self):
|
||||
"""Test that the Glm4Config initializes with expected default values."""
|
||||
config = Glm4Config()
|
||||
|
||||
assert config.vocab_size == 151552
|
||||
assert config.hidden_size == 4096
|
||||
assert config.num_hidden_layers == 40
|
||||
assert config.num_attention_heads == 32
|
||||
assert config.num_key_value_heads == 2
|
||||
assert config.position_embedding_type == "partial_rope"
|
||||
assert config.model_type == "glm4"
|
||||
assert "layers.*.self_attn.q_proj" in Glm4Config.base_model_tp_plan
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
def test_override_arguments_apply(self):
|
||||
"""Test that arguments passed to Glm4Config constructor correctly override the defaults."""
|
||||
config = Glm4Config(vocab_size=10, num_attention_heads=8, eos_token_id=(1,))
|
||||
|
||||
assert config.vocab_size == 10
|
||||
assert config.num_attention_heads == 8
|
||||
assert config.eos_token_id == (1,)
|
||||
46
tests/st/test_ut/test_models/test_glm4/test_modeling_glm4.py
Normal file
46
tests/st/test_ut/test_models/test_glm4/test_modeling_glm4.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright 2025 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""UTs for Glm4 modeling API."""
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from mindformers.models.glm4.configuration_glm4 import Glm4Config
|
||||
from mindformers.models.glm4.modeling_glm4 import Glm4ForCausalLM
|
||||
from mindformers.models.glm4.modeling_glm4_infer import InferenceGlm4ForCausalLM
|
||||
|
||||
|
||||
class TestGlm4ForCausalLM:
|
||||
"""Ensure Glm4ForCausalLM routes to the proper implementation."""
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
def test_init_model_in_predict_mode(self):
|
||||
"""When RUN_MODE is unset/predict, the inference model should be instantiated."""
|
||||
os.environ['RUN_MODE'] = "predict"
|
||||
config = Glm4Config()
|
||||
|
||||
model = Glm4ForCausalLM(config)
|
||||
|
||||
assert isinstance(model, InferenceGlm4ForCausalLM)
|
||||
assert model.config is config
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
def test_init_model_in_train_mode(self):
|
||||
"""RUN_MODE=train should raise an explicit NotImplementedError."""
|
||||
os.environ['RUN_MODE'] = "train"
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
Glm4ForCausalLM(Glm4Config())
|
||||
@@ -0,0 +1,48 @@
|
||||
# Copyright 2025 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Unit tests for Glm4MoeConfig."""
|
||||
import pytest
|
||||
|
||||
from mindformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig
|
||||
|
||||
|
||||
class TestGlm4MoeConfig:
|
||||
"""Tests covering the Glm4Moe configuration helper."""
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
def test_default_configuration_values(self):
|
||||
"""Ensure defaults from the spec are propagated to attributes."""
|
||||
config = Glm4MoeConfig()
|
||||
|
||||
assert config.vocab_size == 151552
|
||||
assert config.hidden_size == 4096
|
||||
assert config.num_hidden_layers == 46
|
||||
assert config.num_attention_heads == 96
|
||||
assert config.moe_intermediate_size == 1408
|
||||
assert config.num_experts_per_tok == 8
|
||||
assert config.norm_topk_prob is True
|
||||
assert config.model_type == "glm4_moe"
|
||||
assert "layers.*.self_attn.q_proj" in Glm4MoeConfig.base_model_tp_plan
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
def test_rope_scaling_type_key_is_renamed(self):
|
||||
"""When rope scaling contains 'type', it should be copied to 'rope_type'."""
|
||||
rope_scaling = {"type": "yarn", "factor": 2.0}
|
||||
config = Glm4MoeConfig(rope_scaling=rope_scaling)
|
||||
|
||||
assert config.rope_scaling["rope_type"] == "yarn"
|
||||
assert config.rope_scaling["factor"] == 2.0
|
||||
@@ -0,0 +1,46 @@
|
||||
# Copyright 2025 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""UTs for Glm4Moe modeling API."""
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from mindformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig
|
||||
from mindformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeForCausalLM
|
||||
from mindformers.models.glm4_moe.modeling_glm4_moe_infer import InferenceGlm4MoeForCausalLM
|
||||
|
||||
|
||||
class TestGlm4MoeForCausalLM:
|
||||
"""Ensure Glm4MoeForCausalLM routes to the proper implementation."""
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
def test_init_model_in_predict_mode(self):
|
||||
"""When RUN_MODE is unset/predict, the inference model should be instantiated."""
|
||||
os.environ['RUN_MODE'] = "predict"
|
||||
config = Glm4MoeConfig()
|
||||
|
||||
model = Glm4MoeForCausalLM(config)
|
||||
|
||||
assert isinstance(model, InferenceGlm4MoeForCausalLM)
|
||||
assert model.config is config
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
def test_init_model_in_train_mode(self):
|
||||
"""RUN_MODE=train should raise an explicit NotImplementedError."""
|
||||
os.environ['RUN_MODE'] = "train"
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
Glm4MoeForCausalLM(Glm4MoeConfig())
|
||||
@@ -0,0 +1,174 @@
|
||||
# Copyright 2025 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""UTs for tensor-parallel mapping helpers."""
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from mindformers.parallel_core.inference.parallel_state import ProcessGroup
|
||||
from mindformers.parallel_core.inference.tensor_parallel import mappings
|
||||
|
||||
|
||||
ms.context.set_context(deterministic="ON")
|
||||
jit_level = "O0"
|
||||
infer_boost = "on"
|
||||
ms.set_context(device_target="Ascend",
|
||||
mode=ms.GRAPH_MODE,
|
||||
jit_config={"jit_level": jit_level, "infer_boost": infer_boost})
|
||||
|
||||
|
||||
class FakeGather:
|
||||
"""Mock AllGather operator recording inputs."""
|
||||
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def __call__(self, tensor):
|
||||
self.calls.append(tensor)
|
||||
return tensor
|
||||
|
||||
|
||||
class FakeReduceScatter:
|
||||
"""Mock ReduceScatter returning half-size tensor."""
|
||||
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def __call__(self, tensor):
|
||||
self.calls.append(tensor)
|
||||
# Return the first split chunk
|
||||
return tensor[:tensor.shape[0] // 2]
|
||||
|
||||
|
||||
class FakeAllReduce:
|
||||
"""Mock AllReduce returning tensor doubled."""
|
||||
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def __call__(self, tensor):
|
||||
self.calls.append(tensor)
|
||||
return tensor * 2
|
||||
|
||||
|
||||
class FakeSplit:
|
||||
"""Mock Split op returning chunks."""
|
||||
|
||||
def __init__(self, axis, output_num):
|
||||
self.axis = axis
|
||||
self.output_num = output_num
|
||||
|
||||
def __call__(self, tensor):
|
||||
return tuple(np.split(tensor.asnumpy(), self.output_num, axis=self.axis))
|
||||
|
||||
class TestTensorParallelMappings:
|
||||
"""Groups mapping tests into a single suite."""
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gather_returns_input_when_group_size_one(self):
|
||||
"""
|
||||
Test that gather_from_model_parallel_region returns the original tensor unchanged
|
||||
when the process group size is 1.
|
||||
"""
|
||||
group = ProcessGroup(group=None, rank=0, size=1)
|
||||
# pylint: disable=W0212
|
||||
group._is_group_created = True
|
||||
tensor = Tensor(np.ones((2, 2), dtype=np.float32), dtype=mstype.float32)
|
||||
|
||||
output = mappings.gather_from_model_parallel_region(tensor, group, dim=-1)
|
||||
|
||||
assert np.array_equal(output.asnumpy(), tensor.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gather_transposes_when_dim_nonzero(self, monkeypatch):
|
||||
"""
|
||||
Test that gather_from_model_parallel_region correctly handles gathering along a non-last dimension.
|
||||
"""
|
||||
fake_gather = FakeGather()
|
||||
monkeypatch.setattr(mappings.ops, "AllGather", lambda group: fake_gather)
|
||||
group = ProcessGroup(group="test", rank=0, size=2)
|
||||
# pylint: disable=W0212
|
||||
group._is_group_created = True
|
||||
tensor = Tensor(np.arange(6).reshape(3, 2).astype(np.float32), dtype=mstype.float32)
|
||||
|
||||
output = mappings.gather_from_model_parallel_region(tensor, group, dim=1)
|
||||
|
||||
assert output.shape == tensor.shape
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_reduce_allreduce_invoked(self, monkeypatch):
|
||||
"""
|
||||
Test that reduce_from_model_parallel_region performs an AllReduce operation.
|
||||
"""
|
||||
fake_reduce = FakeAllReduce()
|
||||
monkeypatch.setattr(mappings.ops, "AllReduce", lambda group: fake_reduce)
|
||||
group = ProcessGroup(group="test", rank=0, size=2)
|
||||
# pylint: disable=W0212
|
||||
group._is_group_created = True
|
||||
tensor = Tensor(np.ones((2, 2), dtype=np.float32), dtype=mstype.float32)
|
||||
|
||||
output = mappings.reduce_from_model_parallel_region(tensor, group)
|
||||
|
||||
assert np.array_equal(output.asnumpy(), (tensor * 2).asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_reduce_scatter_returns_split(self, monkeypatch):
|
||||
"""
|
||||
Test that reduce_scatter_to_model_parallel_region performs a ReduceScatter operation.
|
||||
"""
|
||||
fake_reduce_scatter = FakeReduceScatter()
|
||||
monkeypatch.setattr(mappings.ops, "ReduceScatter", lambda group: fake_reduce_scatter)
|
||||
group = ProcessGroup(group="test", rank=0, size=2)
|
||||
# pylint: disable=W0212
|
||||
group._is_group_created = True
|
||||
tensor = Tensor(np.ones((4, 2), dtype=np.float32), dtype=mstype.float32)
|
||||
|
||||
output = mappings.reduce_scatter_to_model_parallel_region(tensor, group)
|
||||
|
||||
assert output.shape[0] == tensor.shape[0] // 2
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scatter_returns_rank_chunk(self, monkeypatch):
|
||||
"""
|
||||
Test that scatter_to_model_parallel_region splits the input tensor along the specified dimension.
|
||||
"""
|
||||
monkeypatch.setattr(mappings.ops, "Split", partial(FakeSplit))
|
||||
group = ProcessGroup(group="test", rank=1, size=2)
|
||||
# pylint: disable=W0212
|
||||
group._is_group_created = True
|
||||
tensor = Tensor(np.arange(8).reshape(2, 4).astype(np.float32), dtype=mstype.float32)
|
||||
|
||||
output = mappings.scatter_to_model_parallel_region(tensor, group, dim=1)
|
||||
|
||||
assert output.shape == (2, 2)
|
||||
@@ -0,0 +1,92 @@
|
||||
# Copyright 2025 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""UTs for FusedScaleMaskSoftmax."""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from mindformers.parallel_core.inference.transformer.fused_softmax import FusedScaleMaskSoftmax
|
||||
|
||||
|
||||
ms.context.set_context(deterministic="ON")
|
||||
jit_level = "O0"
|
||||
infer_boost = "on"
|
||||
ms.set_context(device_target="Ascend",
|
||||
mode=ms.GRAPH_MODE,
|
||||
jit_config={"jit_level": jit_level, "infer_boost": infer_boost})
|
||||
|
||||
|
||||
def simple_mask(tensor, mask):
|
||||
"""Mask function for tests that multiplies by mask."""
|
||||
return tensor + mask
|
||||
|
||||
|
||||
class TestFusedScaleMaskSoftmax:
|
||||
"""Tests covering the fused softmax helper."""
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_forward_pass_with_scale_and_mask(self):
|
||||
"""
|
||||
Test the forward pass of FusedScaleMaskSoftmax with both scaling and a mask applied.
|
||||
|
||||
Verifies that the module correctly applies the scale factor to the input tensor,
|
||||
applies the provided attention mask, and computes the softmax, returning an output
|
||||
with the expected shape. This tests the core functionality under typical conditions.
|
||||
"""
|
||||
fused_softmax = FusedScaleMaskSoftmax(mask_func=simple_mask, scale=0.5, softmax_compute_type=mstype.float32)
|
||||
x = Tensor(np.array([[2.0, 0.0]], dtype=np.float32), dtype=mstype.float32)
|
||||
mask = Tensor(np.array([[0.0, -1.0]], dtype=np.float32), dtype=mstype.float32)
|
||||
|
||||
output = fused_softmax(x, mask)
|
||||
|
||||
assert output.shape == (1, 2)
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_precision_casts_to_fp32_when_needed(self):
|
||||
"""
|
||||
Test that FusedScaleMaskSoftmax automatically casts inputs to float32 when required.
|
||||
|
||||
Verifies that when the softmax computation type is set to float32 but the input
|
||||
tensor is in float16, the module performs the necessary precision casting to fp32
|
||||
for the softmax operation, ensuring numerical stability, and returns an output
|
||||
with the correct shape.
|
||||
"""
|
||||
fused_softmax = FusedScaleMaskSoftmax(mask_func=simple_mask, scale=None, softmax_compute_type=mstype.float32)
|
||||
x = Tensor(np.array([[1.0, 1.0]], dtype=np.float16), dtype=mstype.float16)
|
||||
|
||||
output = fused_softmax(x, mask=None)
|
||||
|
||||
assert output.shape == (1, 2)
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_invalid_scale_precision_combination_raises(self):
|
||||
"""
|
||||
Test that FusedScaleMaskSoftmax raises a ValueError for invalid precision combinations.
|
||||
|
||||
Verifies that the module enforces the rule that if a scale factor is applied,
|
||||
the softmax computation must be performed in float32 to maintain precision.
|
||||
Attempting to use a scale with float16 computation should raise a ValueError.
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
FusedScaleMaskSoftmax(mask_func=simple_mask, scale=0.1, softmax_compute_type=mstype.float16)
|
||||
@@ -0,0 +1,70 @@
|
||||
# Copyright 2025 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Unit tests for LowerTriangularMaskWithDynamic."""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from mindformers.parallel_core.inference.transformer.lower_triangular_mask import (
|
||||
LowerTriangularMaskWithDynamic,
|
||||
)
|
||||
|
||||
|
||||
ms.context.set_context(deterministic="ON")
|
||||
jit_level = "O0"
|
||||
infer_boost = "on"
|
||||
ms.set_context(device_target="Ascend",
|
||||
mode=ms.GRAPH_MODE,
|
||||
jit_config={"jit_level": jit_level, "infer_boost": infer_boost})
|
||||
|
||||
|
||||
class TestLowerTriangularMask:
|
||||
"""Validates lower-triangular mask generation."""
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_lower_triangular_mask_in_prefill(self):
|
||||
"""Prefill path should directly return the static fa mask."""
|
||||
lower_triangular_mask = LowerTriangularMaskWithDynamic(seq_length=4, compute_type=mstype.float16)
|
||||
lower_triangular_mask.is_prefill = True
|
||||
|
||||
mask = lower_triangular_mask(positions=Tensor(np.zeros((1, 4)), dtype=mstype.int32))
|
||||
assert mask.shape == lower_triangular_mask.fa_lower_triangle_mask.shape
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_lower_triangular_mask_in_decode(self):
|
||||
"""Decode path should gather using provided positions."""
|
||||
lower_triangular_mask = LowerTriangularMaskWithDynamic(seq_length=4, compute_type=mstype.float16)
|
||||
lower_triangular_mask.is_prefill = False
|
||||
positions = Tensor(np.array([0, 2], dtype=np.int32))
|
||||
|
||||
mask = lower_triangular_mask(positions=positions)
|
||||
expected_shape = (positions.shape[0], lower_triangular_mask.pa_lower_triangle_mask.shape[1])
|
||||
assert mask.shape == expected_shape
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_lower_triangular_mask_bfloat16_in_prefill_mask(self):
|
||||
"""When using bf16 compute type, mask coefficient becomes +1."""
|
||||
lower_triangular_mask = LowerTriangularMaskWithDynamic(seq_length=4, compute_type=mstype.bfloat16)
|
||||
mask = lower_triangular_mask.prefill()
|
||||
assert mask.shape == lower_triangular_mask.fa_lower_triangle_mask.shape
|
||||
@@ -0,0 +1,114 @@
|
||||
# Copyright 2025 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""UTs for `moe_utils.py`."""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from mindformers.parallel_core.inference.transformer.moe.moe_utils import (
|
||||
group_limited_topk,
|
||||
topk_routing_with_score_function,
|
||||
)
|
||||
|
||||
|
||||
ms.context.set_context(deterministic="ON")
|
||||
jit_level = "O0"
|
||||
infer_boost = "on"
|
||||
ms.set_context(device_target="Ascend",
|
||||
mode=ms.GRAPH_MODE,
|
||||
jit_config={"jit_level": jit_level, "infer_boost": infer_boost})
|
||||
|
||||
|
||||
class TestTopkRoutingWithScoreFunction:
|
||||
"""Unit tests for the top-k routing helper."""
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_softmax_routing_returns_normalized_weights(self):
|
||||
"""The weights should sum to one per token when normalization is enabled."""
|
||||
logits = Tensor(
|
||||
np.array([[1.0, 2.0, 0.5, -0.5], [-1.0, 0.0, 2.5, 1.0]], dtype=np.float32),
|
||||
dtype=mstype.float32,
|
||||
)
|
||||
expert_weight, routing_map = topk_routing_with_score_function(
|
||||
logits=logits,
|
||||
topk=2,
|
||||
num_experts=4,
|
||||
score_function="softmax",
|
||||
norm_topk_prob=True,
|
||||
)
|
||||
|
||||
assert expert_weight.shape == (2, 2)
|
||||
assert routing_map.shape == (2, 2)
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sigmoid_routing_with_bias_without_normalization(self):
|
||||
"""Bias should affect the chosen experts while weights stay unnormalized when disabled."""
|
||||
logits = Tensor(
|
||||
np.array([[0.0, -2.0, 2.0, 1.0]], dtype=np.float32),
|
||||
dtype=mstype.float32,
|
||||
)
|
||||
expert_bias = Tensor(np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32), dtype=mstype.float32)
|
||||
|
||||
expert_weight, routing_map = topk_routing_with_score_function(
|
||||
logits=logits,
|
||||
topk=2,
|
||||
num_experts=4,
|
||||
score_function="sigmoid",
|
||||
expert_bias=expert_bias,
|
||||
norm_topk_prob=False,
|
||||
)
|
||||
|
||||
assert expert_weight.shape == (1, 2)
|
||||
assert routing_map.shape == (1, 2)
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_group_limited_topk_only_selects_from_best_group(self):
|
||||
"""group_limited_topk should not route experts outside the best group subset."""
|
||||
scores = Tensor(np.array([[0.9, 0.8, 0.1, 0.2]], dtype=np.float32), dtype=mstype.float32)
|
||||
|
||||
probs, top_indices = group_limited_topk(
|
||||
scores=scores,
|
||||
topk=2,
|
||||
num_experts=4,
|
||||
num_groups=2,
|
||||
group_topk=1,
|
||||
)
|
||||
|
||||
assert probs.shape == (1, 2)
|
||||
assert top_indices.shape == (1, 2)
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_invalid_score_function_raises(self):
|
||||
"""An unsupported score function name should raise ValueError."""
|
||||
logits = Tensor(np.zeros((1, 2), dtype=np.float32), dtype=mstype.float32)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
topk_routing_with_score_function(
|
||||
logits=logits,
|
||||
topk=1,
|
||||
num_experts=2,
|
||||
score_function="unsupported",
|
||||
)
|
||||
@@ -0,0 +1,331 @@
|
||||
# Copyright 2025 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Unit tests for transformer utils helpers."""
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import Tensor, Parameter
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from mindformers.parallel_core.transformer_config import TransformerConfig
|
||||
import mindformers.parallel_core.inference.utils as transformer_utils
|
||||
|
||||
|
||||
class DummySubCell:
|
||||
"""Subcell exposing a sharded state dict."""
|
||||
|
||||
def __init__(self):
|
||||
self.param = Parameter(
|
||||
Tensor(np.ones((2, 2), dtype=np.float32), dtype=mstype.float32), name="sub.param"
|
||||
)
|
||||
|
||||
def sharded_state_dict(self):
|
||||
return {
|
||||
"sub.param": {
|
||||
"shape": self.param.shape,
|
||||
"shard": (1, 2),
|
||||
}
|
||||
}
|
||||
|
||||
def name_cells(self):
|
||||
return {"self": self}
|
||||
|
||||
|
||||
class DummyNetwork:
|
||||
"""Minimal network exposing parameters and cells."""
|
||||
|
||||
def __init__(self):
|
||||
self.sub = DummySubCell()
|
||||
self.head = Parameter(Tensor(np.ones((2,), dtype=np.float32), dtype=mstype.float32), name="head.bias")
|
||||
|
||||
def name_cells(self):
|
||||
return {"self": self, "sub": self.sub}
|
||||
|
||||
def parameters_dict(self):
|
||||
return {"sub.param": self.sub.param, "head.bias": self.head}
|
||||
|
||||
|
||||
class TestAttnMaskHelpers:
|
||||
"""Tests for attention mask helpers."""
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_attn_mask_fill_applies_value(self):
|
||||
"""
|
||||
Test 'attn_mask_fill' function correctly applies the fill value to masked positions.
|
||||
"""
|
||||
func = transformer_utils.get_attn_mask_func("attn_mask_fill")
|
||||
scores = Tensor(np.ones((1, 2), dtype=np.float32), dtype=mstype.float32)
|
||||
mask = Tensor(np.array([[False, True]]), dtype=mstype.bool_)
|
||||
output = func(scores, mask, fill_value=-9.0)
|
||||
|
||||
output_np = output.asnumpy()
|
||||
assert output_np[0, 0] == pytest.approx(1.0, rel=1e-6)
|
||||
assert output_np[0, 1] == pytest.approx(-9.0, rel=1e-6)
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_attn_mask_add_casts_mask(self):
|
||||
"""
|
||||
Test 'attn_mask_add' function adding a float mask to attention scores.
|
||||
"""
|
||||
func = transformer_utils.get_attn_mask_func("attn_mask_add")
|
||||
scores = Tensor(np.zeros((1, 2), dtype=np.float32), dtype=mstype.float32)
|
||||
mask = Tensor(np.array([[0.0, -5.0]], dtype=np.float32), dtype=mstype.float32)
|
||||
output = func(scores, mask)
|
||||
|
||||
output_np = output.asnumpy()
|
||||
assert output_np.shape == (1, 2)
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_get_attn_mask_func_with_invalid_name(self):
|
||||
"""
|
||||
Test get_attn_mask_func raising a KeyError for an unsupported mask function type.
|
||||
"""
|
||||
with pytest.raises(KeyError):
|
||||
transformer_utils.get_attn_mask_func("unknown")
|
||||
|
||||
|
||||
class TestStateDictGeneration:
|
||||
"""Tests for sharded state dict utilities."""
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_generate_state_dict_includes_sharded_and_full_params(self, monkeypatch):
|
||||
"""
|
||||
Test that generate_state_dict correctly includes both sharded and non-sharded parameters.
|
||||
"""
|
||||
monkeypatch.setattr(transformer_utils, "get_group_size", lambda: 2)
|
||||
state_dict = transformer_utils.generate_state_dict(DummyNetwork())
|
||||
|
||||
assert state_dict["total_rank"] == 2
|
||||
assert "sub.param" in state_dict["model"]
|
||||
assert "head.bias" in state_dict["model"]
|
||||
assert state_dict["model"]["head.bias"]["shard"] == (1,)
|
||||
|
||||
|
||||
class TestCommAndTopologyHelpers:
|
||||
"""Tests targeting communication helper utilities."""
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_update_comm_config_single_tp_multi_dp(self, monkeypatch):
|
||||
"""
|
||||
Test update_comm_config for a configuration with single tensor parallel group and multiple data parallel groups.
|
||||
"""
|
||||
monkeypatch.setattr(transformer_utils, "get_tensor_model_parallel_world_size", lambda: 1)
|
||||
monkeypatch.setattr(transformer_utils, "get_data_parallel_world_size", lambda: 2)
|
||||
monkeypatch.setattr(transformer_utils, "get_moe_tensor_parallel_world_size", lambda: 1)
|
||||
|
||||
config = TransformerConfig(num_layers=1, num_attention_heads=1)
|
||||
updated = transformer_utils.update_comm_config(config)
|
||||
|
||||
assert updated.use_alltoall is True
|
||||
assert updated.attn_allreduce is False
|
||||
assert updated.ffn_allreduce is False
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_update_comm_config_moe_tp_enabled(self, monkeypatch):
|
||||
"""
|
||||
Test update_comm_config when MOE tensor parallelism is enabled.
|
||||
"""
|
||||
monkeypatch.setattr(transformer_utils, "get_tensor_model_parallel_world_size", lambda: 1)
|
||||
monkeypatch.setattr(transformer_utils, "get_data_parallel_world_size", lambda: 2)
|
||||
monkeypatch.setattr(transformer_utils, "get_moe_tensor_parallel_world_size", lambda: 2)
|
||||
|
||||
config = TransformerConfig(num_layers=1, num_attention_heads=1)
|
||||
updated = transformer_utils.update_comm_config(config)
|
||||
|
||||
assert updated.attn_allgather is True
|
||||
assert updated.ffn_reduce_scatter is True
|
||||
assert updated.ffn_allreduce is False
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_get_num_layers_and_offset_with_pp_offsets(self, monkeypatch):
|
||||
"""
|
||||
Test get_num_layers_and_offset with a valid pipeline parallel offset configuration.
|
||||
"""
|
||||
monkeypatch.setattr(transformer_utils, "get_pipeline_model_parallel_world_size", lambda: 2)
|
||||
monkeypatch.setattr(transformer_utils, "get_pipeline_model_parallel_rank", lambda: 1)
|
||||
|
||||
config = TransformerConfig(num_layers=5, offset=[1, 0], num_attention_heads=1)
|
||||
|
||||
layers, offset = transformer_utils.get_num_layers_and_offset(config)
|
||||
|
||||
assert layers == 2
|
||||
assert offset == 3
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_get_num_layers_and_offset_raises_for_small_model(self, monkeypatch):
|
||||
"""
|
||||
Test that get_num_layers_and_offset raises RuntimeError when the model has too few layers.
|
||||
"""
|
||||
monkeypatch.setattr(transformer_utils, "get_pipeline_model_parallel_world_size", lambda: 8)
|
||||
|
||||
config = TransformerConfig(num_layers=4, num_attention_heads=1)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
transformer_utils.get_num_layers_and_offset(config)
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_get_num_layers_and_offset_invalid_offset_shape(self, monkeypatch):
|
||||
"""
|
||||
Test that get_num_layers_and_offset raises ValueError for an offset list with incorrect length.
|
||||
"""
|
||||
monkeypatch.setattr(transformer_utils, "get_pipeline_model_parallel_world_size", lambda: 2)
|
||||
|
||||
config = TransformerConfig(num_layers=6, offset=[1, 0, 0], num_attention_heads=1)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
transformer_utils.get_num_layers_and_offset(config)
|
||||
|
||||
|
||||
class TestMathHelpers:
|
||||
"""Tests for small math helpers."""
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_divide_checks_divisibility(self):
|
||||
"""
|
||||
Test that the divide function checks for exact divisibility.
|
||||
"""
|
||||
assert transformer_utils.divide(6, 3) == 2
|
||||
with pytest.raises(ValueError):
|
||||
transformer_utils.divide(5, 3)
|
||||
|
||||
|
||||
class TestCustomOpsToggle:
|
||||
"""Tests for custom ops toggling."""
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_use_ms_custom_ops_false_when_module_missing(self, monkeypatch):
|
||||
"""
|
||||
Test that use_ms_custom_ops returns False when the 'ms_custom_ops' module is not imported.
|
||||
|
||||
Ensures the fallback mechanism works correctly if the custom operators package is unavailable.
|
||||
"""
|
||||
monkeypatch.setitem(sys.modules, "ms_custom_ops", None)
|
||||
assert transformer_utils.use_ms_custom_ops() is False
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_use_ms_custom_ops_true_when_module_present(self, monkeypatch):
|
||||
"""
|
||||
Test that use_ms_custom_ops returns True when the 'ms_custom_ops' module is present and not on 310p.
|
||||
|
||||
Verifies the primary condition for enabling custom operators based on module availability.
|
||||
"""
|
||||
monkeypatch.setitem(sys.modules, "ms_custom_ops", SimpleNamespace())
|
||||
monkeypatch.setattr(transformer_utils, "is_310p", lambda: False)
|
||||
assert transformer_utils.use_ms_custom_ops() is True
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_use_ms_custom_ops_false_when_310p(self, monkeypatch):
|
||||
"""
|
||||
Test that use_ms_custom_ops returns False even if the module is present when running on 310p.
|
||||
|
||||
Confirms the hardware-specific override that disables custom operators on the Ascend 310P platform.
|
||||
"""
|
||||
monkeypatch.setitem(sys.modules, "ms_custom_ops", SimpleNamespace())
|
||||
monkeypatch.setattr(transformer_utils, "is_310p", lambda: True)
|
||||
assert transformer_utils.use_ms_custom_ops() is False
|
||||
|
||||
|
||||
class TestParameterUtility:
|
||||
"""Covers helpers related to parameter creation."""
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_create_empty_parameter_returns_expected_shape(self):
|
||||
"""
|
||||
Test that create_empty_parameter creates a Parameter with the specified shape and data type.
|
||||
"""
|
||||
param = transformer_utils.create_empty_parameter((2, 3), dtype=mstype.float32, name="dummy")
|
||||
assert param.shape == (2, 3)
|
||||
assert param.dtype == mstype.float32
|
||||
|
||||
|
||||
class TestWorldSizeFallbacks:
|
||||
"""Ensure fallback logic returns non-zero defaults."""
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_world_size_helpers_default_to_one(self, monkeypatch):
|
||||
"""
|
||||
Test that world size helper functions default to 1 when underlying query functions return 0.
|
||||
|
||||
Ensures robustness by providing safe defaults for parallelism degrees, preventing division by zero.
|
||||
"""
|
||||
monkeypatch.setattr(transformer_utils, "get_tensor_model_parallel_world_size", lambda: 0)
|
||||
monkeypatch.setattr(transformer_utils, "get_moe_tensor_parallel_world_size", lambda: 0)
|
||||
monkeypatch.setattr(transformer_utils, "get_moe_expert_parallel_world_size", lambda: 0)
|
||||
monkeypatch.setattr(transformer_utils, "get_data_parallel_world_size", lambda: 0)
|
||||
|
||||
assert transformer_utils.get_tp_world_size() == 1
|
||||
assert transformer_utils.get_moe_tp_world_size() == 1
|
||||
assert transformer_utils.get_moe_ep_world_size() == 1
|
||||
assert transformer_utils.get_dp_world_size() == 1
|
||||
|
||||
|
||||
class TestPaddingIndexGeneration:
|
||||
"""Tests for generate_padding_index helper."""
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend910b_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_generate_padding_index_single_dp(self, monkeypatch):
|
||||
"""
|
||||
Test generate_padding_index for a simple case with single data parallel group.
|
||||
|
||||
Verifies that the function generates padding and unpadding indices with the correct shape
|
||||
based on the input sequence lengths.
|
||||
"""
|
||||
monkeypatch.setattr(transformer_utils, "get_tensor_model_parallel_world_size", lambda: 1)
|
||||
monkeypatch.setattr(transformer_utils, "get_data_parallel_world_size", lambda: 1)
|
||||
monkeypatch.setattr(transformer_utils, "get_data_parallel_group",
|
||||
lambda: SimpleNamespace(rank=0, group=None))
|
||||
|
||||
q_seq_lens = Tensor(np.array([[2]], dtype=np.int32))
|
||||
attn_pad, attn_unpad, ffn_pad, ffn_unpad = transformer_utils.generate_padding_index(q_seq_lens)
|
||||
|
||||
assert attn_pad.shape == (2,)
|
||||
assert attn_unpad.shape == (2,)
|
||||
assert ffn_pad.shape == (2,)
|
||||
assert ffn_unpad.shape == (2,)
|
||||
Reference in New Issue
Block a user