You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

preprocess.py 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """post process for 310 inference"""
  16. import os
  17. import sys
  18. import six
  19. import lmdb
  20. from PIL import Image
  21. from src.model_utils.config import config
  22. from src.util import CTCLabelConverter
  23. def get_img_from_lmdb(env_, ind):
  24. """Get image_from lmdb."""
  25. with env_.begin(write=False) as txn_:
  26. label_key = 'label-%09d'.encode() % ind
  27. label_ = txn_.get(label_key).decode('utf-8')
  28. img_key = 'image-%09d'.encode() % ind
  29. imgbuf = txn_.get(img_key)
  30. buf = six.BytesIO()
  31. buf.write(imgbuf)
  32. buf.seek(0)
  33. try:
  34. img_ = Image.open(buf).convert('RGB') # for color image
  35. except IOError:
  36. print(f'Corrupted image for {ind}')
  37. # make dummy image and dummy label for corrupted image.
  38. img_ = Image.new('RGB', (config.IMG_W, config.IMG_H))
  39. label_ = '[dummy_label]'
  40. label_ = label_.lower()
  41. return img_, label_
  42. if __name__ == '__main__':
  43. max_len = int((26 + 1) // 2)
  44. converter = CTCLabelConverter(config.CHARACTER)
  45. env = lmdb.open(config.TEST_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
  46. if not env:
  47. print('cannot create lmdb from %s' % (config.TEST_DATASET_PATH))
  48. sys.exit(0)
  49. with env.begin(write=False) as txn:
  50. n_samples = int(txn.get('num-samples'.encode()))
  51. n_samples = n_samples
  52. # Filtering
  53. filtered_index_list = []
  54. for index_ in range(n_samples):
  55. index_ += 1 # lmdb starts with 1
  56. label_key_ = 'label-%09d'.encode() % index_
  57. label = txn.get(label_key_).decode('utf-8')
  58. if len(label) > max_len:
  59. continue
  60. illegal_sample = False
  61. for char_item in label.lower():
  62. if char_item not in config.CHARACTER:
  63. illegal_sample = True
  64. break
  65. if illegal_sample:
  66. continue
  67. filtered_index_list.append(index_)
  68. img_ret = []
  69. text_ret = []
  70. print(f'num of samples in IIIT dataset: {len(filtered_index_list)}')
  71. i = 0
  72. label_dict = {}
  73. for index in filtered_index_list:
  74. img, label = get_img_from_lmdb(env, index)
  75. img_name = os.path.join(config.preprocess_output, str(i) + ".png")
  76. img.save(img_name)
  77. label_dict[str(i)] = label
  78. i += 1
  79. with open('./label.txt', 'w') as file:
  80. for k, v in label_dict.items():
  81. file.write(str(k) + ',' + str(v) + '\n')

MindArmour关注AI的安全和隐私问题。致力于增强模型的安全可信、保护用户的数据隐私。主要包含3个模块:对抗样本鲁棒性模块、Fuzz Testing模块、隐私保护与评估模块。 对抗样本鲁棒性模块 对抗样本鲁棒性模块用于评估模型对于对抗样本的鲁棒性,并提供模型增强方法用于增强模型抗对抗样本攻击的能力,提升模型鲁棒性。对抗样本鲁棒性模块包含了4个子模块:对抗样本的生成、对抗样本的检测、模型防御、攻防评估。