You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_misc.cpp 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. /**
  2. * \file test/test_misc.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "lite_build_config.h"
  12. #if LITE_BUILD_WITH_MGE
  13. #include "test_common.h"
  14. #include "../src/decryption/decrypt_base.h"
  15. #include "../src/network_impl_base.h"
  16. #include "megbrain/opr/io.h"
  17. #include "megbrain/tensor.h"
  18. #include "megbrain/utils/metahelper.h"
  19. #include <gtest/gtest.h>
  20. #include <string.h>
  21. #include <chrono>
  22. #include <memory>
  23. #include <random>
  24. using namespace lite;
  25. TEST(TestMisc, DecryptionRegister) {
  26. size_t number = decryption_static_data().decryption_methods.size();
  27. //! At least one method is register by lite
  28. ASSERT_GE(number, 1);
  29. DecryptionFunc func;
  30. register_decryption_and_key("AllForTest0", func, {});
  31. ASSERT_EQ(number + 1, decryption_static_data().decryption_methods.size());
  32. }
  33. TEST(TestMisc, DecryptionUpdate) {
  34. DecryptionFunc func;
  35. register_decryption_and_key("AllForTest1", func, {});
  36. func = [](const void*, size_t,
  37. const std::vector<uint8_t>&) -> std::vector<uint8_t> {
  38. return {};
  39. };
  40. update_decryption_or_key("AllForTest1", func, {});
  41. ASSERT_NE(decryption_static_data().decryption_methods["AllForTest1"].first,
  42. nullptr);
  43. ASSERT_EQ(decryption_static_data()
  44. .decryption_methods["AllForTest1"]
  45. .second->size(),
  46. 0);
  47. update_decryption_or_key("AllForTest1", {}, {1, 2, 3});
  48. ASSERT_EQ(decryption_static_data()
  49. .decryption_methods["AllForTest1"]
  50. .second->size(),
  51. 3);
  52. }
  53. TEST(TestMisc, SharedSameDeviceTensor) {
  54. using namespace mgb;
  55. serialization::GraphLoader::LoadConfig mgb_config;
  56. mgb_config.comp_node_mapper = [](CompNode::Locator& loc) {
  57. loc = to_compnode_locator(LiteDeviceType::LITE_CPU);
  58. };
  59. mgb_config.comp_graph = ComputingGraph::make();
  60. std::string model_path = "./shufflenet.mge";
  61. auto inp_file = mgb::serialization::InputFile::make_fs(model_path.c_str());
  62. auto format =
  63. serialization::GraphLoader::identify_graph_dump_format(*inp_file);
  64. mgb_assert(format.valid(),
  65. "invalid model: unknown model format, please make sure input "
  66. "file is generated by GraphDumper");
  67. auto loader =
  68. serialization::GraphLoader::make(std::move(inp_file), format.val());
  69. auto load_ret_1 = loader->load(mgb_config, true);
  70. auto load_ret_2 = loader->load(mgb_config, true);
  71. ASSERT_EQ(load_ret_1.output_var_list.size(),
  72. load_ret_2.output_var_list.size());
  73. ComputingGraph::OutputSpec out_spec_1, out_spec_2;
  74. for (size_t i = 0; i < load_ret_1.output_var_list.size(); i++) {
  75. out_spec_1.emplace_back(load_ret_1.output_var_list[i], nullptr);
  76. out_spec_2.emplace_back(load_ret_2.output_var_list[i], nullptr);
  77. }
  78. auto func_1 = load_ret_1.graph_compile(out_spec_1);
  79. auto func_2 = load_ret_2.graph_compile(out_spec_1);
  80. std::vector<cg::OperatorNodeBase*> oprs_1, oprs_2;
  81. func_1->iter_opr_seq([&oprs_1](cg::OperatorNodeBase* opr) -> bool {
  82. if (opr->try_cast_final<opr::ImmutableTensor>()) {
  83. oprs_1.push_back(opr);
  84. }
  85. return true;
  86. });
  87. func_1->iter_opr_seq([&oprs_2](cg::OperatorNodeBase* opr) -> bool {
  88. if (opr->try_cast_final<opr::ImmutableTensor>()) {
  89. oprs_2.push_back(opr);
  90. }
  91. return true;
  92. });
  93. ASSERT_EQ(oprs_1.size(), oprs_2.size());
  94. for (size_t i = 0; i < oprs_1.size(); i++) {
  95. auto tensor_1 =
  96. oprs_1[i]->try_cast_final<opr::ImmutableTensor>()->value();
  97. auto tensor_2 =
  98. oprs_2[i]->try_cast_final<opr::ImmutableTensor>()->value();
  99. ASSERT_EQ(tensor_1.raw_ptr(), tensor_2.raw_ptr());
  100. }
  101. }
  102. #endif
  103. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台