|
- /**
- * \file test/test_misc.cpp
- * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- *
- * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- */
-
- #include "lite_build_config.h"
-
- #if LITE_BUILD_WITH_MGE
- #include "test_common.h"
- #include "../src/decryption/decrypt_base.h"
- #include "../src/network_impl_base.h"
-
- #include "megbrain/opr/io.h"
- #include "megbrain/tensor.h"
- #include "megbrain/utils/metahelper.h"
-
- #include <gtest/gtest.h>
-
- #include <string.h>
- #include <chrono>
- #include <memory>
- #include <random>
-
- using namespace lite;
-
- TEST(TestMisc, DecryptionRegister) {
- size_t number = decryption_static_data().decryption_methods.size();
- //! At least one method is register by lite
- ASSERT_GE(number, 1);
- DecryptionFunc func;
- register_decryption_and_key("AllForTest0", func, {});
-
- ASSERT_EQ(number + 1, decryption_static_data().decryption_methods.size());
- }
-
- TEST(TestMisc, DecryptionUpdate) {
- DecryptionFunc func;
- register_decryption_and_key("AllForTest1", func, {});
- func = [](const void*, size_t,
- const std::vector<uint8_t>&) -> std::vector<uint8_t> {
- return {};
- };
- update_decryption_or_key("AllForTest1", func, {});
- ASSERT_NE(decryption_static_data().decryption_methods["AllForTest1"].first,
- nullptr);
- ASSERT_EQ(decryption_static_data()
- .decryption_methods["AllForTest1"]
- .second->size(),
- 0);
- update_decryption_or_key("AllForTest1", {}, {1, 2, 3});
- ASSERT_EQ(decryption_static_data()
- .decryption_methods["AllForTest1"]
- .second->size(),
- 3);
- }
-
- TEST(TestMisc, SharedSameDeviceTensor) {
- using namespace mgb;
- serialization::GraphLoader::LoadConfig mgb_config;
- mgb_config.comp_node_mapper = [](CompNode::Locator& loc) {
- loc = to_compnode_locator(LiteDeviceType::LITE_CPU);
- };
- mgb_config.comp_graph = ComputingGraph::make();
- std::string model_path = "./shufflenet.mge";
-
- auto inp_file = mgb::serialization::InputFile::make_fs(model_path.c_str());
- auto format =
- serialization::GraphLoader::identify_graph_dump_format(*inp_file);
- mgb_assert(format.valid(),
- "invalid model: unknown model format, please make sure input "
- "file is generated by GraphDumper");
- auto loader =
- serialization::GraphLoader::make(std::move(inp_file), format.val());
- auto load_ret_1 = loader->load(mgb_config, true);
- auto load_ret_2 = loader->load(mgb_config, true);
- ASSERT_EQ(load_ret_1.output_var_list.size(),
- load_ret_2.output_var_list.size());
-
- ComputingGraph::OutputSpec out_spec_1, out_spec_2;
- for (size_t i = 0; i < load_ret_1.output_var_list.size(); i++) {
- out_spec_1.emplace_back(load_ret_1.output_var_list[i], nullptr);
- out_spec_2.emplace_back(load_ret_2.output_var_list[i], nullptr);
- }
- auto func_1 = load_ret_1.graph_compile(out_spec_1);
- auto func_2 = load_ret_2.graph_compile(out_spec_1);
- std::vector<cg::OperatorNodeBase*> oprs_1, oprs_2;
- func_1->iter_opr_seq([&oprs_1](cg::OperatorNodeBase* opr) -> bool {
- if (opr->try_cast_final<opr::ImmutableTensor>()) {
- oprs_1.push_back(opr);
- }
- return true;
- });
- func_1->iter_opr_seq([&oprs_2](cg::OperatorNodeBase* opr) -> bool {
- if (opr->try_cast_final<opr::ImmutableTensor>()) {
- oprs_2.push_back(opr);
- }
- return true;
- });
- ASSERT_EQ(oprs_1.size(), oprs_2.size());
- for (size_t i = 0; i < oprs_1.size(); i++) {
- auto tensor_1 =
- oprs_1[i]->try_cast_final<opr::ImmutableTensor>()->value();
- auto tensor_2 =
- oprs_2[i]->try_cast_final<opr::ImmutableTensor>()->value();
- ASSERT_EQ(tensor_1.raw_ptr(), tensor_2.raw_ptr());
- }
- }
-
- #endif
-
- // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
|