mirror of
https://gitee.com/mindspore/mindformers.git
synced 2025-12-06 11:29:59 +08:00
修复test_pma用例超时,去掉test_all_reduce中无用用例
This commit is contained in:
@@ -25,7 +25,7 @@ from mindformers.tools.logger import logger
|
||||
|
||||
|
||||
_LEVEL_0_TASK_TIME = 0
|
||||
_LEVEL_1_TASK_TIME = 124
|
||||
_LEVEL_1_TASK_TIME = 436
|
||||
_TASK_TYPE = TaskType.FOUR_CARDS_TASK
|
||||
|
||||
|
||||
|
||||
@@ -63,30 +63,6 @@ class TestHelperFunctions(unittest.TestCase):
|
||||
self.assertEqual(loss, 0.8)
|
||||
self.assertFalse(overflow)
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_get_weight_norm(self):
|
||||
"""Test _get_weight_norm function"""
|
||||
# Create mock network
|
||||
mock_network = Mock()
|
||||
param1 = Mock()
|
||||
param1.to.return_value = param1
|
||||
param1.norm.return_value = Tensor(np.array([2.0]))
|
||||
param2 = Mock()
|
||||
param2.to.return_value = param2
|
||||
param2.norm.return_value = Tensor(np.array([3.0]))
|
||||
|
||||
mock_network.trainable_params.return_value = [param1, param2]
|
||||
|
||||
with patch('mindspore.ops.functional.stack') as mock_stack:
|
||||
mock_stack.return_value = Tensor(np.array([3.605551]))
|
||||
|
||||
# pylint: disable=W0212
|
||||
norm = callback_module._get_weight_norm(mock_network)
|
||||
|
||||
self.assertAlmostEqual(norm, 3.605551, places=5)
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
||||
Reference in New Issue
Block a user