You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_override.cpp 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. /**
  2. * \file imperative/python/src/grad_override.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./grad.h"
  12. #include "megbrain/imperative/ops/autogen.h"
  13. namespace mgb::imperative::python {
  14. namespace {
  15. std::shared_ptr<Tensor> get_shape(Tensor* x) {
  16. static auto op = GetVarShape::make();
  17. return python::apply(op, x)[0];
  18. }
  19. std::shared_ptr<Tensor> reduce_to(Tensor* x, Tensor* s) {
  20. static auto op = Reduce::make();
  21. return python::apply(op, x, s)[0];
  22. }
  23. std::shared_ptr<Tensor> reshape_to(Tensor* x, Tensor* s) {
  24. static auto op = Reshape::make();
  25. return python::apply(op, x, s)[0];
  26. }
  27. std::shared_ptr<Tensor> broadcast_to(Tensor* x, Tensor* s) {
  28. static auto op = Broadcast::make();
  29. return python::apply(op, x, s)[0];
  30. }
  31. std::shared_ptr<Tensor> make_tensor(CompNode cn, Tensor* shape, float v = 0) {
  32. HostTensorND scalar{cn, {{1}, dtype::Float32()}};
  33. scalar.ptr<float>()[0] = v;
  34. interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar);
  35. auto&& t = std::make_shared<Tensor>(handle);
  36. auto&& res = broadcast_to(t.get(), shape);
  37. return res;
  38. }
  39. apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
  40. auto& op = ctx.op->cast_final_safe<Elemwise>();
  41. if (op.mode == Elemwise::Mode::ADD) {
  42. mgb_assert(ctx.nargs == 2);
  43. std::array<std::shared_ptr<Tensor>, 2> input_shapes;
  44. for (size_t i = 0; i < 2; ++i) {
  45. if (input_requires_grad(ctx, i)) {
  46. input_shapes[i] = get_shape(ctx.args[i]);
  47. }
  48. }
  49. maker.output_size(1).output_captured(0, false);
  50. maker.backward([shapes=std::move(input_shapes)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
  51. mgb_assert(ngrads == 1);
  52. Tensor* grad = grads[0];
  53. apply_result_t ret(2);
  54. for (size_t i = 0; i < 2; ++i) {
  55. if (shapes[i]) {
  56. ret[i] = reduce_to(grad, shapes[i].get());
  57. }
  58. }
  59. return ret;
  60. });
  61. return apply(ctx);
  62. }
  63. throw GradRuleFallback();
  64. }
  65. apply_result_t reshape_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
  66. mgb_assert(ctx.nargs == 2);
  67. std::array<std::shared_ptr<Tensor>, 2> input_shapes;
  68. for (size_t i = 0; i < 2; ++i) {
  69. if (input_requires_grad(ctx, i)) {
  70. input_shapes[i] = get_shape(ctx.args[i]);
  71. }
  72. }
  73. maker.output_size(1).output_captured(0, false);
  74. maker.backward([shapes=std::move(input_shapes)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
  75. mgb_assert(ngrads == 1);
  76. Tensor* grad = grads[0];
  77. apply_result_t ret(2);
  78. for (size_t i = 0; i < 2; ++i) {
  79. if (shapes[i]) {
  80. ret[i] = reshape_to(grad, shapes[i].get());
  81. }
  82. }
  83. return ret;
  84. });
  85. return apply(ctx);
  86. }
  87. apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
  88. auto&& op = ctx.op->cast_final_safe<Subtensor>();
  89. auto&& grad_op = SetSubtensor::make(op.items);
  90. SmallVector<std::shared_ptr<Tensor>> inputs;
  91. if (input_requires_grad(ctx, 0)) {
  92. inputs.push_back(get_shape(ctx.args[0]));
  93. for (size_t i = 1; i < ctx.nargs; ++i) {
  94. inputs.push_back(ctx.args[i]->copy());
  95. }
  96. }
  97. maker.output_size(1).output_captured(0, false);
  98. maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
  99. mgb_assert(ngrads == 1);
  100. apply_result_t ret(1);
  101. if (inputs[0]) {
  102. SmallVector<Tensor*> args_(inputs.size()+1);
  103. Tensor* grad = grads[0];
  104. auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get());
  105. args_[0] = zeros.get();
  106. args_[1] = grad;
  107. for (size_t i = 1; i < inputs.size(); ++i) {
  108. args_[i+1] = inputs[i].get();
  109. }
  110. ret[0] = python::apply(grad_op_, args_)[0];
  111. }
  112. return ret;
  113. });
  114. return apply(ctx);
  115. }
  116. apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
  117. auto&& op = ctx.op->cast_final_safe<IndexingMultiAxisVec>();
  118. auto&& grad_op = IndexingSetMultiAxisVec::make(op.items);
  119. SmallVector<std::shared_ptr<Tensor>> inputs;
  120. if (input_requires_grad(ctx, 0)) {
  121. inputs.push_back(get_shape(ctx.args[0]));
  122. for (size_t i = 1; i < ctx.nargs; ++i) {
  123. inputs.push_back(ctx.args[i]->copy());
  124. }
  125. }
  126. maker.output_size(1).output_captured(0, false);
  127. maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
  128. mgb_assert(ngrads == 1);
  129. apply_result_t ret(1);
  130. if (inputs[0]) {
  131. SmallVector<Tensor*> args_(inputs.size()+1);
  132. Tensor* grad = grads[0];
  133. auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get());
  134. args_[0] = zeros.get();
  135. args_[1] = grad;
  136. for (size_t i = 1; i < inputs.size(); ++i) {
  137. args_[i+1] = inputs[i].get();
  138. }
  139. ret[0] = python::apply(grad_op_, args_)[0];
  140. }
  141. return ret;
  142. });
  143. return apply(ctx);
  144. }
  145. apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
  146. auto& op = ctx.op->cast_final_safe<Reduce>();
  147. if (op.mode == Reduce::Mode::SUM) {
  148. mgb_assert(ctx.nargs == 1);
  149. std::array<std::shared_ptr<Tensor>, 1> input_shapes;
  150. if (input_requires_grad(ctx, 0)) {
  151. input_shapes[0] = get_shape(ctx.args[0]);
  152. }
  153. maker.output_size(1).output_captured(0, false);
  154. maker.backward([shapes=std::move(input_shapes)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
  155. mgb_assert(ngrads == 1);
  156. Tensor* grad = grads[0];
  157. apply_result_t ret(1);
  158. if (shapes[0]) {
  159. ret[0] = broadcast_to(grad, shapes[0].get());
  160. }
  161. return ret;
  162. });
  163. return apply(ctx);
  164. }
  165. throw GradRuleFallback();
  166. }
  167. template<typename T, typename U>
  168. apply_result_t axisAddRemove_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
  169. auto&& op = ctx.op->cast_final_safe<T>();
  170. mgb_assert(ctx.nargs == 1);
  171. auto&& grad_op = U::make(op.axis);
  172. maker.output_size(1).output_captured(0, false);
  173. maker.backward([grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
  174. mgb_assert(ngrads == 1);
  175. Tensor* grad = grads[0];
  176. apply_result_t ret(1);
  177. ret[0] = python::apply(grad_op_, grad)[0];
  178. return ret;
  179. });
  180. return apply(ctx);
  181. }
  182. struct Init {
  183. Init() {
  184. auto& reg = grad_rule_registry();
  185. reg.emplace(Elemwise::typeinfo(), elemwise_grad_rule);
  186. reg.emplace(Reshape::typeinfo(), reshape_grad_rule);
  187. reg.emplace(Subtensor::typeinfo(), subtensor_grad_rule);
  188. reg.emplace(IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule);
  189. reg.emplace(Reduce::typeinfo(), reduce_grad_rule);
  190. reg.emplace(AddAxis::typeinfo(), axisAddRemove_grad_rule<AddAxis, RemoveAxis>);
  191. reg.emplace(RemoveAxis::typeinfo(), axisAddRemove_grad_rule<RemoveAxis, AddAxis>);
  192. }
  193. } _;
  194. } // namespace
  195. } // namespace mgb::imperative::python

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台