mirror of
https://gitee.com/mindspore/mindarmour.git
synced 2025-12-06 11:59:05 +08:00
add examples/model_security/model_attacks/white_box/mnist_attack_vmifgsm.py.
测试variance tuning momentum iterative method. Signed-off-by: weiwan <wanwei_0303@hust.edu.cn> update examples/model_security/model_attacks/white_box/mnist_attack_vmifgsm.py. Signed-off-by: weiwan <wanwei_0303@hust.edu.cn> update examples/model_security/model_attacks/white_box/mnist_attack_vmifgsm.py. Signed-off-by: weiwan <wanwei_0303@hust.edu.cn> update examples/model_security/model_attacks/white_box/mnist_attack_vmifgsm.py. Signed-off-by: weiwan <wanwei_0303@hust.edu.cn> update examples/model_security/model_attacks/white_box/mnist_attack_vmifgsm.py. Signed-off-by: weiwan <wanwei_0303@hust.edu.cn> new file: examples/model_security/model_attacks/white_box/mnist_attack_vmifgsm.py modified: mindarmour/adv_robustness/attacks/__init__.py modified: tests/ut/python/adv_robustness/attacks/test_iterative_gradient_method.py new file: examples/model_security/model_attacks/white_box/mnist_attack_vmifgsm.py modified: mindarmour/adv_robustness/attacks/__init__.py modified: tests/ut/python/adv_robustness/attacks/test_iterative_gradient_method.py new file: examples/model_security/model_attacks/white_box/mnist_attack_vmifgsm.py modified: mindarmour/adv_robustness/attacks/__init__.py modified: tests/ut/python/adv_robustness/attacks/test_iterative_gradient_method.py modified: .jenkins/check/config/filter_pylint.txt new file: examples/model_security/model_attacks/white_box/mnist_attack_vmifgsm.py modified: mindarmour/adv_robustness/attacks/__init__.py modified: tests/ut/python/adv_robustness/attacks/test_iterative_gradient_method.py modified: .jenkins/check/config/filter_pylint.txt new file: examples/model_security/model_attacks/white_box/mnist_attack_vmifgsm.py modified: mindarmour/adv_robustness/attacks/__init__.py modified: tests/ut/python/adv_robustness/attacks/test_iterative_gradient_method.py modified: .jenkins/check/config/filter_pylint.txt new file: examples/model_security/model_attacks/white_box/mnist_attack_vmifgsm.py modified: mindarmour/adv_robustness/attacks/__init__.py modified: tests/ut/python/adv_robustness/attacks/test_iterative_gradient_method.py modified: .jenkins/check/config/filter_pylint.txt new file: examples/model_security/model_attacks/white_box/mnist_attack_vmifgsm.py modified: mindarmour/adv_robustness/attacks/__init__.py modified: tests/ut/python/adv_robustness/attacks/test_iterative_gradient_method.py
This commit is contained in:
@@ -43,6 +43,7 @@
|
|||||||
"mindarmour/examples/model_security/model_attacks/white_box/mnist_attack_lbfgs.py" "missing-docstring"
|
"mindarmour/examples/model_security/model_attacks/white_box/mnist_attack_lbfgs.py" "missing-docstring"
|
||||||
"mindarmour/examples/model_security/model_attacks/white_box/mnist_attack_mdi2fgsm.py" "missing-docstring"
|
"mindarmour/examples/model_security/model_attacks/white_box/mnist_attack_mdi2fgsm.py" "missing-docstring"
|
||||||
"mindarmour/examples/model_security/model_attacks/white_box/mnist_attack_pgd.py" "missing-docstring"
|
"mindarmour/examples/model_security/model_attacks/white_box/mnist_attack_pgd.py" "missing-docstring"
|
||||||
|
"mindarmour/examples/model_security/model_attacks/white_box/mnist_attack_vmifgsm.py" "missing-docstring"
|
||||||
"mindarmour/examples/model_security/model_defenses/mnist_defense_nad.py" "missing-docstring"
|
"mindarmour/examples/model_security/model_defenses/mnist_defense_nad.py" "missing-docstring"
|
||||||
"mindarmour/examples/model_security/model_defenses/mnist_evaluation.py" "missing-docstring"
|
"mindarmour/examples/model_security/model_defenses/mnist_evaluation.py" "missing-docstring"
|
||||||
"mindarmour/examples/model_security/model_defenses/mnist_similarity_detector.py" "missing-docstring"
|
"mindarmour/examples/model_security/model_defenses/mnist_similarity_detector.py" "missing-docstring"
|
||||||
|
|||||||
@@ -0,0 +1,110 @@
|
|||||||
|
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
from scipy.special import softmax
|
||||||
|
|
||||||
|
from mindspore import Model
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.nn import SoftmaxCrossEntropyWithLogits
|
||||||
|
|
||||||
|
from mindarmour.adv_robustness.attacks import VarianceTuningMomentumIterativeMethod
|
||||||
|
from mindarmour.adv_robustness.evaluations import AttackEvaluate
|
||||||
|
from mindarmour.utils.logger import LogUtil
|
||||||
|
|
||||||
|
from examples.common.dataset.data_processing import generate_mnist_dataset
|
||||||
|
from examples.common.networks.lenet5.lenet5_net import LeNet5
|
||||||
|
|
||||||
|
LOGGER = LogUtil.get_instance()
|
||||||
|
LOGGER.set_level('INFO')
|
||||||
|
TAG = 'VMI_Test'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_variance_tuning_momentum_iterative_method():
|
||||||
|
"""
|
||||||
|
test for CPU device.
|
||||||
|
"""
|
||||||
|
# upload trained network
|
||||||
|
ckpt_path = '../../../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
|
||||||
|
net = LeNet5()
|
||||||
|
load_dict = load_checkpoint(ckpt_path)
|
||||||
|
load_param_into_net(net, load_dict)
|
||||||
|
|
||||||
|
# get test data
|
||||||
|
data_list = "../../../common/dataset/MNIST/test"
|
||||||
|
batch_size = 32
|
||||||
|
ds = generate_mnist_dataset(data_list, batch_size)
|
||||||
|
|
||||||
|
# prediction accuracy before attack
|
||||||
|
model = Model(net)
|
||||||
|
batch_num = 32 # the number of batches of attacking samples
|
||||||
|
test_images = []
|
||||||
|
test_labels = []
|
||||||
|
predict_labels = []
|
||||||
|
i = 0
|
||||||
|
for data in ds.create_tuple_iterator(output_numpy=True):
|
||||||
|
i += 1
|
||||||
|
images = data[0].astype(np.float32)
|
||||||
|
labels = data[1]
|
||||||
|
test_images.append(images)
|
||||||
|
test_labels.append(labels)
|
||||||
|
pred_labels = np.argmax(model.predict(Tensor(images)).asnumpy(),
|
||||||
|
axis=1)
|
||||||
|
predict_labels.append(pred_labels)
|
||||||
|
if i >= batch_num:
|
||||||
|
break
|
||||||
|
predict_labels = np.concatenate(predict_labels)
|
||||||
|
true_labels = np.concatenate(test_labels)
|
||||||
|
accuracy = np.mean(np.equal(predict_labels, true_labels))
|
||||||
|
LOGGER.info(TAG, "prediction accuracy before attacking is : %s", accuracy)
|
||||||
|
|
||||||
|
# attacking
|
||||||
|
loss = SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||||
|
attack = VarianceTuningMomentumIterativeMethod(net, eps=0.3, loss_fn=loss)
|
||||||
|
start_time = time.process_time()
|
||||||
|
adv_data = attack.batch_generate(np.concatenate(test_images),
|
||||||
|
true_labels, batch_size=32)
|
||||||
|
stop_time = time.process_time()
|
||||||
|
np.save('./adv_data', adv_data)
|
||||||
|
pred_logits_adv = model.predict(Tensor(adv_data)).asnumpy()
|
||||||
|
# rescale predict confidences into (0, 1).
|
||||||
|
pred_logits_adv = softmax(pred_logits_adv, axis=1)
|
||||||
|
pred_labels_adv = np.argmax(pred_logits_adv, axis=1)
|
||||||
|
accuracy_adv = np.mean(np.equal(pred_labels_adv, true_labels))
|
||||||
|
LOGGER.info(TAG, "prediction accuracy after attacking is : %s", accuracy_adv)
|
||||||
|
attack_evaluate = AttackEvaluate(np.concatenate(test_images).transpose(0, 2, 3, 1),
|
||||||
|
np.eye(10)[true_labels],
|
||||||
|
adv_data.transpose(0, 2, 3, 1),
|
||||||
|
pred_logits_adv)
|
||||||
|
LOGGER.info(TAG, 'mis-classification rate of adversaries is : %s',
|
||||||
|
attack_evaluate.mis_classification_rate())
|
||||||
|
LOGGER.info(TAG, 'The average confidence of adversarial class is : %s',
|
||||||
|
attack_evaluate.avg_conf_adv_class())
|
||||||
|
LOGGER.info(TAG, 'The average confidence of true class is : %s',
|
||||||
|
attack_evaluate.avg_conf_true_class())
|
||||||
|
LOGGER.info(TAG, 'The average distance (l0, l2, linf) between original '
|
||||||
|
'samples and adversarial samples are: %s',
|
||||||
|
attack_evaluate.avg_lp_distance())
|
||||||
|
LOGGER.info(TAG, 'The average structural similarity between original '
|
||||||
|
'samples and adversarial samples are: %s',
|
||||||
|
attack_evaluate.avg_ssim())
|
||||||
|
LOGGER.info(TAG, 'The average costing time is %s',
|
||||||
|
(stop_time - start_time)/(batch_num*batch_size))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# device_target can be "CPU", "GPU" or "Ascend"
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||||
|
test_variance_tuning_momentum_iterative_method()
|
||||||
@@ -18,7 +18,8 @@ in making adversarial examples.
|
|||||||
from .gradient_method import FastGradientMethod, FastGradientSignMethod, RandomFastGradientMethod, \
|
from .gradient_method import FastGradientMethod, FastGradientSignMethod, RandomFastGradientMethod, \
|
||||||
RandomFastGradientSignMethod, LeastLikelyClassMethod, RandomLeastLikelyClassMethod
|
RandomFastGradientSignMethod, LeastLikelyClassMethod, RandomLeastLikelyClassMethod
|
||||||
from .iterative_gradient_method import IterativeGradientMethod, BasicIterativeMethod, MomentumIterativeMethod, \
|
from .iterative_gradient_method import IterativeGradientMethod, BasicIterativeMethod, MomentumIterativeMethod, \
|
||||||
ProjectedGradientDescent, DiverseInputIterativeMethod, MomentumDiverseInputIterativeMethod
|
ProjectedGradientDescent, DiverseInputIterativeMethod, MomentumDiverseInputIterativeMethod, \
|
||||||
|
VarianceTuningMomentumIterativeMethod
|
||||||
from .deep_fool import DeepFool
|
from .deep_fool import DeepFool
|
||||||
from .jsma import JSMAAttack
|
from .jsma import JSMAAttack
|
||||||
from .carlini_wagner import CarliniWagnerL2Attack
|
from .carlini_wagner import CarliniWagnerL2Attack
|
||||||
@@ -40,6 +41,7 @@ __all__ = ['FastGradientMethod',
|
|||||||
'IterativeGradientMethod',
|
'IterativeGradientMethod',
|
||||||
'BasicIterativeMethod',
|
'BasicIterativeMethod',
|
||||||
'MomentumIterativeMethod',
|
'MomentumIterativeMethod',
|
||||||
|
'VarianceTuningMomentumIterativeMethod',
|
||||||
'ProjectedGradientDescent',
|
'ProjectedGradientDescent',
|
||||||
'DiverseInputIterativeMethod',
|
'DiverseInputIterativeMethod',
|
||||||
'MomentumDiverseInputIterativeMethod',
|
'MomentumDiverseInputIterativeMethod',
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from mindarmour.adv_robustness.attacks import ProjectedGradientDescent
|
|||||||
from mindarmour.adv_robustness.attacks import IterativeGradientMethod
|
from mindarmour.adv_robustness.attacks import IterativeGradientMethod
|
||||||
from mindarmour.adv_robustness.attacks import DiverseInputIterativeMethod
|
from mindarmour.adv_robustness.attacks import DiverseInputIterativeMethod
|
||||||
from mindarmour.adv_robustness.attacks import MomentumDiverseInputIterativeMethod
|
from mindarmour.adv_robustness.attacks import MomentumDiverseInputIterativeMethod
|
||||||
|
from mindarmour.adv_robustness.attacks import VarianceTuningMomentumIterativeMethod
|
||||||
|
|
||||||
|
|
||||||
# for user
|
# for user
|
||||||
@@ -56,6 +56,31 @@ class Net(Cell):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class FlattenNet(Cell):
|
||||||
|
"""
|
||||||
|
Construct the network of target model.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = FlattenNet()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(FlattenNet, self).__init__()
|
||||||
|
self._flatten = P.Flatten()
|
||||||
|
self._softmax = P.Softmax()
|
||||||
|
|
||||||
|
def construct(self, inputs):
|
||||||
|
"""
|
||||||
|
Construct flatten network.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (Tensor): Input data.
|
||||||
|
"""
|
||||||
|
out = self._flatten(inputs)
|
||||||
|
out = self._softmax(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
@pytest.mark.platform_arm_ascend_training
|
@pytest.mark.platform_arm_ascend_training
|
||||||
@pytest.mark.platform_x86_ascend_training
|
@pytest.mark.platform_x86_ascend_training
|
||||||
@@ -354,3 +379,31 @@ def test_error_cpu():
|
|||||||
assert attack.generate(input_np, label)
|
assert attack.generate(input_np, label)
|
||||||
del input_np, label
|
del input_np, label
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_card
|
||||||
|
@pytest.mark.component_mindarmour
|
||||||
|
def test_variance_tuning_momentum_iterative_method_cpu():
|
||||||
|
"""
|
||||||
|
Feature: Variance Tuning Momentum iterative method unit test for cpu
|
||||||
|
Description: Given multiple images, we want to make sure the adversarial examples
|
||||||
|
generated are different from the images
|
||||||
|
Expectation: input_np != ms_adv_x
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||||
|
input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32)
|
||||||
|
label = np.asarray([2], np.int32)
|
||||||
|
label = np.eye(3)[label].astype(np.float32)
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
attack = VarianceTuningMomentumIterativeMethod(FlattenNet(), nb_iter=i + 1,
|
||||||
|
loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
|
||||||
|
ms_adv_x = attack.generate(input_np, label)
|
||||||
|
assert np.any(ms_adv_x != input_np), 'Variance Tuning Momentum iterative method: generate' \
|
||||||
|
' value must not be equal to' \
|
||||||
|
' original value.'
|
||||||
|
del input_np, label, ms_adv_x
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
Reference in New Issue
Block a user