You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_global.py 3.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import os
  10. import unittest
  11. from ctypes import *
  12. import numpy as np
  13. from megenginelite import *
  14. set_log_level(2)
  15. class TestShuffleNet(unittest.TestCase):
  16. source_dir = os.getenv("LITE_TEST_RESOURCE")
  17. input_data_path = os.path.join(source_dir, "input_data.npy")
  18. correct_data_path = os.path.join(source_dir, "output_data.npy")
  19. correct_data = np.load(correct_data_path).flatten()
  20. input_data = np.load(input_data_path)
  21. def check_correct(self, out_data, error=1e-4):
  22. out_data = out_data.flatten()
  23. assert np.isfinite(out_data.sum())
  24. assert self.correct_data.size == out_data.size
  25. for i in range(out_data.size):
  26. assert abs(out_data[i] - self.correct_data[i]) < error
  27. def do_forward(self, network, times=3):
  28. input_name = network.get_input_name(0)
  29. input_tensor = network.get_io_tensor(input_name)
  30. output_name = network.get_output_name(0)
  31. output_tensor = network.get_io_tensor(output_name)
  32. input_tensor.set_data_by_copy(self.input_data)
  33. for i in range(times):
  34. network.forward()
  35. network.wait()
  36. output_data = output_tensor.to_numpy()
  37. self.check_correct(output_data)
  38. class TestGlobal(TestShuffleNet):
  39. def test_device_count(self):
  40. LiteGlobal.try_coalesce_all_free_memory()
  41. count = LiteGlobal.get_device_count(LiteDeviceType.LITE_CPU)
  42. assert count > 0
  43. def test_register_decryption_method(self):
  44. @decryption_func
  45. def function(in_arr, key_arr, out_arr):
  46. if not out_arr:
  47. return in_arr.size
  48. else:
  49. for i in range(in_arr.size):
  50. out_arr[i] = in_arr[i] ^ key_arr[0] ^ key_arr[0]
  51. return out_arr.size
  52. LiteGlobal.register_decryption_and_key("just_for_test", function, [15])
  53. config = LiteConfig()
  54. config.bare_model_cryption_name = "just_for_test".encode("utf-8")
  55. network = LiteNetwork()
  56. model_path = os.path.join(self.source_dir, "shufflenet.mge")
  57. network.load(model_path)
  58. self.do_forward(network)
  59. def test_set_get_memory_pair(self):
  60. if LiteGlobal.get_device_count(LiteDeviceType.LITE_AX) > 0:
  61. arr1 = np.ones([2, 3])
  62. arr2 = np.ones([2, 3])
  63. vir_ptr = arr1.ctypes.data_as(c_void_p)
  64. phy_ptr = arr2.ctypes.data_as(c_void_p)
  65. LiteGlobal.register_memory_pair(
  66. vir_ptr, phy_ptr, 10, LiteDeviceType.LITE_AX
  67. )
  68. phy_ptr2 = LiteGlobal.lookup_physic_ptr(vir_ptr, LiteDeviceType.LITE_AX)
  69. assert phy_ptr.value == phy_ptr2.value
  70. LiteGlobal.clear_memory_pair(vir_ptr, phy_ptr, LiteDeviceType.LITE_AX)