|
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-
- import os
- import unittest
-
- import numpy as np
-
- from megenginelite import *
-
- set_log_level(2)
-
-
- class TestShuffleNet(unittest.TestCase):
- source_dir = os.getenv("LITE_TEST_RESOUCE")
- input_data_path = os.path.join(source_dir, "input_data.npy")
- correct_data_path = os.path.join(source_dir, "output_data.npy")
- correct_data = np.load(correct_data_path).flatten()
- input_data = np.load(input_data_path)
-
- def check_correct(self, out_data, error=1e-4):
- out_data = out_data.flatten()
- assert np.isfinite(out_data.sum())
- assert self.correct_data.size == out_data.size
- for i in range(out_data.size):
- assert abs(out_data[i] - self.correct_data[i]) < error
-
- def do_forward(self, network, times=3):
- input_name = network.get_input_name(0)
- input_tensor = network.get_io_tensor(input_name)
- output_name = network.get_output_name(0)
- output_tensor = network.get_io_tensor(output_name)
-
- input_tensor.set_data_by_copy(self.input_data)
- for i in range(times):
- network.forward()
- network.wait()
-
- output_data = output_tensor.to_numpy()
- self.check_correct(output_data)
-
-
- class TestGlobal(TestShuffleNet):
- def test_device_count(self):
- LiteGlobal.try_coalesce_all_free_memory()
- count = LiteGlobal.get_device_count(LiteDeviceType.LITE_CPU)
- assert count > 0
-
- def test_register_decryption_method(self):
- @decryption_func
- def function(in_arr, key_arr, out_arr):
- if not out_arr:
- return in_arr.size
- else:
- for i in range(in_arr.size):
- out_arr[i] = in_arr[i] ^ key_arr[0] ^ key_arr[0]
- return out_arr.size
-
- LiteGlobal.register_decryption_and_key("just_for_test", function, [15])
- config = LiteConfig()
- config.bare_model_cryption_name = "just_for_test".encode("utf-8")
-
- network = LiteNetwork()
- model_path = os.path.join(self.source_dir, "shufflenet.mge")
- network.load(model_path)
-
- self.do_forward(network)
-
- def test_update_decryption_key(self):
- wrong_key = [0] * 32
- LiteGlobal.update_decryption_key("AES_default", wrong_key)
-
- with self.assertRaises(RuntimeError):
- config = LiteConfig()
- config.bare_model_cryption_name = "AES_default".encode("utf-8")
- network = LiteNetwork(config)
- model_path = os.path.join(self.source_dir, "shufflenet_crypt_aes.mge")
- network.load(model_path)
-
- right_key = [i for i in range(32)]
- LiteGlobal.update_decryption_key("AES_default", right_key)
-
- config = LiteConfig()
- config.bare_model_cryption_name = "AES_default".encode("utf-8")
- network = LiteNetwork(config)
- model_path = os.path.join(self.source_dir, "shufflenet_crypt_aes.mge")
- network.load(model_path)
-
- self.do_forward(network)
|