You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_misc.cpp 3.7 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. #include "lite_build_config.h"
  2. #if LITE_BUILD_WITH_MGE
  3. #include "../src/decryption/decrypt_base.h"
  4. #include "../src/network_impl_base.h"
  5. #include "test_common.h"
  6. #include "megbrain/opr/io.h"
  7. #include "megbrain/tensor.h"
  8. #include "megbrain/utils/metahelper.h"
  9. #include <gtest/gtest.h>
  10. #include <string.h>
  11. #include <chrono>
  12. #include <memory>
  13. #include <random>
  14. using namespace lite;
  15. TEST(TestMisc, DecryptionRegister) {
  16. size_t number = decryption_static_data().decryption_methods.size();
  17. //! At least one method is register by lite
  18. ASSERT_GE(number, 1);
  19. DecryptionFunc func;
  20. register_decryption_and_key("AllForTest0", func, {});
  21. ASSERT_EQ(number + 1, decryption_static_data().decryption_methods.size());
  22. }
  23. TEST(TestMisc, DecryptionUpdate) {
  24. DecryptionFunc func;
  25. register_decryption_and_key("AllForTest1", func, {});
  26. func = [](const void*, size_t,
  27. const std::vector<uint8_t>&) -> std::vector<uint8_t> { return {}; };
  28. update_decryption_or_key("AllForTest1", func, {});
  29. ASSERT_NE(
  30. decryption_static_data().decryption_methods["AllForTest1"].first, nullptr);
  31. ASSERT_EQ(
  32. decryption_static_data().decryption_methods["AllForTest1"].second->size(),
  33. 0);
  34. update_decryption_or_key("AllForTest1", {}, {1, 2, 3});
  35. ASSERT_EQ(
  36. decryption_static_data().decryption_methods["AllForTest1"].second->size(),
  37. 3);
  38. }
  39. TEST(TestMisc, SharedSameDeviceTensor) {
  40. using namespace mgb;
  41. serialization::GraphLoader::LoadConfig mgb_config;
  42. mgb_config.comp_node_mapper = [](CompNode::Locator& loc) {
  43. loc = to_compnode_locator(LiteDeviceType::LITE_CPU);
  44. };
  45. mgb_config.comp_graph = ComputingGraph::make();
  46. std::string model_path = "./shufflenet.mge";
  47. auto inp_file = mgb::serialization::InputFile::make_fs(model_path.c_str());
  48. auto format = serialization::GraphLoader::identify_graph_dump_format(*inp_file);
  49. mgb_assert(
  50. format.valid(),
  51. "invalid model: unknown model format, please make sure input "
  52. "file is generated by GraphDumper");
  53. auto loader = serialization::GraphLoader::make(std::move(inp_file), format.val());
  54. auto load_ret_1 = loader->load(mgb_config, true);
  55. auto load_ret_2 = loader->load(mgb_config, true);
  56. ASSERT_EQ(load_ret_1.output_var_list.size(), load_ret_2.output_var_list.size());
  57. ComputingGraph::OutputSpec out_spec_1, out_spec_2;
  58. for (size_t i = 0; i < load_ret_1.output_var_list.size(); i++) {
  59. out_spec_1.emplace_back(load_ret_1.output_var_list[i], nullptr);
  60. out_spec_2.emplace_back(load_ret_2.output_var_list[i], nullptr);
  61. }
  62. auto func_1 = load_ret_1.graph_compile(out_spec_1);
  63. auto func_2 = load_ret_2.graph_compile(out_spec_1);
  64. std::vector<cg::OperatorNodeBase*> oprs_1, oprs_2;
  65. func_1->iter_opr_seq([&oprs_1](cg::OperatorNodeBase* opr) -> bool {
  66. if (opr->try_cast_final<opr::ImmutableTensor>()) {
  67. oprs_1.push_back(opr);
  68. }
  69. return true;
  70. });
  71. func_1->iter_opr_seq([&oprs_2](cg::OperatorNodeBase* opr) -> bool {
  72. if (opr->try_cast_final<opr::ImmutableTensor>()) {
  73. oprs_2.push_back(opr);
  74. }
  75. return true;
  76. });
  77. ASSERT_EQ(oprs_1.size(), oprs_2.size());
  78. for (size_t i = 0; i < oprs_1.size(); i++) {
  79. auto tensor_1 = oprs_1[i]->try_cast_final<opr::ImmutableTensor>()->value();
  80. auto tensor_2 = oprs_2[i]->try_cast_final<opr::ImmutableTensor>()->value();
  81. ASSERT_EQ(tensor_1.raw_ptr(), tensor_2.raw_ptr());
  82. }
  83. }
  84. #endif
  85. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}