You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_global.py 3.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import os
  10. import unittest
  11. import numpy as np
  12. from megenginelite import *
  13. set_log_level(2)
  14. class TestShuffleNet(unittest.TestCase):
  15. source_dir = os.getenv("LITE_TEST_RESOUCE")
  16. input_data_path = os.path.join(source_dir, "input_data.npy")
  17. correct_data_path = os.path.join(source_dir, "output_data.npy")
  18. correct_data = np.load(correct_data_path).flatten()
  19. input_data = np.load(input_data_path)
  20. def check_correct(self, out_data, error=1e-4):
  21. out_data = out_data.flatten()
  22. assert np.isfinite(out_data.sum())
  23. assert self.correct_data.size == out_data.size
  24. for i in range(out_data.size):
  25. assert abs(out_data[i] - self.correct_data[i]) < error
  26. def do_forward(self, network, times=3):
  27. input_name = network.get_input_name(0)
  28. input_tensor = network.get_io_tensor(input_name)
  29. output_name = network.get_output_name(0)
  30. output_tensor = network.get_io_tensor(output_name)
  31. input_tensor.set_data_by_copy(self.input_data)
  32. for i in range(times):
  33. network.forward()
  34. network.wait()
  35. output_data = output_tensor.to_numpy()
  36. self.check_correct(output_data)
  37. class TestGlobal(TestShuffleNet):
  38. def test_device_count(self):
  39. LiteGlobal.try_coalesce_all_free_memory()
  40. count = LiteGlobal.get_device_count(LiteDeviceType.LITE_CPU)
  41. assert count > 0
  42. def test_register_decryption_method(self):
  43. @decryption_func
  44. def function(in_arr, key_arr, out_arr):
  45. if not out_arr:
  46. return in_arr.size
  47. else:
  48. for i in range(in_arr.size):
  49. out_arr[i] = in_arr[i] ^ key_arr[0] ^ key_arr[0]
  50. return out_arr.size
  51. LiteGlobal.register_decryption_and_key("just_for_test", function, [15])
  52. config = LiteConfig()
  53. config.bare_model_cryption_name = "just_for_test".encode("utf-8")
  54. network = LiteNetwork()
  55. model_path = os.path.join(self.source_dir, "shufflenet.mge")
  56. network.load(model_path)
  57. self.do_forward(network)
  58. def test_update_decryption_key(self):
  59. wrong_key = [0] * 32
  60. LiteGlobal.update_decryption_key("AES_default", wrong_key)
  61. with self.assertRaises(RuntimeError):
  62. config = LiteConfig()
  63. config.bare_model_cryption_name = "AES_default".encode("utf-8")
  64. network = LiteNetwork(config)
  65. model_path = os.path.join(self.source_dir, "shufflenet_crypt_aes.mge")
  66. network.load(model_path)
  67. right_key = [i for i in range(32)]
  68. LiteGlobal.update_decryption_key("AES_default", right_key)
  69. config = LiteConfig()
  70. config.bare_model_cryption_name = "AES_default".encode("utf-8")
  71. network = LiteNetwork(config)
  72. model_path = os.path.join(self.source_dir, "shufflenet_crypt_aes.mge")
  73. network.load(model_path)
  74. self.do_forward(network)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台