mirror of
https://gitee.com/mindspore/mindarmour.git
synced 2025-12-06 11:59:05 +08:00
!241 fix bug of salt and pepper attack
Merge pull request !241 from ZhidanLiu/master
This commit is contained in:
@@ -86,9 +86,7 @@ def test_salt_and_pepper_attack_on_mnist():
|
||||
|
||||
# attacking
|
||||
is_target = False
|
||||
attack = SaltAndPepperNoiseAttack(model=model,
|
||||
is_targeted=is_target,
|
||||
sparse=True)
|
||||
attack = SaltAndPepperNoiseAttack(model=model, is_targeted=is_target, sparse=True)
|
||||
if is_target:
|
||||
targeted_labels = np.random.randint(0, 10, size=len(true_labels))
|
||||
for i, true_l in enumerate(true_labels):
|
||||
@@ -97,8 +95,7 @@ def test_salt_and_pepper_attack_on_mnist():
|
||||
else:
|
||||
targeted_labels = true_labels
|
||||
LOGGER.debug(TAG, 'input shape is: {}'.format(np.concatenate(test_images).shape))
|
||||
success_list, adv_data, query_list = attack.generate(
|
||||
np.concatenate(test_images), targeted_labels)
|
||||
success_list, adv_data, query_list = attack.generate(np.concatenate(test_images), targeted_labels)
|
||||
success_list = np.arange(success_list.shape[0])[success_list]
|
||||
LOGGER.info(TAG, 'success_list: %s', success_list)
|
||||
LOGGER.info(TAG, 'average of query times is : %s', np.mean(query_list))
|
||||
@@ -110,21 +107,16 @@ def test_salt_and_pepper_attack_on_mnist():
|
||||
adv_preds.extend(pred_logits_adv)
|
||||
adv_preds = np.array(adv_preds)
|
||||
accuracy_adv = np.mean(np.equal(np.max(adv_preds, axis=1), true_labels))
|
||||
LOGGER.info(TAG, "prediction accuracy after attacking is : %g",
|
||||
accuracy_adv)
|
||||
LOGGER.info(TAG, "prediction accuracy after attacking is : %g", accuracy_adv)
|
||||
test_labels_onehot = np.eye(10)[true_labels]
|
||||
attack_evaluate = AttackEvaluate(np.concatenate(test_images),
|
||||
test_labels_onehot, adv_data,
|
||||
adv_preds, targeted=is_target,
|
||||
target_label=targeted_labels)
|
||||
LOGGER.info(TAG, 'mis-classification rate of adversaries is : %s',
|
||||
attack_evaluate.mis_classification_rate())
|
||||
LOGGER.info(TAG, 'The average confidence of adversarial class is : %s',
|
||||
attack_evaluate.avg_conf_adv_class())
|
||||
LOGGER.info(TAG, 'The average confidence of true class is : %s',
|
||||
attack_evaluate.avg_conf_true_class())
|
||||
LOGGER.info(TAG, 'The average distance (l0, l2, linf) between original '
|
||||
'samples and adversarial samples are: %s',
|
||||
LOGGER.info(TAG, 'mis-classification rate of adversaries is : %s', attack_evaluate.mis_classification_rate())
|
||||
LOGGER.info(TAG, 'The average confidence of adversarial class is : %s', attack_evaluate.avg_conf_adv_class())
|
||||
LOGGER.info(TAG, 'The average confidence of true class is : %s', attack_evaluate.avg_conf_true_class())
|
||||
LOGGER.info(TAG, 'The average distance (l0, l2, linf) between original samples and adversarial samples are: %s',
|
||||
attack_evaluate.avg_lp_distance())
|
||||
|
||||
|
||||
|
||||
@@ -15,12 +15,10 @@
|
||||
SaltAndPepperNoise-Attack.
|
||||
"""
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindarmour.utils._check_param import check_model, check_pair_numpy_param, \
|
||||
check_param_type, check_int_positive, check_param_multi_types
|
||||
from mindarmour.utils._check_param import normalize_value
|
||||
from mindarmour.utils.logger import LogUtil
|
||||
from ..attack import Attack
|
||||
from .black_model import BlackModel
|
||||
@@ -31,26 +29,21 @@ TAG = 'SaltAndPepperNoise-Attack'
|
||||
|
||||
class SaltAndPepperNoiseAttack(Attack):
|
||||
"""
|
||||
Increases the amount of salt and pepper noise to generate adversarial
|
||||
samples.
|
||||
Increases the amount of salt and pepper noise to generate adversarial samples.
|
||||
|
||||
Args:
|
||||
model (BlackModel): Target model.
|
||||
bounds (tuple): Upper and lower bounds of data. In form of (clip_min,
|
||||
clip_max). Default: (0.0, 1.0)
|
||||
max_iter (int): Max iteration to generate an adversarial example.
|
||||
Default: 100
|
||||
is_targeted (bool): If True, targeted attack. If False, untargeted
|
||||
attack. Default: False.
|
||||
sparse (bool): If True, input labels are sparse-encoded. If False,
|
||||
input labels are one-hot-encoded. Default: True.
|
||||
bounds (tuple): Upper and lower bounds of data. In form of (clip_min, clip_max). Default: (0.0, 1.0)
|
||||
max_iter (int): Max iteration to generate an adversarial example. Default: 100
|
||||
is_targeted (bool): If True, targeted attack. If False, untargeted attack. Default: False.
|
||||
sparse (bool): If True, input labels are sparse-encoded. If False, input labels are one-hot-encoded.
|
||||
Default: True.
|
||||
|
||||
Examples:
|
||||
>>> attack = SaltAndPepperNoiseAttack(model)
|
||||
"""
|
||||
|
||||
def __init__(self, model, bounds=(0.0, 1.0), max_iter=100,
|
||||
is_targeted=False, sparse=True):
|
||||
def __init__(self, model, bounds=(0.0, 1.0), max_iter=100, is_targeted=False, sparse=True):
|
||||
super(SaltAndPepperNoiseAttack, self).__init__()
|
||||
self._model = check_model('model', model, BlackModel)
|
||||
self._bounds = check_param_multi_types('bounds', bounds, [tuple, list])
|
||||
@@ -76,12 +69,9 @@ class SaltAndPepperNoiseAttack(Attack):
|
||||
- numpy.ndarray, query times for each sample.
|
||||
|
||||
Examples:
|
||||
>>> adv_list = attack.generate(([[0.1, 0.2, 0.6],
|
||||
>>> [0.3, 0, 0.4]],
|
||||
>>> [1, 2])
|
||||
>>> adv_list = attack.generate(([[0.1, 0.2, 0.6], [0.3, 0, 0.4]], [1, 2])
|
||||
"""
|
||||
arr_x, arr_y = check_pair_numpy_param('inputs', inputs, 'labels',
|
||||
labels)
|
||||
arr_x, arr_y = check_pair_numpy_param('inputs', inputs, 'labels', labels)
|
||||
if not self._sparse:
|
||||
arr_y = np.argmax(arr_y, axis=1)
|
||||
|
||||
@@ -94,9 +84,8 @@ class SaltAndPepperNoiseAttack(Attack):
|
||||
is_adv_list.append(is_adv)
|
||||
adv_list.append(perturbed)
|
||||
query_times_each_adv.append(query_times)
|
||||
LOGGER.info(TAG, 'Finished one sample, adversarial is {}, '
|
||||
'cost time {:.2}s'
|
||||
.format(is_adv, time.time() - start_t))
|
||||
LOGGER.info(TAG, 'Finished one sample, adversarial is {}, cost time {:.2}s'.format(is_adv,
|
||||
time.time() - start_t))
|
||||
is_adv_list = np.array(is_adv_list)
|
||||
adv_list = np.array(adv_list)
|
||||
query_times_each_adv = np.array(query_times_each_adv)
|
||||
@@ -104,14 +93,12 @@ class SaltAndPepperNoiseAttack(Attack):
|
||||
|
||||
def _generate_one(self, one_input, label, epsilons=10):
|
||||
"""
|
||||
Increases the amount of salt and pepper noise to generate adversarial
|
||||
samples.
|
||||
Increases the amount of salt and pepper noise to generate adversarial samples.
|
||||
|
||||
Args:
|
||||
one_input (numpy.ndarray): The original, unperturbed input.
|
||||
label (numpy.ndarray): The target label.
|
||||
epsilons (int) : Number of steps to try probability between 0
|
||||
and 1. Default: 10
|
||||
epsilons (int) : Number of steps to try probability between 0 and 1. Default: 10
|
||||
|
||||
Returns:
|
||||
- numpy.ndarray, bool values for result.
|
||||
@@ -128,9 +115,7 @@ class SaltAndPepperNoiseAttack(Attack):
|
||||
high_ = 1.0
|
||||
query_count = 0
|
||||
input_shape = one_input.shape
|
||||
input_dtype = one_input.dtype
|
||||
one_input = one_input.reshape(-1)
|
||||
depth = np.abs(np.subtract(self._bounds[0], self._bounds[1]))
|
||||
best_adv = np.copy(one_input)
|
||||
best_eps = high_
|
||||
find_adv = False
|
||||
@@ -142,15 +127,11 @@ class SaltAndPepperNoiseAttack(Attack):
|
||||
noise = np.random.uniform(low=low_, high=high_, size=one_input.size)
|
||||
eps = (min_eps + max_eps) / 2
|
||||
# add salt
|
||||
adv[noise < eps] = -depth
|
||||
adv[noise < eps] = self._bounds[0]
|
||||
# add pepper
|
||||
adv[noise >= (high_ - eps)] = depth
|
||||
# normalized sample
|
||||
adv = normalize_value(np.expand_dims(adv, axis=0), 'l2').astype(input_dtype)
|
||||
adv[noise >= (high_ - eps)] = self._bounds[1]
|
||||
query_count += 1
|
||||
ite_bool = self._model.is_adversarial(adv.reshape(input_shape),
|
||||
label,
|
||||
is_targeted=self._is_targeted)
|
||||
ite_bool = self._model.is_adversarial(adv.reshape(input_shape), label, is_targeted=self._is_targeted)
|
||||
if ite_bool:
|
||||
find_adv = True
|
||||
if best_eps > eps:
|
||||
|
||||
@@ -14,22 +14,22 @@
|
||||
"""
|
||||
SaltAndPepper Attack Test
|
||||
"""
|
||||
import os
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.ops.operations as M
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore import context
|
||||
|
||||
from mindarmour import BlackModel
|
||||
from mindarmour.adv_robustness.attacks import SaltAndPepperNoiseAttack
|
||||
from tests.ut.python.utils.mock_net import Net
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
|
||||
# for user
|
||||
class ModelToBeAttacked(BlackModel):
|
||||
"""model to be attack"""
|
||||
|
||||
@@ -43,33 +43,6 @@ class ModelToBeAttacked(BlackModel):
|
||||
return result.asnumpy()
|
||||
|
||||
|
||||
# for user
|
||||
class SimpleNet(Cell):
|
||||
"""
|
||||
Construct the network of target model.
|
||||
|
||||
Examples:
|
||||
>>> net = SimpleNet()
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Introduce the layers used for network construction.
|
||||
"""
|
||||
super(SimpleNet, self).__init__()
|
||||
self._softmax = M.Softmax()
|
||||
|
||||
def construct(self, inputs):
|
||||
"""
|
||||
Construct network.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): Input data.
|
||||
"""
|
||||
out = self._softmax(inputs)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@@ -79,44 +52,21 @@ def test_salt_and_pepper_attack_method():
|
||||
"""
|
||||
Salt and pepper attack method unit test.
|
||||
"""
|
||||
batch_size = 6
|
||||
np.random.seed(123)
|
||||
net = SimpleNet()
|
||||
inputs = np.random.rand(batch_size, 10)
|
||||
# upload trained network
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
ckpt_path = os.path.join(current_dir,
|
||||
'../../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt')
|
||||
net = Net()
|
||||
load_dict = load_checkpoint(ckpt_path)
|
||||
load_param_into_net(net, load_dict)
|
||||
|
||||
# get one mnist image
|
||||
inputs = np.load(os.path.join(current_dir, '../../../dataset/test_images.npy'))[:3]
|
||||
labels = np.load(os.path.join(current_dir, '../../../dataset/test_labels.npy'))[:3]
|
||||
model = ModelToBeAttacked(net)
|
||||
labels = np.random.randint(low=0, high=10, size=batch_size)
|
||||
labels = np.eye(10)[labels]
|
||||
labels = labels.astype(np.float32)
|
||||
|
||||
attack = SaltAndPepperNoiseAttack(model, sparse=False)
|
||||
attack = SaltAndPepperNoiseAttack(model, sparse=True)
|
||||
_, adv_data, _ = attack.generate(inputs, labels)
|
||||
assert np.any(adv_data[0] != inputs[0]), 'Salt and pepper attack method: ' \
|
||||
'generate value must not be equal' \
|
||||
' to original value.'
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_card
|
||||
@pytest.mark.component_mindarmour
|
||||
def test_salt_and_pepper_attack_in_batch():
|
||||
"""
|
||||
Salt and pepper attack method unit test in batch.
|
||||
"""
|
||||
batch_size = 32
|
||||
np.random.seed(123)
|
||||
net = SimpleNet()
|
||||
inputs = np.random.rand(batch_size*2, 10)
|
||||
|
||||
model = ModelToBeAttacked(net)
|
||||
labels = np.random.randint(low=0, high=10, size=batch_size*2)
|
||||
labels = np.eye(10)[labels]
|
||||
labels = labels.astype(np.float32)
|
||||
|
||||
attack = SaltAndPepperNoiseAttack(model, sparse=False)
|
||||
adv_data = attack.batch_generate(inputs, labels, batch_size=32)
|
||||
assert np.any(adv_data[0] != inputs[0]), 'Salt and pepper attack method: ' \
|
||||
'generate value must not be equal' \
|
||||
assert np.any(adv_data[0] != inputs[0]), 'Salt and pepper attack method: generate value must not be equal' \
|
||||
' to original value.'
|
||||
|
||||
Reference in New Issue
Block a user