You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utility.sereg.h 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. /**
  2. * \file src/opr/impl/utility.sereg.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/opr/utility.h"
  12. #include "megbrain/serialization/sereg.h"
  13. #if MGB_ENABLE_FBS_SERIALIZATION
  14. #include "megbrain/serialization/internal/mgb_cpp_opr_generated.h"
  15. #endif
  16. namespace mgb {
  17. namespace serialization {
  18. template<>
  19. struct OprLoadDumpImpl<opr::AssertEqual, 0> {
  20. static void dump(OprDumpContext &ctx,
  21. const cg::OperatorNodeBase &opr_) {
  22. auto &&opr = opr_.cast_final_safe<opr::AssertEqual>();
  23. ctx.write_param(opr.param());
  24. }
  25. static cg::OperatorNodeBase* load(
  26. OprLoadContext &ctx, const cg::VarNodeArray &inputs,
  27. const OperatorNodeConfig &config) {
  28. auto param = ctx.read_param<megdnn::param::AssertEqual>();
  29. SymbolVar out;
  30. if (inputs.size() == 2) {
  31. // from python
  32. out = opr::AssertEqual::make(
  33. inputs[0], inputs[1], param, config);
  34. } else {
  35. // from sereg or copy
  36. mgb_assert(inputs.size() == 3);
  37. out = opr::AssertEqual::make(
  38. inputs[0], inputs[1], inputs[2], param, config);
  39. }
  40. return out.node()->owner_opr();
  41. }
  42. };
  43. #if !MGB_BUILD_SLIM_SERVING
  44. template <>
  45. struct OprLoadDumpImpl<opr::VirtualDep, 0> {
  46. static void dump(OprDumpContext& ctx,
  47. const cg::OperatorNodeBase& opr_) {}
  48. static cg::OperatorNodeBase* load(OprLoadContext& ctx,
  49. const cg::VarNodeArray& inputs,
  50. const OperatorNodeConfig& config) {
  51. return opr::VirtualDep::make(to_symbol_var_array(inputs), config)
  52. .node()
  53. ->owner_opr();
  54. }
  55. };
  56. #if MGB_ENABLE_FBS_SERIALIZATION
  57. namespace fbs {
  58. template <>
  59. struct ParamConverter<opr::Sleep::Param> {
  60. using FlatBufferType = param::MGBSleep;
  61. static opr::Sleep::Param to_param(const param::MGBSleep* fb) {
  62. return {fb->seconds(), {fb->device(), fb->host()}};
  63. }
  64. static flatbuffers::Offset<param::MGBSleep> to_flatbuffer(
  65. flatbuffers::FlatBufferBuilder& builder,
  66. const opr::Sleep::Param& p) {
  67. return param::CreateMGBSleep(builder, p.type.device, p.type.host,
  68. p.seconds);
  69. }
  70. };
  71. } // namespace fbs
  72. #endif
  73. #endif
  74. } // namespace serialization
  75. namespace opr {
  76. MGB_SEREG_OPR(MarkDynamicVar, 1);
  77. MGB_SEREG_OPR(MarkNoBroadcastElemwise, 1);
  78. MGB_SEREG_OPR(Identity, 1);
  79. MGB_SEREG_OPR(AssertEqual, 0);
  80. #if MGB_ENABLE_GRAD
  81. MGB_SEREG_OPR(VirtualGrad, 2);
  82. cg::OperatorNodeBase* opr_shallow_copy_set_grad(
  83. const serialization::OprShallowCopyContext &ctx,
  84. const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs,
  85. const OperatorNodeConfig &config) {
  86. mgb_assert(inputs.size() == 1);
  87. auto &&opr = opr_.cast_final_safe<SetGrad>();
  88. return SetGrad::make(inputs[0], opr.grad_getter(), config).
  89. node()->owner_opr();
  90. }
  91. MGB_REG_OPR_SHALLOW_COPY(SetGrad, opr_shallow_copy_set_grad);
  92. cg::OperatorNodeBase* opr_shallow_copy_virtual_loss(
  93. const serialization::OprShallowCopyContext& ctx,
  94. const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
  95. const OperatorNodeConfig& config) {
  96. return inputs[0]->owner_graph()->insert_opr(
  97. std::make_unique<VirtualLoss>(inputs, config));
  98. }
  99. MGB_REG_OPR_SHALLOW_COPY(VirtualLoss, opr_shallow_copy_virtual_loss);
  100. cg::OperatorNodeBase* opr_shallow_copy_invalid_grad(
  101. const serialization::OprShallowCopyContext& ctx,
  102. const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
  103. const OperatorNodeConfig& config) {
  104. mgb_assert(inputs.size() == 1);
  105. auto&& opr = opr_.cast_final_safe<InvalidGrad>();
  106. return inputs[0]->owner_opr()->owner_graph()->insert_opr(
  107. std::make_unique<InvalidGrad>(inputs[0], opr.grad_opr(),
  108. opr.inp_idx()));
  109. }
  110. MGB_REG_OPR_SHALLOW_COPY(InvalidGrad, opr_shallow_copy_invalid_grad)
  111. #endif
  112. cg::OperatorNodeBase* opr_shallow_copy_callback_injector(
  113. const serialization::OprShallowCopyContext &ctx,
  114. const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs,
  115. const OperatorNodeConfig &config) {
  116. auto &&opr = opr_.cast_final_safe<CallbackInjector>();
  117. return CallbackInjector::make(cg::to_symbol_var_array(inputs), opr.param(), config).
  118. node()->owner_opr();
  119. }
  120. MGB_REG_OPR_SHALLOW_COPY(CallbackInjector,
  121. opr_shallow_copy_callback_injector);
  122. cg::OperatorNodeBase* opr_shallow_copy_require_input_dynamic_storage(
  123. const serialization::OprShallowCopyContext &ctx,
  124. const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs,
  125. const OperatorNodeConfig &config) {
  126. mgb_assert(inputs.size() == 1);
  127. return RequireInputDynamicStorage::make(inputs[0], config).
  128. node()->owner_opr();
  129. }
  130. MGB_REG_OPR_SHALLOW_COPY(RequireInputDynamicStorage,
  131. opr_shallow_copy_require_input_dynamic_storage);
  132. #if !MGB_BUILD_SLIM_SERVING
  133. MGB_SEREG_OPR(Sleep, 1);
  134. MGB_SEREG_OPR(VirtualDep, 0);
  135. #endif
  136. MGB_SEREG_OPR(PersistentOutputStorage, 1);
  137. cg::OperatorNodeBase* opr_shallow_copy_shape_hint(
  138. const serialization::OprShallowCopyContext &ctx,
  139. const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs,
  140. const OperatorNodeConfig &config) {
  141. auto &&opr = opr_.cast_final_safe<ShapeHint>();
  142. mgb_assert(inputs.size() == 1);
  143. return ShapeHint::make(inputs[0], opr.shape(), opr.is_const(), config)
  144. .node()->owner_opr();
  145. }
  146. MGB_REG_OPR_SHALLOW_COPY(ShapeHint, opr_shallow_copy_shape_hint);
  147. } // namespace opr
  148. } // namespace mgb
  149. // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台