diff --git a/mindformers/parallel_core/transformer_config.py b/mindformers/parallel_core/transformer_config.py index 1c90ae88a..02f343ffa 100644 --- a/mindformers/parallel_core/transformer_config.py +++ b/mindformers/parallel_core/transformer_config.py @@ -650,7 +650,7 @@ class TransformerConfig(ModelParallelConfig, MFModelConfig): "When using moe_dry_run, moe_token_dispatcher_type must be 'alltoall' or 'alltoall_deredundency'." ) - if self.position_embedding_type not in ["rope", "yarn", "none", "relative", "learned_absolute"]: + if self.position_embedding_type not in ["rope", "yarn", "none", "relative", "learned_absolute", "partial_rope"]: raise ValueError( f"The current value of position_embedding_type is {self.position_embedding_type}," " but position_embedding_type must be one of: 'rope', 'yarn', 'none', 'relative', 'learned_absolute'." diff --git a/tests/st/test_ut/test_models/test_glm4/__init__.py b/tests/st/test_ut/test_models/test_glm4/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/st/test_ut/test_models/test_glm4/test_configuration_glm4.py b/tests/st/test_ut/test_models/test_glm4/test_configuration_glm4.py new file mode 100644 index 000000000..aacc90465 --- /dev/null +++ b/tests/st/test_ut/test_models/test_glm4/test_configuration_glm4.py @@ -0,0 +1,47 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Unit tests for Glm4Config.""" +import pytest + +from mindformers.models.glm4.configuration_glm4 import Glm4Config + + +class TestGlm4Config: + """Validates default behaviors for the Glm4 configuration.""" + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_default_configuration_fields(self): + """Test that the Glm4Config initializes with expected default values.""" + config = Glm4Config() + + assert config.vocab_size == 151552 + assert config.hidden_size == 4096 + assert config.num_hidden_layers == 40 + assert config.num_attention_heads == 32 + assert config.num_key_value_heads == 2 + assert config.position_embedding_type == "partial_rope" + assert config.model_type == "glm4" + assert "layers.*.self_attn.q_proj" in Glm4Config.base_model_tp_plan + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_override_arguments_apply(self): + """Test that arguments passed to Glm4Config constructor correctly override the defaults.""" + config = Glm4Config(vocab_size=10, num_attention_heads=8, eos_token_id=(1,)) + + assert config.vocab_size == 10 + assert config.num_attention_heads == 8 + assert config.eos_token_id == (1,) diff --git a/tests/st/test_ut/test_models/test_glm4/test_modeling_glm4.py b/tests/st/test_ut/test_models/test_glm4/test_modeling_glm4.py new file mode 100644 index 000000000..ab46542b0 --- /dev/null +++ b/tests/st/test_ut/test_models/test_glm4/test_modeling_glm4.py @@ -0,0 +1,46 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""UTs for Glm4 modeling API.""" +import os +import pytest + +from mindformers.models.glm4.configuration_glm4 import Glm4Config +from mindformers.models.glm4.modeling_glm4 import Glm4ForCausalLM +from mindformers.models.glm4.modeling_glm4_infer import InferenceGlm4ForCausalLM + + +class TestGlm4ForCausalLM: + """Ensure Glm4ForCausalLM routes to the proper implementation.""" + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_init_model_in_predict_mode(self): + """When RUN_MODE is unset/predict, the inference model should be instantiated.""" + os.environ['RUN_MODE'] = "predict" + config = Glm4Config() + + model = Glm4ForCausalLM(config) + + assert isinstance(model, InferenceGlm4ForCausalLM) + assert model.config is config + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_init_model_in_train_mode(self): + """RUN_MODE=train should raise an explicit NotImplementedError.""" + os.environ['RUN_MODE'] = "train" + + with pytest.raises(NotImplementedError): + Glm4ForCausalLM(Glm4Config()) diff --git a/tests/st/test_ut/test_models/test_glm4_moe/__init__.py b/tests/st/test_ut/test_models/test_glm4_moe/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/st/test_ut/test_models/test_glm4_moe/test_configuration_glm4_moe.py b/tests/st/test_ut/test_models/test_glm4_moe/test_configuration_glm4_moe.py new file mode 100644 index 000000000..4e10edeb8 --- /dev/null +++ b/tests/st/test_ut/test_models/test_glm4_moe/test_configuration_glm4_moe.py @@ -0,0 +1,48 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Unit tests for Glm4MoeConfig.""" +import pytest + +from mindformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig + + +class TestGlm4MoeConfig: + """Tests covering the Glm4Moe configuration helper.""" + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_default_configuration_values(self): + """Ensure defaults from the spec are propagated to attributes.""" + config = Glm4MoeConfig() + + assert config.vocab_size == 151552 + assert config.hidden_size == 4096 + assert config.num_hidden_layers == 46 + assert config.num_attention_heads == 96 + assert config.moe_intermediate_size == 1408 + assert config.num_experts_per_tok == 8 + assert config.norm_topk_prob is True + assert config.model_type == "glm4_moe" + assert "layers.*.self_attn.q_proj" in Glm4MoeConfig.base_model_tp_plan + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_rope_scaling_type_key_is_renamed(self): + """When rope scaling contains 'type', it should be copied to 'rope_type'.""" + rope_scaling = {"type": "yarn", "factor": 2.0} + config = Glm4MoeConfig(rope_scaling=rope_scaling) + + assert config.rope_scaling["rope_type"] == "yarn" + assert config.rope_scaling["factor"] == 2.0 diff --git a/tests/st/test_ut/test_models/test_glm4_moe/test_modeling_glm4_moe.py b/tests/st/test_ut/test_models/test_glm4_moe/test_modeling_glm4_moe.py new file mode 100644 index 000000000..900c73704 --- /dev/null +++ b/tests/st/test_ut/test_models/test_glm4_moe/test_modeling_glm4_moe.py @@ -0,0 +1,46 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""UTs for Glm4Moe modeling API.""" +import os +import pytest + +from mindformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig +from mindformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeForCausalLM +from mindformers.models.glm4_moe.modeling_glm4_moe_infer import InferenceGlm4MoeForCausalLM + + +class TestGlm4MoeForCausalLM: + """Ensure Glm4MoeForCausalLM routes to the proper implementation.""" + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_init_model_in_predict_mode(self): + """When RUN_MODE is unset/predict, the inference model should be instantiated.""" + os.environ['RUN_MODE'] = "predict" + config = Glm4MoeConfig() + + model = Glm4MoeForCausalLM(config) + + assert isinstance(model, InferenceGlm4MoeForCausalLM) + assert model.config is config + + @pytest.mark.level1 + @pytest.mark.platform_x86_cpu + def test_init_model_in_train_mode(self): + """RUN_MODE=train should raise an explicit NotImplementedError.""" + os.environ['RUN_MODE'] = "train" + + with pytest.raises(NotImplementedError): + Glm4MoeForCausalLM(Glm4MoeConfig()) diff --git a/tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/test_mapping/__init__.py b/tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/test_mapping/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/test_mapping/test_infer_mapping.py b/tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/test_mapping/test_infer_mapping.py new file mode 100644 index 000000000..b81e9ffad --- /dev/null +++ b/tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/test_mapping/test_infer_mapping.py @@ -0,0 +1,174 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""UTs for tensor-parallel mapping helpers.""" +from functools import partial + +import numpy as np +import pytest + +import mindspore as ms +from mindspore import Tensor +import mindspore.common.dtype as mstype + +from mindformers.parallel_core.inference.parallel_state import ProcessGroup +from mindformers.parallel_core.inference.tensor_parallel import mappings + + +ms.context.set_context(deterministic="ON") +jit_level = "O0" +infer_boost = "on" +ms.set_context(device_target="Ascend", + mode=ms.GRAPH_MODE, + jit_config={"jit_level": jit_level, "infer_boost": infer_boost}) + + +class FakeGather: + """Mock AllGather operator recording inputs.""" + + def __init__(self): + self.calls = [] + + def __call__(self, tensor): + self.calls.append(tensor) + return tensor + + +class FakeReduceScatter: + """Mock ReduceScatter returning half-size tensor.""" + + def __init__(self): + self.calls = [] + + def __call__(self, tensor): + self.calls.append(tensor) + # Return the first split chunk + return tensor[:tensor.shape[0] // 2] + + +class FakeAllReduce: + """Mock AllReduce returning tensor doubled.""" + + def __init__(self): + self.calls = [] + + def __call__(self, tensor): + self.calls.append(tensor) + return tensor * 2 + + +class FakeSplit: + """Mock Split op returning chunks.""" + + def __init__(self, axis, output_num): + self.axis = axis + self.output_num = output_num + + def __call__(self, tensor): + return tuple(np.split(tensor.asnumpy(), self.output_num, axis=self.axis)) + +class TestTensorParallelMappings: + """Groups mapping tests into a single suite.""" + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_gather_returns_input_when_group_size_one(self): + """ + Test that gather_from_model_parallel_region returns the original tensor unchanged + when the process group size is 1. + """ + group = ProcessGroup(group=None, rank=0, size=1) + # pylint: disable=W0212 + group._is_group_created = True + tensor = Tensor(np.ones((2, 2), dtype=np.float32), dtype=mstype.float32) + + output = mappings.gather_from_model_parallel_region(tensor, group, dim=-1) + + assert np.array_equal(output.asnumpy(), tensor.asnumpy()) + + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_gather_transposes_when_dim_nonzero(self, monkeypatch): + """ + Test that gather_from_model_parallel_region correctly handles gathering along a non-last dimension. + """ + fake_gather = FakeGather() + monkeypatch.setattr(mappings.ops, "AllGather", lambda group: fake_gather) + group = ProcessGroup(group="test", rank=0, size=2) + # pylint: disable=W0212 + group._is_group_created = True + tensor = Tensor(np.arange(6).reshape(3, 2).astype(np.float32), dtype=mstype.float32) + + output = mappings.gather_from_model_parallel_region(tensor, group, dim=1) + + assert output.shape == tensor.shape + + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_reduce_allreduce_invoked(self, monkeypatch): + """ + Test that reduce_from_model_parallel_region performs an AllReduce operation. + """ + fake_reduce = FakeAllReduce() + monkeypatch.setattr(mappings.ops, "AllReduce", lambda group: fake_reduce) + group = ProcessGroup(group="test", rank=0, size=2) + # pylint: disable=W0212 + group._is_group_created = True + tensor = Tensor(np.ones((2, 2), dtype=np.float32), dtype=mstype.float32) + + output = mappings.reduce_from_model_parallel_region(tensor, group) + + assert np.array_equal(output.asnumpy(), (tensor * 2).asnumpy()) + + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_reduce_scatter_returns_split(self, monkeypatch): + """ + Test that reduce_scatter_to_model_parallel_region performs a ReduceScatter operation. + """ + fake_reduce_scatter = FakeReduceScatter() + monkeypatch.setattr(mappings.ops, "ReduceScatter", lambda group: fake_reduce_scatter) + group = ProcessGroup(group="test", rank=0, size=2) + # pylint: disable=W0212 + group._is_group_created = True + tensor = Tensor(np.ones((4, 2), dtype=np.float32), dtype=mstype.float32) + + output = mappings.reduce_scatter_to_model_parallel_region(tensor, group) + + assert output.shape[0] == tensor.shape[0] // 2 + + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_scatter_returns_rank_chunk(self, monkeypatch): + """ + Test that scatter_to_model_parallel_region splits the input tensor along the specified dimension. + """ + monkeypatch.setattr(mappings.ops, "Split", partial(FakeSplit)) + group = ProcessGroup(group="test", rank=1, size=2) + # pylint: disable=W0212 + group._is_group_created = True + tensor = Tensor(np.arange(8).reshape(2, 4).astype(np.float32), dtype=mstype.float32) + + output = mappings.scatter_to_model_parallel_region(tensor, group, dim=1) + + assert output.shape == (2, 2) diff --git a/tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_fused_softmax/__init__.py b/tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_fused_softmax/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_fused_softmax/test_infer_fused_softmax.py b/tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_fused_softmax/test_infer_fused_softmax.py new file mode 100644 index 000000000..c90120d82 --- /dev/null +++ b/tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_fused_softmax/test_infer_fused_softmax.py @@ -0,0 +1,92 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""UTs for FusedScaleMaskSoftmax.""" +import numpy as np +import pytest + +import mindspore as ms +from mindspore import Tensor +import mindspore.common.dtype as mstype + +from mindformers.parallel_core.inference.transformer.fused_softmax import FusedScaleMaskSoftmax + + +ms.context.set_context(deterministic="ON") +jit_level = "O0" +infer_boost = "on" +ms.set_context(device_target="Ascend", + mode=ms.GRAPH_MODE, + jit_config={"jit_level": jit_level, "infer_boost": infer_boost}) + + +def simple_mask(tensor, mask): + """Mask function for tests that multiplies by mask.""" + return tensor + mask + + +class TestFusedScaleMaskSoftmax: + """Tests covering the fused softmax helper.""" + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_forward_pass_with_scale_and_mask(self): + """ + Test the forward pass of FusedScaleMaskSoftmax with both scaling and a mask applied. + + Verifies that the module correctly applies the scale factor to the input tensor, + applies the provided attention mask, and computes the softmax, returning an output + with the expected shape. This tests the core functionality under typical conditions. + """ + fused_softmax = FusedScaleMaskSoftmax(mask_func=simple_mask, scale=0.5, softmax_compute_type=mstype.float32) + x = Tensor(np.array([[2.0, 0.0]], dtype=np.float32), dtype=mstype.float32) + mask = Tensor(np.array([[0.0, -1.0]], dtype=np.float32), dtype=mstype.float32) + + output = fused_softmax(x, mask) + + assert output.shape == (1, 2) + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_precision_casts_to_fp32_when_needed(self): + """ + Test that FusedScaleMaskSoftmax automatically casts inputs to float32 when required. + + Verifies that when the softmax computation type is set to float32 but the input + tensor is in float16, the module performs the necessary precision casting to fp32 + for the softmax operation, ensuring numerical stability, and returns an output + with the correct shape. + """ + fused_softmax = FusedScaleMaskSoftmax(mask_func=simple_mask, scale=None, softmax_compute_type=mstype.float32) + x = Tensor(np.array([[1.0, 1.0]], dtype=np.float16), dtype=mstype.float16) + + output = fused_softmax(x, mask=None) + + assert output.shape == (1, 2) + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_invalid_scale_precision_combination_raises(self): + """ + Test that FusedScaleMaskSoftmax raises a ValueError for invalid precision combinations. + + Verifies that the module enforces the rule that if a scale factor is applied, + the softmax computation must be performed in float32 to maintain precision. + Attempting to use a scale with float16 computation should raise a ValueError. + """ + with pytest.raises(ValueError): + FusedScaleMaskSoftmax(mask_func=simple_mask, scale=0.1, softmax_compute_type=mstype.float16) diff --git a/tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_lower_triangular_mask/__init__.py b/tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_lower_triangular_mask/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_lower_triangular_mask/test_infer_lower_triangular_mask.py b/tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_lower_triangular_mask/test_infer_lower_triangular_mask.py new file mode 100644 index 000000000..2b2a69855 --- /dev/null +++ b/tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_lower_triangular_mask/test_infer_lower_triangular_mask.py @@ -0,0 +1,70 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Unit tests for LowerTriangularMaskWithDynamic.""" +import numpy as np +import pytest + +import mindspore as ms +from mindspore import Tensor +import mindspore.common.dtype as mstype + +from mindformers.parallel_core.inference.transformer.lower_triangular_mask import ( + LowerTriangularMaskWithDynamic, +) + + +ms.context.set_context(deterministic="ON") +jit_level = "O0" +infer_boost = "on" +ms.set_context(device_target="Ascend", + mode=ms.GRAPH_MODE, + jit_config={"jit_level": jit_level, "infer_boost": infer_boost}) + + +class TestLowerTriangularMask: + """Validates lower-triangular mask generation.""" + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_lower_triangular_mask_in_prefill(self): + """Prefill path should directly return the static fa mask.""" + lower_triangular_mask = LowerTriangularMaskWithDynamic(seq_length=4, compute_type=mstype.float16) + lower_triangular_mask.is_prefill = True + + mask = lower_triangular_mask(positions=Tensor(np.zeros((1, 4)), dtype=mstype.int32)) + assert mask.shape == lower_triangular_mask.fa_lower_triangle_mask.shape + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_lower_triangular_mask_in_decode(self): + """Decode path should gather using provided positions.""" + lower_triangular_mask = LowerTriangularMaskWithDynamic(seq_length=4, compute_type=mstype.float16) + lower_triangular_mask.is_prefill = False + positions = Tensor(np.array([0, 2], dtype=np.int32)) + + mask = lower_triangular_mask(positions=positions) + expected_shape = (positions.shape[0], lower_triangular_mask.pa_lower_triangle_mask.shape[1]) + assert mask.shape == expected_shape + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_lower_triangular_mask_bfloat16_in_prefill_mask(self): + """When using bf16 compute type, mask coefficient becomes +1.""" + lower_triangular_mask = LowerTriangularMaskWithDynamic(seq_length=4, compute_type=mstype.bfloat16) + mask = lower_triangular_mask.prefill() + assert mask.shape == lower_triangular_mask.fa_lower_triangle_mask.shape diff --git a/tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_moe/test_moe_utils/__init__.py b/tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_moe/test_moe_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_moe/test_moe_utils/test_infer_moe_utils.py b/tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_moe/test_moe_utils/test_infer_moe_utils.py new file mode 100644 index 000000000..c11495f0b --- /dev/null +++ b/tests/st/test_ut/test_parallel_core/test_inference/test_transformer/test_moe/test_moe_utils/test_infer_moe_utils.py @@ -0,0 +1,114 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""UTs for `moe_utils.py`.""" +import numpy as np +import pytest + +import mindspore as ms +from mindspore import Tensor +import mindspore.common.dtype as mstype + +from mindformers.parallel_core.inference.transformer.moe.moe_utils import ( + group_limited_topk, + topk_routing_with_score_function, +) + + +ms.context.set_context(deterministic="ON") +jit_level = "O0" +infer_boost = "on" +ms.set_context(device_target="Ascend", + mode=ms.GRAPH_MODE, + jit_config={"jit_level": jit_level, "infer_boost": infer_boost}) + + +class TestTopkRoutingWithScoreFunction: + """Unit tests for the top-k routing helper.""" + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_softmax_routing_returns_normalized_weights(self): + """The weights should sum to one per token when normalization is enabled.""" + logits = Tensor( + np.array([[1.0, 2.0, 0.5, -0.5], [-1.0, 0.0, 2.5, 1.0]], dtype=np.float32), + dtype=mstype.float32, + ) + expert_weight, routing_map = topk_routing_with_score_function( + logits=logits, + topk=2, + num_experts=4, + score_function="softmax", + norm_topk_prob=True, + ) + + assert expert_weight.shape == (2, 2) + assert routing_map.shape == (2, 2) + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_sigmoid_routing_with_bias_without_normalization(self): + """Bias should affect the chosen experts while weights stay unnormalized when disabled.""" + logits = Tensor( + np.array([[0.0, -2.0, 2.0, 1.0]], dtype=np.float32), + dtype=mstype.float32, + ) + expert_bias = Tensor(np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32), dtype=mstype.float32) + + expert_weight, routing_map = topk_routing_with_score_function( + logits=logits, + topk=2, + num_experts=4, + score_function="sigmoid", + expert_bias=expert_bias, + norm_topk_prob=False, + ) + + assert expert_weight.shape == (1, 2) + assert routing_map.shape == (1, 2) + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_group_limited_topk_only_selects_from_best_group(self): + """group_limited_topk should not route experts outside the best group subset.""" + scores = Tensor(np.array([[0.9, 0.8, 0.1, 0.2]], dtype=np.float32), dtype=mstype.float32) + + probs, top_indices = group_limited_topk( + scores=scores, + topk=2, + num_experts=4, + num_groups=2, + group_topk=1, + ) + + assert probs.shape == (1, 2) + assert top_indices.shape == (1, 2) + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_invalid_score_function_raises(self): + """An unsupported score function name should raise ValueError.""" + logits = Tensor(np.zeros((1, 2), dtype=np.float32), dtype=mstype.float32) + + with pytest.raises(ValueError): + topk_routing_with_score_function( + logits=logits, + topk=1, + num_experts=2, + score_function="unsupported", + ) diff --git a/tests/st/test_ut/test_parallel_core/test_inference/test_utils/__init__.py b/tests/st/test_ut/test_parallel_core/test_inference/test_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/st/test_ut/test_parallel_core/test_inference/test_utils/test_utils.py b/tests/st/test_ut/test_parallel_core/test_inference/test_utils/test_utils.py new file mode 100644 index 000000000..a7e47235c --- /dev/null +++ b/tests/st/test_ut/test_parallel_core/test_inference/test_utils/test_utils.py @@ -0,0 +1,331 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Unit tests for transformer utils helpers.""" +import sys +from types import SimpleNamespace + +import numpy as np +import pytest + +from mindspore import Tensor, Parameter +import mindspore.common.dtype as mstype + +from mindformers.parallel_core.transformer_config import TransformerConfig +import mindformers.parallel_core.inference.utils as transformer_utils + + +class DummySubCell: + """Subcell exposing a sharded state dict.""" + + def __init__(self): + self.param = Parameter( + Tensor(np.ones((2, 2), dtype=np.float32), dtype=mstype.float32), name="sub.param" + ) + + def sharded_state_dict(self): + return { + "sub.param": { + "shape": self.param.shape, + "shard": (1, 2), + } + } + + def name_cells(self): + return {"self": self} + + +class DummyNetwork: + """Minimal network exposing parameters and cells.""" + + def __init__(self): + self.sub = DummySubCell() + self.head = Parameter(Tensor(np.ones((2,), dtype=np.float32), dtype=mstype.float32), name="head.bias") + + def name_cells(self): + return {"self": self, "sub": self.sub} + + def parameters_dict(self): + return {"sub.param": self.sub.param, "head.bias": self.head} + + +class TestAttnMaskHelpers: + """Tests for attention mask helpers.""" + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_attn_mask_fill_applies_value(self): + """ + Test 'attn_mask_fill' function correctly applies the fill value to masked positions. + """ + func = transformer_utils.get_attn_mask_func("attn_mask_fill") + scores = Tensor(np.ones((1, 2), dtype=np.float32), dtype=mstype.float32) + mask = Tensor(np.array([[False, True]]), dtype=mstype.bool_) + output = func(scores, mask, fill_value=-9.0) + + output_np = output.asnumpy() + assert output_np[0, 0] == pytest.approx(1.0, rel=1e-6) + assert output_np[0, 1] == pytest.approx(-9.0, rel=1e-6) + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_attn_mask_add_casts_mask(self): + """ + Test 'attn_mask_add' function adding a float mask to attention scores. + """ + func = transformer_utils.get_attn_mask_func("attn_mask_add") + scores = Tensor(np.zeros((1, 2), dtype=np.float32), dtype=mstype.float32) + mask = Tensor(np.array([[0.0, -5.0]], dtype=np.float32), dtype=mstype.float32) + output = func(scores, mask) + + output_np = output.asnumpy() + assert output_np.shape == (1, 2) + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_get_attn_mask_func_with_invalid_name(self): + """ + Test get_attn_mask_func raising a KeyError for an unsupported mask function type. + """ + with pytest.raises(KeyError): + transformer_utils.get_attn_mask_func("unknown") + + +class TestStateDictGeneration: + """Tests for sharded state dict utilities.""" + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_generate_state_dict_includes_sharded_and_full_params(self, monkeypatch): + """ + Test that generate_state_dict correctly includes both sharded and non-sharded parameters. + """ + monkeypatch.setattr(transformer_utils, "get_group_size", lambda: 2) + state_dict = transformer_utils.generate_state_dict(DummyNetwork()) + + assert state_dict["total_rank"] == 2 + assert "sub.param" in state_dict["model"] + assert "head.bias" in state_dict["model"] + assert state_dict["model"]["head.bias"]["shard"] == (1,) + + +class TestCommAndTopologyHelpers: + """Tests targeting communication helper utilities.""" + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_update_comm_config_single_tp_multi_dp(self, monkeypatch): + """ + Test update_comm_config for a configuration with single tensor parallel group and multiple data parallel groups. + """ + monkeypatch.setattr(transformer_utils, "get_tensor_model_parallel_world_size", lambda: 1) + monkeypatch.setattr(transformer_utils, "get_data_parallel_world_size", lambda: 2) + monkeypatch.setattr(transformer_utils, "get_moe_tensor_parallel_world_size", lambda: 1) + + config = TransformerConfig(num_layers=1, num_attention_heads=1) + updated = transformer_utils.update_comm_config(config) + + assert updated.use_alltoall is True + assert updated.attn_allreduce is False + assert updated.ffn_allreduce is False + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_update_comm_config_moe_tp_enabled(self, monkeypatch): + """ + Test update_comm_config when MOE tensor parallelism is enabled. + """ + monkeypatch.setattr(transformer_utils, "get_tensor_model_parallel_world_size", lambda: 1) + monkeypatch.setattr(transformer_utils, "get_data_parallel_world_size", lambda: 2) + monkeypatch.setattr(transformer_utils, "get_moe_tensor_parallel_world_size", lambda: 2) + + config = TransformerConfig(num_layers=1, num_attention_heads=1) + updated = transformer_utils.update_comm_config(config) + + assert updated.attn_allgather is True + assert updated.ffn_reduce_scatter is True + assert updated.ffn_allreduce is False + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_get_num_layers_and_offset_with_pp_offsets(self, monkeypatch): + """ + Test get_num_layers_and_offset with a valid pipeline parallel offset configuration. + """ + monkeypatch.setattr(transformer_utils, "get_pipeline_model_parallel_world_size", lambda: 2) + monkeypatch.setattr(transformer_utils, "get_pipeline_model_parallel_rank", lambda: 1) + + config = TransformerConfig(num_layers=5, offset=[1, 0], num_attention_heads=1) + + layers, offset = transformer_utils.get_num_layers_and_offset(config) + + assert layers == 2 + assert offset == 3 + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_get_num_layers_and_offset_raises_for_small_model(self, monkeypatch): + """ + Test that get_num_layers_and_offset raises RuntimeError when the model has too few layers. + """ + monkeypatch.setattr(transformer_utils, "get_pipeline_model_parallel_world_size", lambda: 8) + + config = TransformerConfig(num_layers=4, num_attention_heads=1) + + with pytest.raises(RuntimeError): + transformer_utils.get_num_layers_and_offset(config) + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_get_num_layers_and_offset_invalid_offset_shape(self, monkeypatch): + """ + Test that get_num_layers_and_offset raises ValueError for an offset list with incorrect length. + """ + monkeypatch.setattr(transformer_utils, "get_pipeline_model_parallel_world_size", lambda: 2) + + config = TransformerConfig(num_layers=6, offset=[1, 0, 0], num_attention_heads=1) + + with pytest.raises(ValueError): + transformer_utils.get_num_layers_and_offset(config) + + +class TestMathHelpers: + """Tests for small math helpers.""" + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_divide_checks_divisibility(self): + """ + Test that the divide function checks for exact divisibility. + """ + assert transformer_utils.divide(6, 3) == 2 + with pytest.raises(ValueError): + transformer_utils.divide(5, 3) + + +class TestCustomOpsToggle: + """Tests for custom ops toggling.""" + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_use_ms_custom_ops_false_when_module_missing(self, monkeypatch): + """ + Test that use_ms_custom_ops returns False when the 'ms_custom_ops' module is not imported. + + Ensures the fallback mechanism works correctly if the custom operators package is unavailable. + """ + monkeypatch.setitem(sys.modules, "ms_custom_ops", None) + assert transformer_utils.use_ms_custom_ops() is False + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_use_ms_custom_ops_true_when_module_present(self, monkeypatch): + """ + Test that use_ms_custom_ops returns True when the 'ms_custom_ops' module is present and not on 310p. + + Verifies the primary condition for enabling custom operators based on module availability. + """ + monkeypatch.setitem(sys.modules, "ms_custom_ops", SimpleNamespace()) + monkeypatch.setattr(transformer_utils, "is_310p", lambda: False) + assert transformer_utils.use_ms_custom_ops() is True + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_use_ms_custom_ops_false_when_310p(self, monkeypatch): + """ + Test that use_ms_custom_ops returns False even if the module is present when running on 310p. + + Confirms the hardware-specific override that disables custom operators on the Ascend 310P platform. + """ + monkeypatch.setitem(sys.modules, "ms_custom_ops", SimpleNamespace()) + monkeypatch.setattr(transformer_utils, "is_310p", lambda: True) + assert transformer_utils.use_ms_custom_ops() is False + + +class TestParameterUtility: + """Covers helpers related to parameter creation.""" + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_create_empty_parameter_returns_expected_shape(self): + """ + Test that create_empty_parameter creates a Parameter with the specified shape and data type. + """ + param = transformer_utils.create_empty_parameter((2, 3), dtype=mstype.float32, name="dummy") + assert param.shape == (2, 3) + assert param.dtype == mstype.float32 + + +class TestWorldSizeFallbacks: + """Ensure fallback logic returns non-zero defaults.""" + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_world_size_helpers_default_to_one(self, monkeypatch): + """ + Test that world size helper functions default to 1 when underlying query functions return 0. + + Ensures robustness by providing safe defaults for parallelism degrees, preventing division by zero. + """ + monkeypatch.setattr(transformer_utils, "get_tensor_model_parallel_world_size", lambda: 0) + monkeypatch.setattr(transformer_utils, "get_moe_tensor_parallel_world_size", lambda: 0) + monkeypatch.setattr(transformer_utils, "get_moe_expert_parallel_world_size", lambda: 0) + monkeypatch.setattr(transformer_utils, "get_data_parallel_world_size", lambda: 0) + + assert transformer_utils.get_tp_world_size() == 1 + assert transformer_utils.get_moe_tp_world_size() == 1 + assert transformer_utils.get_moe_ep_world_size() == 1 + assert transformer_utils.get_dp_world_size() == 1 + + +class TestPaddingIndexGeneration: + """Tests for generate_padding_index helper.""" + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_generate_padding_index_single_dp(self, monkeypatch): + """ + Test generate_padding_index for a simple case with single data parallel group. + + Verifies that the function generates padding and unpadding indices with the correct shape + based on the input sequence lengths. + """ + monkeypatch.setattr(transformer_utils, "get_tensor_model_parallel_world_size", lambda: 1) + monkeypatch.setattr(transformer_utils, "get_data_parallel_world_size", lambda: 1) + monkeypatch.setattr(transformer_utils, "get_data_parallel_group", + lambda: SimpleNamespace(rank=0, group=None)) + + q_seq_lens = Tensor(np.array([[2]], dtype=np.int32)) + attn_pad, attn_unpad, ffn_pad, ffn_unpad = transformer_utils.generate_padding_index(q_seq_lens) + + assert attn_pad.shape == (2,) + assert attn_unpad.shape == (2,) + assert ffn_pad.shape == (2,) + assert ffn_unpad.shape == (2,)