You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval_and_save.py 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """cnnctc eval"""
  16. import numpy as np
  17. import lmdb
  18. from mindspore import Tensor, context
  19. import mindspore.common.dtype as mstype
  20. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  21. from mindspore.dataset import GeneratorDataset
  22. from cnn_ctc.src.util import CTCLabelConverter
  23. from cnn_ctc.src.dataset import iiit_generator_batch, adv_iiit_generator_batch
  24. from cnn_ctc.src.cnn_ctc import CNNCTC
  25. from cnn_ctc.src.model_utils.config import config
  26. context.set_context(mode=context.GRAPH_MODE, save_graphs=False,
  27. save_graphs_path=".")
  28. def test_dataset_creator(is_adv=False):
  29. if is_adv:
  30. ds = GeneratorDataset(adv_iiit_generator_batch(), ['img', 'label_indices', 'text',
  31. 'sequence_length', 'label_str'])
  32. else:
  33. ds = GeneratorDataset(iiit_generator_batch, ['img', 'label_indices', 'text',
  34. 'sequence_length', 'label_str'])
  35. return ds
  36. def test(lmdb_save_path):
  37. """eval cnnctc model on begin and perturb data."""
  38. target = config.device_target
  39. context.set_context(device_target=target)
  40. ds = test_dataset_creator(is_adv=config.IS_ADV)
  41. net = CNNCTC(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH)
  42. ckpt_path = config.CHECKPOINT_PATH
  43. param_dict = load_checkpoint(ckpt_path)
  44. load_param_into_net(net, param_dict)
  45. print('parameters loaded! from: ', ckpt_path)
  46. converter = CTCLabelConverter(config.CHARACTER)
  47. count = 0
  48. correct_count = 0
  49. env_save = lmdb.open(lmdb_save_path, map_size=1099511627776)
  50. with env_save.begin(write=True) as txn_save:
  51. for data in ds.create_tuple_iterator():
  52. img, _, text, _, length = data
  53. img_tensor = Tensor(img, mstype.float32)
  54. model_predict = net(img_tensor)
  55. model_predict = np.squeeze(model_predict.asnumpy())
  56. preds_size = np.array([model_predict.shape[1]] * config.TEST_BATCH_SIZE)
  57. preds_index = np.argmax(model_predict, 2)
  58. preds_index = np.reshape(preds_index, [-1])
  59. preds_str = converter.decode(preds_index, preds_size)
  60. label_str = converter.reverse_encode(text.asnumpy(), length.asnumpy())
  61. print("Prediction samples: \n", preds_str[:5])
  62. print("Ground truth: \n", label_str[:5])
  63. for pred, label in zip(preds_str, label_str):
  64. if pred == label:
  65. correct_count += 1
  66. count += 1
  67. if config.IS_ADV:
  68. pred_key = 'adv_pred-%09d'.encode() % count
  69. else:
  70. pred_key = 'pred-%09d'.encode() % count
  71. txn_save.put(pred_key, pred.encode())
  72. accuracy = correct_count / count
  73. return accuracy
  74. if __name__ == '__main__':
  75. save_path = config.ADV_TEST_DATASET_PATH
  76. config.IS_ADV = False
  77. config.TEST_DATASET_PATH = save_path
  78. ori_acc = test(lmdb_save_path=save_path)
  79. config.IS_ADV = True
  80. adv_acc = test(lmdb_save_path=save_path)
  81. print('Accuracy of benign sample: ', ori_acc)
  82. print('Accuracy of perturbed sample: ', adv_acc)

MindArmour关注AI的安全和隐私问题。致力于增强模型的安全可信、保护用户的数据隐私。主要包含3个模块:对抗样本鲁棒性模块、Fuzz Testing模块、隐私保护与评估模块。 对抗样本鲁棒性模块 对抗样本鲁棒性模块用于评估模型对于对抗样本的鲁棒性,并提供模型增强方法用于增强模型抗对抗样本攻击的能力,提升模型鲁棒性。对抗样本鲁棒性模块包含了4个子模块:对抗样本的生成、对抗样本的检测、模型防御、攻防评估。