You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_utility.cpp 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. /**
  2. * \file imperative/src/test/opr_utility.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/imperative/opr_utility.h"
  12. #include "megbrain/opr/io.h"
  13. #include "megbrain/opr/basic_arith.h"
  14. #include "megbrain/opr/utility.h"
  15. #include "megbrain/opr/tensor_manip.h"
  16. #include "megbrain/test/helper.h"
  17. using namespace mgb;
  18. using namespace opr;
  19. TEST(TestOprUtility, InputCallback) {
  20. HostTensorGenerator<> gen;
  21. DeviceTensorND dv;
  22. auto hv = gen({2, 3});
  23. dv.copy_from(*hv).sync();
  24. auto graph = ComputingGraph::make();
  25. auto callback = [dv]() {return dv;};
  26. auto outputs = opr::InputCallback::make(*graph, callback, dv.comp_node(), dv.dtype(), {2, 3});
  27. HostTensorND hout;
  28. ComputingGraph::OutputSpec outspec{make_callback_copy(outputs[0], hout)};
  29. auto func = graph->compile(outspec);
  30. func->execute();
  31. MGB_ASSERT_TENSOR_EQ(hout, *hv);
  32. }
  33. TEST(TestOprUtility, OutputCallback) {
  34. HostTensorGenerator<> gen;
  35. auto hx = gen({2, 3});
  36. auto graph = ComputingGraph::make();
  37. auto x = opr::Host2DeviceCopy::make(*graph, hx);
  38. HostTensorND hy;
  39. auto callback = [&hy](DeviceTensorND dv) {hy.copy_from(dv);};
  40. auto dummy = opr::OutputCallback::make({callback}, x);
  41. auto y = opr::VirtualDep::make({x, dummy});
  42. ComputingGraph::OutputSpec outspec{{y, [](DeviceTensorND&){}}};
  43. auto func = graph->compile(outspec);
  44. func->execute();
  45. MGB_ASSERT_TENSOR_EQ(hy, *hx);
  46. }
  47. TEST(TestOprUtility, OutputCallbackPreferHost) {
  48. HostTensorGenerator<> gen;
  49. auto hx = gen({2, 3});
  50. auto graph = ComputingGraph::make();
  51. auto x = opr::Host2DeviceCopy::make(*graph, hx);
  52. x = opr::GetVarShape::make(x);
  53. HostTensorND hy;
  54. auto callback = [&hy](DeviceTensorND dv) {hy.copy_from(dv);};
  55. opr::OutputCallback::Param param{callback};
  56. param.prefer_host_value = true;
  57. auto dummy = opr::OutputCallback::make(param, x);
  58. auto y = opr::VirtualDep::make({x, dummy});
  59. ComputingGraph::OutputSpec outspec{{y, [](DeviceTensorND&){}}};
  60. auto func = graph->compile(outspec);
  61. func->execute();
  62. ASSERT_TRUE(hy.comp_node() == CompNode::default_cpu());
  63. ASSERT_EQ(hy.ptr<int>()[0], 2);
  64. ASSERT_EQ(hy.ptr<int>()[1], 3);
  65. }
  66. TEST(TestOprUtility, NopCallback) {
  67. HostTensorGenerator<> gen;
  68. auto hx = gen({2, 3});
  69. auto graph = ComputingGraph::make();
  70. auto x = opr::Host2DeviceCopy::make(*graph, hx);
  71. bool fired = false;
  72. auto callback = [&fired]() {fired = true;};
  73. auto dummy = opr::NopCallback::make(*graph, callback, x.node()->comp_node(), {x});
  74. auto y = opr::VirtualDep::make({x, dummy});
  75. ComputingGraph::OutputSpec outspec{{y, [](DeviceTensorND&){}}};
  76. auto func = graph->compile(outspec);
  77. func->execute();
  78. ASSERT_TRUE(fired);
  79. }
  80. TEST(TestOprUtility, NopCallbackMixedInput) {
  81. REQUIRE_XPU(2);
  82. auto graph = ComputingGraph::make();
  83. auto x0 = opr::Host2DeviceCopy::make(*graph, HostTensorGenerator<dtype::Int32>()({2, 3}), OperatorNodeConfig(CompNode::load("xpu0")));
  84. auto x1 = opr::Host2DeviceCopy::make(*graph, HostTensorGenerator<dtype::Float32>()({2, 3}), OperatorNodeConfig(CompNode::load("xpu1")));
  85. bool fired = false;
  86. auto callback = [&fired]() {fired = true;};
  87. auto dummy = opr::NopCallback::make(*graph, callback, CompNode::load("xpux"), {x0, x1});
  88. auto y = opr::VirtualDep::make({x0, dummy});
  89. ComputingGraph::OutputSpec outspec{{y, [](DeviceTensorND&){}}};
  90. auto func = graph->compile(outspec);
  91. func->execute();
  92. ASSERT_TRUE(fired);
  93. }
  94. TEST(TestOprUtility, CallbackChain) {
  95. auto graph = ComputingGraph::make();
  96. graph->options().graph_opt_level = 0;
  97. HostTensorGenerator<dtype::Int16> gen;
  98. SymbolVar x, dummy;
  99. DeviceTensorND dev_x, dev_y;
  100. auto host_x = gen({2, 3});
  101. dev_x.copy_from(*host_x).sync();
  102. auto cn = dev_x.comp_node();
  103. auto dev_x_weakptr = std::weak_ptr<dt_byte>(dev_x.storage().raw_storage());
  104. {
  105. auto callback = [&dev_x]() {
  106. DeviceTensorND ret = dev_x;
  107. dev_x.storage({});
  108. return ret;
  109. };
  110. auto out = opr::InputCallback::make(*graph, callback, cn, dev_x.dtype(), {2, 3});
  111. x = out[0];
  112. dummy = out[1];
  113. }
  114. {
  115. x = opr::TypeCvt::make(x, dtype::Int32());
  116. x = opr::TypeCvt::make(x, dtype::Int16());
  117. auto callback = [&](DeviceTensorND y) {
  118. // dev_x.storage has been reset in InputCallback
  119. mgb_assert(!dev_x.storage().comp_node_valid());
  120. dev_y = y;
  121. };
  122. dummy = opr::OutputCallback::make({callback}, {x, dummy});
  123. }
  124. bool fired = false;
  125. {
  126. auto callback = [&]() {
  127. fired = true;
  128. ASSERT_FALSE(dev_x_weakptr.lock());
  129. };
  130. dummy = opr::NopCallback::make(*graph, callback, cn, {dummy});
  131. }
  132. {
  133. auto out = opr::VirtualDep::make({x.make_scalar(0), dummy});
  134. ComputingGraph::OutputSpec outspec{{out, [](DeviceTensorND&){}}};
  135. auto func = graph->compile(outspec);
  136. func->execute();
  137. }
  138. ASSERT_TRUE(fired);
  139. HostTensorND host_y;
  140. host_y.copy_from(dev_y).sync();
  141. MGB_ASSERT_TENSOR_EQ(host_y, *host_x);
  142. }
  143. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台