You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_misc.cpp 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. /**
  2. * \file test/test_misc.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "lite_build_config.h"
  12. #if LITE_BUILD_WITH_MGE
  13. #include "../src/decryption/decrypt_base.h"
  14. #include "../src/network_impl_base.h"
  15. #include "test_common.h"
  16. #include "megbrain/opr/io.h"
  17. #include "megbrain/tensor.h"
  18. #include "megbrain/utils/metahelper.h"
  19. #include <gtest/gtest.h>
  20. #include <string.h>
  21. #include <chrono>
  22. #include <memory>
  23. #include <random>
  24. using namespace lite;
  25. TEST(TestMisc, DecryptionRegister) {
  26. size_t number = decryption_static_data().decryption_methods.size();
  27. //! At least one method is register by lite
  28. ASSERT_GE(number, 1);
  29. DecryptionFunc func;
  30. register_decryption_and_key("AllForTest0", func, {});
  31. ASSERT_EQ(number + 1, decryption_static_data().decryption_methods.size());
  32. }
  33. TEST(TestMisc, DecryptionUpdate) {
  34. DecryptionFunc func;
  35. register_decryption_and_key("AllForTest1", func, {});
  36. func = [](const void*, size_t,
  37. const std::vector<uint8_t>&) -> std::vector<uint8_t> { return {}; };
  38. update_decryption_or_key("AllForTest1", func, {});
  39. ASSERT_NE(
  40. decryption_static_data().decryption_methods["AllForTest1"].first, nullptr);
  41. ASSERT_EQ(
  42. decryption_static_data().decryption_methods["AllForTest1"].second->size(),
  43. 0);
  44. update_decryption_or_key("AllForTest1", {}, {1, 2, 3});
  45. ASSERT_EQ(
  46. decryption_static_data().decryption_methods["AllForTest1"].second->size(),
  47. 3);
  48. }
  49. TEST(TestMisc, SharedSameDeviceTensor) {
  50. using namespace mgb;
  51. serialization::GraphLoader::LoadConfig mgb_config;
  52. mgb_config.comp_node_mapper = [](CompNode::Locator& loc) {
  53. loc = to_compnode_locator(LiteDeviceType::LITE_CPU);
  54. };
  55. mgb_config.comp_graph = ComputingGraph::make();
  56. std::string model_path = "./shufflenet.mge";
  57. auto inp_file = mgb::serialization::InputFile::make_fs(model_path.c_str());
  58. auto format = serialization::GraphLoader::identify_graph_dump_format(*inp_file);
  59. mgb_assert(
  60. format.valid(),
  61. "invalid model: unknown model format, please make sure input "
  62. "file is generated by GraphDumper");
  63. auto loader = serialization::GraphLoader::make(std::move(inp_file), format.val());
  64. auto load_ret_1 = loader->load(mgb_config, true);
  65. auto load_ret_2 = loader->load(mgb_config, true);
  66. ASSERT_EQ(load_ret_1.output_var_list.size(), load_ret_2.output_var_list.size());
  67. ComputingGraph::OutputSpec out_spec_1, out_spec_2;
  68. for (size_t i = 0; i < load_ret_1.output_var_list.size(); i++) {
  69. out_spec_1.emplace_back(load_ret_1.output_var_list[i], nullptr);
  70. out_spec_2.emplace_back(load_ret_2.output_var_list[i], nullptr);
  71. }
  72. auto func_1 = load_ret_1.graph_compile(out_spec_1);
  73. auto func_2 = load_ret_2.graph_compile(out_spec_1);
  74. std::vector<cg::OperatorNodeBase*> oprs_1, oprs_2;
  75. func_1->iter_opr_seq([&oprs_1](cg::OperatorNodeBase* opr) -> bool {
  76. if (opr->try_cast_final<opr::ImmutableTensor>()) {
  77. oprs_1.push_back(opr);
  78. }
  79. return true;
  80. });
  81. func_1->iter_opr_seq([&oprs_2](cg::OperatorNodeBase* opr) -> bool {
  82. if (opr->try_cast_final<opr::ImmutableTensor>()) {
  83. oprs_2.push_back(opr);
  84. }
  85. return true;
  86. });
  87. ASSERT_EQ(oprs_1.size(), oprs_2.size());
  88. for (size_t i = 0; i < oprs_1.size(); i++) {
  89. auto tensor_1 = oprs_1[i]->try_cast_final<opr::ImmutableTensor>()->value();
  90. auto tensor_2 = oprs_2[i]->try_cast_final<opr::ImmutableTensor>()->value();
  91. ASSERT_EQ(tensor_1.raw_ptr(), tensor_2.raw_ptr());
  92. }
  93. }
  94. #endif
  95. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台