|
|
@@ -12,6 +12,7 @@ |
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
# limitations under the License. |
|
|
|
"""defense example using nad""" |
|
|
|
import os |
|
|
|
import sys |
|
|
|
|
|
|
|
import numpy as np |
|
|
@@ -19,41 +20,43 @@ from mindspore import Tensor |
|
|
|
from mindspore import context |
|
|
|
from mindspore import nn |
|
|
|
from mindspore.nn import SoftmaxCrossEntropyWithLogits |
|
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net |
|
|
|
from mindspore.train import Model |
|
|
|
from mindspore.train.callback import LossMonitor |
|
|
|
|
|
|
|
from lenet5_net import LeNet5 |
|
|
|
from mindarmour.attacks import FastGradientSignMethod |
|
|
|
from mindarmour.defenses import NaturalAdversarialDefense |
|
|
|
from mindarmour.utils.logger import LogUtil |
|
|
|
|
|
|
|
from lenet5_net import LeNet5 |
|
|
|
|
|
|
|
sys.path.append("..") |
|
|
|
from data_processing import generate_mnist_dataset |
|
|
|
|
|
|
|
|
|
|
|
LOGGER = LogUtil.get_instance() |
|
|
|
LOGGER.set_level("INFO") |
|
|
|
TAG = 'Nad_Example' |
|
|
|
|
|
|
|
|
|
|
|
def test_nad_method(): |
|
|
|
""" |
|
|
|
NAD-Defense test for CPU device. |
|
|
|
NAD-Defense test. |
|
|
|
""" |
|
|
|
# 1. load trained network |
|
|
|
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' |
|
|
|
mnist_path = "./MNIST_unzip/" |
|
|
|
batch_size = 32 |
|
|
|
# 1. train original model |
|
|
|
ds_train = generate_mnist_dataset(os.path.join(mnist_path, "train"), |
|
|
|
batch_size=batch_size, repeat_size=1) |
|
|
|
net = LeNet5() |
|
|
|
load_dict = load_checkpoint(ckpt_name) |
|
|
|
load_param_into_net(net, load_dict) |
|
|
|
|
|
|
|
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) |
|
|
|
opt = nn.Momentum(net.trainable_params(), 0.01, 0.09) |
|
|
|
|
|
|
|
nad = NaturalAdversarialDefense(net, loss_fn=loss, optimizer=opt, |
|
|
|
bounds=(0.0, 1.0), eps=0.3) |
|
|
|
model = Model(net, loss, opt, metrics=None) |
|
|
|
model.train(10, ds_train, callbacks=[LossMonitor()], |
|
|
|
dataset_sink_mode=False) |
|
|
|
|
|
|
|
# 2. get test data |
|
|
|
data_list = "./MNIST_unzip/test" |
|
|
|
batch_size = 32 |
|
|
|
ds_test = generate_mnist_dataset(data_list, batch_size=batch_size) |
|
|
|
ds_test = generate_mnist_dataset(os.path.join(mnist_path, "test"), |
|
|
|
batch_size=batch_size, repeat_size=1) |
|
|
|
inputs = [] |
|
|
|
labels = [] |
|
|
|
for data in ds_test.create_tuple_iterator(): |
|
|
@@ -73,16 +76,15 @@ def test_nad_method(): |
|
|
|
label_pred = np.argmax(logits, axis=1) |
|
|
|
acc_list.append(np.mean(batch_labels == label_pred)) |
|
|
|
|
|
|
|
LOGGER.debug(TAG, 'accuracy of TEST data on original model is : %s', |
|
|
|
np.mean(acc_list)) |
|
|
|
LOGGER.info(TAG, 'accuracy of TEST data on original model is : %s', |
|
|
|
np.mean(acc_list)) |
|
|
|
|
|
|
|
# 4. get adv of test data |
|
|
|
attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) |
|
|
|
adv_data = attack.batch_generate(inputs, labels) |
|
|
|
LOGGER.debug(TAG, 'adv_data.shape is : %s', adv_data.shape) |
|
|
|
LOGGER.info(TAG, 'adv_data.shape is : %s', adv_data.shape) |
|
|
|
|
|
|
|
# 5. get accuracy of adv data on original model |
|
|
|
net.set_train(False) |
|
|
|
acc_list = [] |
|
|
|
batchs = adv_data.shape[0] // batch_size |
|
|
|
for i in range(batchs): |
|
|
@@ -92,11 +94,13 @@ def test_nad_method(): |
|
|
|
label_pred = np.argmax(logits, axis=1) |
|
|
|
acc_list.append(np.mean(batch_labels == label_pred)) |
|
|
|
|
|
|
|
LOGGER.debug(TAG, 'accuracy of adv data on original model is : %s', |
|
|
|
np.mean(acc_list)) |
|
|
|
LOGGER.info(TAG, 'accuracy of adv data on original model is : %s', |
|
|
|
np.mean(acc_list)) |
|
|
|
|
|
|
|
# 6. defense |
|
|
|
net.set_train() |
|
|
|
nad = NaturalAdversarialDefense(net, loss_fn=loss, optimizer=opt, |
|
|
|
bounds=(0.0, 1.0), eps=0.3) |
|
|
|
nad.batch_defense(inputs, labels, batch_size=32, epochs=10) |
|
|
|
|
|
|
|
# 7. get accuracy of test data on defensed model |
|
|
@@ -110,8 +114,8 @@ def test_nad_method(): |
|
|
|
label_pred = np.argmax(logits, axis=1) |
|
|
|
acc_list.append(np.mean(batch_labels == label_pred)) |
|
|
|
|
|
|
|
LOGGER.debug(TAG, 'accuracy of TEST data on defensed model is : %s', |
|
|
|
np.mean(acc_list)) |
|
|
|
LOGGER.info(TAG, 'accuracy of TEST data on defensed model is : %s', |
|
|
|
np.mean(acc_list)) |
|
|
|
|
|
|
|
# 8. get accuracy of adv data on defensed model |
|
|
|
acc_list = [] |
|
|
@@ -123,11 +127,11 @@ def test_nad_method(): |
|
|
|
label_pred = np.argmax(logits, axis=1) |
|
|
|
acc_list.append(np.mean(batch_labels == label_pred)) |
|
|
|
|
|
|
|
LOGGER.debug(TAG, 'accuracy of adv data on defensed model is : %s', |
|
|
|
np.mean(acc_list)) |
|
|
|
LOGGER.info(TAG, 'accuracy of adv data on defensed model is : %s', |
|
|
|
np.mean(acc_list)) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
# device_target can be "CPU", "GPU" or "Ascend" |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
test_nad_method() |