!591 fix cifar inversion case

Merge pull request !591 from 刘思铭/master
This commit is contained in:
刘思铭
2025-04-14 07:55:59 +00:00
committed by Gitee
4 changed files with 25 additions and 21 deletions

View File

@@ -159,7 +159,8 @@ def create_dataset_cifar(data_path, image_height, image_width, repeat_num=1, tra
"""
create data for next use such as training or inferring
"""
cifar_ds = ds.Cifar10Dataset(data_path)
usage = "train" if training else "test"
cifar_ds = ds.Cifar10Dataset(dataset_dir=data_path, usage=usage)
resize_height = image_height # 224
resize_width = image_width # 224
rescale = 1.0 / 255.0
@@ -190,11 +191,11 @@ def create_dataset_cifar(data_path, image_height, image_width, repeat_num=1, tra
return cifar_ds
def generate_dataset_cifar(data_path, batch_size, repeat_num=1):
def generate_dataset_cifar(data_path, batch_size, usage, repeat_num=1):
"""
create data for next use such as training or inferring
"""
cifar_ds = ds.Cifar10Dataset(data_path)
cifar_ds = ds.Cifar10Dataset(data_path, usage=usage)
resize_height = 32
resize_width = 32
rescale = 1.0 / 255.0

View File

@@ -38,7 +38,8 @@ def cifar_train(epoch_size, lr, momentum):
Generate Dataset and Train
"""
mnist_path = "../../dataset/CIFAR10"
ds = create_dataset_cifar(os.path.join(mnist_path, "train"), 32, 32, repeat_num=1)
# ds = create_dataset_cifar(os.path.join(mnist_path, "train"), 32, 32, repeat_num=1)
ds = create_dataset_cifar(mnist_path, 32, 32, repeat_num=1)
network = CIFAR10CNN()
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
@@ -58,11 +59,11 @@ def cifar_train(epoch_size, lr, momentum):
ckpt_file_name = "trained_ckpt_file/checkpoint_cifar-10_1562.ckpt"
param_dict = load_checkpoint(ckpt_file_name)
load_param_into_net(network, param_dict)
ds_eval = create_dataset_cifar(os.path.join(mnist_path, "test"), 32, 32, repeat_num=1)
ds_eval = create_dataset_cifar(mnist_path, 32, 32, repeat_num=1, training=False)
acc = model.eval(ds_eval, dataset_sink_mode=False)
LOGGER.info(TAG, "============== Accuracy: %s ==============", acc)
if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
cifar_train(10, 0.01, 0.9)

View File

@@ -22,7 +22,8 @@ import mindspore as ms
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore import Tensor, context
from mindarmour.privacy.evaluation.model_inversion_attack import ModelInversionAttack
from mindarmour.privacy.evaluation.inversion_attack.inversion_net import CIFAR10CNNDecoderConv11
# from mindarmour.privacy.evaluation.inversion_attack.inversion_net import CIFAR10CNNDecoderConv11
from examples.privacy.inversion_attack.inversion_net import CIFAR10CNNDecoderConv11
from mindarmour.utils.logger import LogUtil
from examples.common.networks.cifar10cnn.cifar10cnn_net import CIFAR10CNN
@@ -49,17 +50,17 @@ def cifar_inversion_attack(net, inv_net, ckptpath):
load_param_into_net(net, load_dict)
# get original data and their inferred fearures
data_list = "../../common/dataset/CIFAR10/train"
ds = generate_dataset_cifar(data_list, 32, repeat_num=1)
data_list = "../..//common/dataset/CIFAR10/test"
ds_test = generate_dataset_cifar(data_list, 32, repeat_num=1)
data_list = "../../common/dataset/CIFAR10" #/train
ds = generate_dataset_cifar(data_list, 32, usage="train", repeat_num=1)
data_list = "../../common/dataset/CIFAR10" #/test
ds_test = generate_dataset_cifar(data_list, 32, usage="test", repeat_num=1)
i = 0
batch_num = 1
sample_num = 10
for data in ds_test.create_tuple_iterator(output_numpy=True):
i += 1
images = data[0].astype(np.float32)
target_features = net.getlayeroutput(Tensor(images), 'conv11')[:sample_num]
target_features = net.get_layer_output(Tensor(images), 'conv11')[:sample_num]
if i >= batch_num:
break
@@ -91,6 +92,6 @@ def cifar_inversion_attack(net, inv_net, ckptpath):
if __name__ == '__main__':
# device_target can be "CPU", "GPU" or "Ascend"
context.set_context(mode=ms.PYNATIVE_MODE, device_target="GPU")
context.set_context(mode=ms.PYNATIVE_MODE, device_target="CPU")
ckpt_path = '../../common/networks/cifar10cnn/trained_ckpt_file/checkpoint_cifar-10_1562.ckpt'
cifar_inversion_attack(CIFAR10CNN(), CIFAR10CNNDecoderConv11(), ckpt_path)

View File

@@ -48,7 +48,7 @@ class ModelInversionLoss(nn.Cell):
self._network.set_train(False)
def construct(self, inputs):
orginal_model_output = self._network.getLayerOutput(inputs, self._target_layer)
orginal_model_output = self._network.get_layer_output(inputs, self._target_layer)
decoder_model_output = self._invnet(orginal_model_output)
loss = self._loss_fn(inputs, decoder_model_output)
return loss
@@ -72,9 +72,9 @@ class ModelInversionAttack:
"""
def __init__(self, network, inv_network, input_shape, ckpoint_path=None, split_layer='conv1'):
self._network = check_param_type('network', network, nn.Cell)
self._invnetwork = check_param_type('inv_network', inv_network, nn.Cell)
self._invnet = check_param_type('inv_network', inv_network, nn.Cell)
self._split_layer = check_param_type('split_layer', split_layer, str)
self._ckpath = check_param_type('ckpoint_path', ckpoint_path, str)
self._ckpath = ckpoint_path #check_param_type('ckpoint_path', ckpoint_path, str)
self.check_inv_network(input_shape)
if self._ckpath is None:
self._ckpath = './trained_inv_ckpt_file'
@@ -85,8 +85,8 @@ class ModelInversionAttack:
def check_inv_network(self, input_shape):
input_shape = check_param_type('input_shape', input_shape, tuple)
inputs = ms.numpy.ones((1,) + input_shape)
orginal_model_output = self._network.getLayerOutput(inputs, self._split_layer)
inv_model_output = self._invnetwork(orginal_model_output)
orginal_model_output = self._network.get_layer_output(inputs, self._split_layer)
inv_model_output = self._invnet(orginal_model_output)
if inputs.shape != inv_model_output.shape:
msg = "InvModel error, input shape is {}, but invmodel output shape is {}" \
.format(inputs.shape, inv_model_output.shape)
@@ -121,6 +121,8 @@ class ModelInversionAttack:
optim = nn.Adam(self._invnet.trainable_params(), learning_rate=learningrate, eps=eps, use_amsgrad=apply_ams)
net = ModelInversionLoss(self._network, self._invnet, net_loss, self._split_layer)
net = nn.TrainOneStepCell(net, optim)
if not os.path.exists(self._ckpath):
os.makedirs(self._ckpath)
for epoch in range(epochs):
loss = 0
@@ -128,8 +130,7 @@ class ModelInversionAttack:
loss += net(Tensor(inputs)).asnumpy()
LOGGER.info(TAG, "Epoch: {}, Loss: {}".format(epoch, loss))
if epoch % 10 == 0:
ms.save_checkpoint(self._invnet, os.path.join(self._ckpath, '/invmodel_{}_{}.ckpt'
.format(self._split_layer, epoch)))
ms.save_checkpoint(self._invnet, self._ckpath + '/invmodel_{}_{}.ckpt'.format(self._split_layer, epoch))
def evaluate(self, dataset):
"""
@@ -149,7 +150,7 @@ class ModelInversionAttack:
total_psnr = 0
size = 0
for inputs, _ in dataset.create_tuple_iterator():
orginal_model_output = self._network.getLayerOutput(Tensor(inputs), self._split_layer)
orginal_model_output = self._network.get_layer_output(Tensor(inputs), self._split_layer)
decoder_model_output = self._invnet(orginal_model_output)
decoder_model_output = decoder_model_output.clip(0, 1)
for i in range(inputs.shape[0]):