You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_misc.cpp 4.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. /**
  2. * \file test/test_misc.cpp
  3. *
  4. * This file is part of MegEngine, a deep learning framework developed by
  5. * Megvii.
  6. *
  7. * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
  8. */
  9. #include "lite_build_config.h"
  10. #if LITE_BUILD_WITH_MGE
  11. #include "test_common.h"
  12. #include "../src/decryption/decrypt_base.h"
  13. #include "../src/network_impl_base.h"
  14. #include "megbrain/opr/io.h"
  15. #include "megbrain/tensor.h"
  16. #include "megbrain/utils/metahelper.h"
  17. #include <gtest/gtest.h>
  18. #include <string.h>
  19. #include <chrono>
  20. #include <memory>
  21. #include <random>
  22. using namespace lite;
  23. TEST(TestMisc, DecryptionRegister) {
  24. size_t number = decryption_static_data().decryption_methods.size();
  25. //! At least one method is register by lite
  26. ASSERT_GE(number, 1);
  27. DecryptionFunc func;
  28. register_decryption_and_key("AllForTest0", func, {});
  29. ASSERT_EQ(number + 1, decryption_static_data().decryption_methods.size());
  30. }
  31. TEST(TestMisc, DecryptionUpdate) {
  32. DecryptionFunc func;
  33. register_decryption_and_key("AllForTest1", func, {});
  34. func = [](const void*, size_t,
  35. const std::vector<uint8_t>&) -> std::vector<uint8_t> {
  36. return {};
  37. };
  38. update_decryption_or_key("AllForTest1", func, {});
  39. ASSERT_NE(decryption_static_data().decryption_methods["AllForTest1"].first,
  40. nullptr);
  41. ASSERT_EQ(decryption_static_data()
  42. .decryption_methods["AllForTest1"]
  43. .second->size(),
  44. 0);
  45. update_decryption_or_key("AllForTest1", {}, {1, 2, 3});
  46. ASSERT_EQ(decryption_static_data()
  47. .decryption_methods["AllForTest1"]
  48. .second->size(),
  49. 3);
  50. }
  51. TEST(TestMisc, SharedSameDeviceTensor) {
  52. using namespace mgb;
  53. serialization::GraphLoader::LoadConfig mgb_config;
  54. mgb_config.comp_node_mapper = [](CompNode::Locator& loc) {
  55. loc = to_compnode_locator(LiteDeviceType::LITE_CPU);
  56. };
  57. mgb_config.comp_graph = ComputingGraph::make();
  58. std::string model_path = "./shufflenet.mge";
  59. auto inp_file = mgb::serialization::InputFile::make_fs(model_path.c_str());
  60. auto format =
  61. serialization::GraphLoader::identify_graph_dump_format(*inp_file);
  62. mgb_assert(format.valid(),
  63. "invalid model: unknown model format, please make sure input "
  64. "file is generated by GraphDumper");
  65. auto loader =
  66. serialization::GraphLoader::make(std::move(inp_file), format.val());
  67. auto load_ret_1 = loader->load(mgb_config, true);
  68. auto load_ret_2 = loader->load(mgb_config, true);
  69. ASSERT_EQ(load_ret_1.output_var_list.size(),
  70. load_ret_2.output_var_list.size());
  71. ComputingGraph::OutputSpec out_spec_1, out_spec_2;
  72. for (size_t i = 0; i < load_ret_1.output_var_list.size(); i++) {
  73. out_spec_1.emplace_back(load_ret_1.output_var_list[i], nullptr);
  74. out_spec_2.emplace_back(load_ret_2.output_var_list[i], nullptr);
  75. }
  76. auto func_1 = load_ret_1.graph_compile(out_spec_1);
  77. auto func_2 = load_ret_2.graph_compile(out_spec_1);
  78. std::vector<cg::OperatorNodeBase*> oprs_1, oprs_2;
  79. func_1->iter_opr_seq([&oprs_1](cg::OperatorNodeBase* opr) -> bool {
  80. if (opr->try_cast_final<opr::ImmutableTensor>()) {
  81. oprs_1.push_back(opr);
  82. }
  83. return true;
  84. });
  85. func_1->iter_opr_seq([&oprs_2](cg::OperatorNodeBase* opr) -> bool {
  86. if (opr->try_cast_final<opr::ImmutableTensor>()) {
  87. oprs_2.push_back(opr);
  88. }
  89. return true;
  90. });
  91. ASSERT_EQ(oprs_1.size(), oprs_2.size());
  92. for (size_t i = 0; i < oprs_1.size(); i++) {
  93. auto tensor_1 =
  94. oprs_1[i]->try_cast_final<opr::ImmutableTensor>()->value();
  95. auto tensor_2 =
  96. oprs_2[i]->try_cast_final<opr::ImmutableTensor>()->value();
  97. ASSERT_EQ(tensor_1.raw_ptr(), tensor_2.raw_ptr());
  98. }
  99. }
  100. #endif
  101. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台