!481 add: add VarianceTuningMomentumIterativeMethod attack method.【华中科技大学网络安全学院】

Merge pull request !481 from YechaoZhang/master
This commit is contained in:
i-robot
2023-05-09 11:42:09 +00:00
committed by Gitee

View File

@@ -647,3 +647,109 @@ def _transform_inputs(inputs, prob, low=29, high=33, full_aug=False):
if not np.any(tran_outputs-raw_inputs):
LOGGER.error(TAG, 'the transform function does not take effect.')
return tran_outputs
class VarianceTuningMomentumIterativeMethod(MomentumIterativeMethod):
"""
VMI-FGSM is a momentum iterative method, it aggregate the gradient with variance on the input data
in each iteration, and could improve the transferability of the adversarial examples.
Reference: `X Wang, H Kun, "Enhancing the Transferability of Adversarial Attacks through Variance Tuning"
in CVPR, 2021 <https://arxiv.org/abs/2103.15571>`_.
Args:
network (Cell): Target model.
eps (float): Proportion of adversarial perturbation generated by the
attack to data range. Default: ``0.3``.
eps_iter (float): The proportion of perturbation in each step. Default: ``0.1``.
bounds (tuple): Upper and lower bounds of data, indicating the data range.
In form of (clip_min, clip_max). Default: ``(0.0, 1.0)``.
is_targeted (bool): If ``True``, targeted attack. If ``False``, untargeted
attack. Default: ``False``.
nb_iter (int) : Number of iteration. Default: ``5``.
decay_factor (float): The momentum factor. Default: ``1.0``.
nb_neighbor (int): The number of sampled examples in the neighborhood.
neighbor_beta (float): The upper bound of neighborhood. Default: ``3/2``.
norm_level (Union[int, str, numpy.inf]): Order of the norm. Possible values:
np.inf, 1 or 2. Default: ``inf``.
loss_fn (Union[Loss, None]): Loss function for optimization. If ``None``, the input network \
is already equipped with loss function. Default: ``None``.
Examples:
>>> from mindspore.ops import operations as P
>>> from mindarmour.adv_robustness.attacks import VarianceTuningMomentumIterativeMethod
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._softmax = P.Softmax()
... def construct(self, inputs):
... out = self._softmax(inputs)
... return out
>>> net = Net()
>>> loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False)
>>> attack = VarianceTuningMomentumIterativeMethod(net, nb_neighbor=5, neighbor_beta=3/2, loss_fn=loss_fn)
>>> inputs = np.asarray([[0.1, 0.2, 0.7]], np.float32)
>>> labels = np.asarray([2],np.int32)
>>> labels = np.eye(3)[labels].astype(np.float32)
>>> net = Net()
>>> adv_x = attack.generate(inputs, labels)
"""
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0),
is_targeted=False, nb_iter=5, decay_factor=1.0, nb_neighbor=5, neighbor_beta=3 / 2,
norm_level='inf', loss_fn=None):
super(VarianceTuningMomentumIterativeMethod, self).__init__(network,
eps=eps,
eps_iter=eps_iter,
bounds=bounds,
nb_iter=nb_iter,
loss_fn=loss_fn)
self._is_targeted = check_param_type('is_targeted', is_targeted, bool)
self._decay_factor = check_value_positive('decay_factor', decay_factor)
self._norm_level = check_norm_level(norm_level)
self._nb_neighbor = check_int_positive('nb_neighbor', nb_neighbor)
self._neighbor_beta = check_value_positive('neighbor_beta', neighbor_beta)
def generate(self, inputs, labels):
"""
Generate adversarial examples based on input data and origin/target labels.
Args:
inputs (Union[numpy.ndarray, tuple]): Benign input samples used as references to
create adversarial examples.
labels (Union[numpy.ndarray, tuple]): Original/target labels. \
For each input if it has more than one label, it is wrapped in a tuple.
Returns:
numpy.ndarray, generated adversarial examples.
"""
inputs_image, inputs, labels = check_inputs_labels(inputs, labels)
ndim = np.ndim(inputs_image)
if ndim < 4:
for _ in range(4-ndim):
inputs_image = np.expand_dims(inputs_image, axis=0)
momentum = 0
v = 0
adv_x = copy.deepcopy(inputs_image)
clip_min, clip_max = self._bounds
clip_diff = clip_max - clip_min
for _ in range(self._nb_iter):
adv_grad = self._gradient(adv_x, labels)
grad = (adv_grad + v) / (np.mean(np.abs(adv_x + v), axis=(1, 2, 3), keepdims=True) + 1e-12)
grad = grad + momentum * self._decay_factor
momentum = grad
gv_grad = 0
for _ in range(self._nb_neighbor):
neighbor_x = adv_x + np.random.uniform(-self._neighbor_beta * self._eps,
self._eps * self._neighbor_beta,
size=inputs_image.shape).astype(np.float32)
gv_grad = self._gradient(neighbor_x, labels) + gv_grad
v = gv_grad / self._nb_neighbor - adv_grad
adv_x = adv_x + self._eps_iter * np.sign(grad)
perturs = np.clip(adv_x - inputs_image, (0 - self._eps) * clip_diff,
self._eps * clip_diff)
adv_x = inputs_image + perturs
return adv_x