You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_override.cpp 9.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. /**
  2. * \file imperative/python/src/grad_override.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./grad.h"
  12. #include "megbrain/imperative/ops/autogen.h"
  13. namespace mgb::imperative::python {
  14. namespace {
  15. std::shared_ptr<Tensor> get_shape(Tensor* x) {
  16. static auto op = GetVarShape::make();
  17. return python::apply(op, x)[0];
  18. }
  19. std::shared_ptr<Tensor> reduce_to(Tensor* x, Tensor* s) {
  20. static auto op = Reduce::make();
  21. return python::apply(op, x, s)[0];
  22. }
  23. std::shared_ptr<Tensor> reshape_to(Tensor* x, Tensor* s) {
  24. static auto op = Reshape::make();
  25. return python::apply(op, x, s)[0];
  26. }
  27. std::shared_ptr<Tensor> broadcast_to(Tensor* x, Tensor* s) {
  28. static auto op = Broadcast::make();
  29. return python::apply(op, x, s)[0];
  30. }
  31. std::shared_ptr<Tensor> make_tensor(CompNode cn, Tensor* shape, float v = 0) {
  32. HostTensorND scalar{cn, {{1}, dtype::Float32()}};
  33. scalar.ptr<float>()[0] = v;
  34. interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar, false);
  35. auto&& t = std::make_shared<Tensor>(handle);
  36. auto res = broadcast_to(t.get(), shape);
  37. return res;
  38. }
  39. apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
  40. auto& op = ctx.op->cast_final_safe<Elemwise>();
  41. if (op.mode == Elemwise::Mode::ADD) {
  42. mgb_assert(ctx.nargs == 2);
  43. std::array<std::shared_ptr<Tensor>, 2> input_shapes;
  44. for (size_t i = 0; i < 2; ++i) {
  45. if (input_requires_grad(ctx, i)) {
  46. input_shapes[i] = get_shape(ctx.args[i]);
  47. }
  48. }
  49. maker.output_size(1).output_captured(0, false);
  50. maker.backward([shapes=std::move(input_shapes)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
  51. mgb_assert(ngrads == 1);
  52. Tensor* grad = grads[0];
  53. apply_result_t ret(2);
  54. if (!grad) {
  55. return ret;
  56. }
  57. for (size_t i = 0; i < 2; ++i) {
  58. if (shapes[i]) {
  59. ret[i] = reduce_to(grad, shapes[i].get());
  60. }
  61. }
  62. return ret;
  63. });
  64. return apply(ctx);
  65. }
  66. throw GradRuleFallback();
  67. }
  68. apply_result_t reshape_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
  69. mgb_assert(ctx.nargs == 2);
  70. std::array<std::shared_ptr<Tensor>, 2> input_shapes;
  71. for (size_t i = 0; i < 2; ++i) {
  72. if (input_requires_grad(ctx, i)) {
  73. input_shapes[i] = get_shape(ctx.args[i]);
  74. }
  75. }
  76. maker.output_size(1).output_captured(0, false);
  77. maker.backward([shapes=std::move(input_shapes)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
  78. mgb_assert(ngrads == 1);
  79. Tensor* grad = grads[0];
  80. apply_result_t ret(2);
  81. if (!grad) {
  82. return ret;
  83. }
  84. for (size_t i = 0; i < 2; ++i) {
  85. if (shapes[i]) {
  86. ret[i] = reshape_to(grad, shapes[i].get());
  87. }
  88. }
  89. return ret;
  90. });
  91. return apply(ctx);
  92. }
  93. apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
  94. auto&& op = ctx.op->cast_final_safe<Subtensor>();
  95. auto&& grad_op = SetSubtensor::make(op.items);
  96. SmallVector<std::shared_ptr<Tensor>> inputs;
  97. if (input_requires_grad(ctx, 0)) {
  98. inputs.push_back(get_shape(ctx.args[0]));
  99. for (size_t i = 1; i < ctx.nargs; ++i) {
  100. inputs.push_back(ctx.args[i]->copy());
  101. }
  102. }
  103. maker.output_size(1).output_captured(0, false);
  104. maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
  105. mgb_assert(ngrads == 1);
  106. Tensor* grad = grads[0];
  107. apply_result_t ret(1);
  108. if (grad && inputs[0]) {
  109. SmallVector<Tensor*> args_(inputs.size()+1);
  110. auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get());
  111. args_[0] = zeros.get();
  112. args_[1] = grad;
  113. for (size_t i = 1; i < inputs.size(); ++i) {
  114. args_[i+1] = inputs[i].get();
  115. }
  116. ret[0] = python::apply(grad_op_, args_)[0];
  117. }
  118. return ret;
  119. });
  120. return apply(ctx);
  121. }
  122. apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
  123. auto&& op = ctx.op->cast_final_safe<IndexingMultiAxisVec>();
  124. auto&& grad_op = IndexingSetMultiAxisVec::make(op.items);
  125. SmallVector<std::shared_ptr<Tensor>> inputs;
  126. if (input_requires_grad(ctx, 0)) {
  127. inputs.push_back(get_shape(ctx.args[0]));
  128. for (size_t i = 1; i < ctx.nargs; ++i) {
  129. inputs.push_back(ctx.args[i]->copy());
  130. }
  131. }
  132. maker.output_size(1).output_captured(0, false);
  133. maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
  134. mgb_assert(ngrads == 1);
  135. Tensor* grad = grads[0];
  136. apply_result_t ret(1);
  137. if (grad && inputs[0]) {
  138. SmallVector<Tensor*> args_(inputs.size()+1);
  139. auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get());
  140. args_[0] = zeros.get();
  141. args_[1] = grad;
  142. for (size_t i = 1; i < inputs.size(); ++i) {
  143. args_[i+1] = inputs[i].get();
  144. }
  145. ret[0] = python::apply(grad_op_, args_)[0];
  146. }
  147. return ret;
  148. });
  149. return apply(ctx);
  150. }
  151. apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
  152. auto& op = ctx.op->cast_final_safe<Reduce>();
  153. if (op.mode == Reduce::Mode::SUM) {
  154. mgb_assert(ctx.nargs == 1);
  155. std::array<std::shared_ptr<Tensor>, 1> input_shapes;
  156. if (input_requires_grad(ctx, 0)) {
  157. input_shapes[0] = get_shape(ctx.args[0]);
  158. }
  159. maker.output_size(1).output_captured(0, false);
  160. maker.backward([shapes=std::move(input_shapes)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
  161. mgb_assert(ngrads == 1);
  162. Tensor* grad = grads[0];
  163. apply_result_t ret(1);
  164. if (grad && shapes[0]) {
  165. ret[0] = broadcast_to(grad, shapes[0].get());
  166. }
  167. return ret;
  168. });
  169. return apply(ctx);
  170. }
  171. throw GradRuleFallback();
  172. }
  173. apply_result_t addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
  174. auto&& op = ctx.op->cast_final_safe<AddAxis>();
  175. mgb_assert(ctx.nargs == 1);
  176. bool flag = input_requires_grad(ctx, 0);
  177. auto&& grad_op = RemoveAxis::make(op.axis);
  178. std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater<int32_t>());
  179. maker.output_size(1).output_captured(0, false);
  180. maker.backward([grad_op_=std::move(grad_op), flag_=flag](BackwardContext&, Tensor*const* grads, size_t ngrads) {
  181. mgb_assert(ngrads == 1);
  182. Tensor* grad = grads[0];
  183. apply_result_t ret(1);
  184. if (grad && flag_) {
  185. ret[0] = python::apply(grad_op_, grad)[0];
  186. }
  187. return ret;
  188. });
  189. return apply(ctx);
  190. }
  191. apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
  192. auto&& op = ctx.op->cast_final_safe<RemoveAxis>();
  193. mgb_assert(ctx.nargs == 1);
  194. bool flag = input_requires_grad(ctx, 0);
  195. auto&& grad_op = AddAxis::make(op.axis);
  196. std::sort(grad_op->axis.begin(), grad_op->axis.end());
  197. maker.output_size(1).output_captured(0, false);
  198. maker.backward([grad_op_=std::move(grad_op), flag_=flag](BackwardContext&, Tensor*const* grads, size_t ngrads) {
  199. mgb_assert(ngrads == 1);
  200. Tensor* grad = grads[0];
  201. apply_result_t ret(1);
  202. if (grad && flag_) {
  203. ret[0] = python::apply(grad_op_, grad)[0];
  204. }
  205. return ret;
  206. });
  207. return apply(ctx);
  208. }
  209. apply_result_t fastpathcopy_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
  210. mgb_assert(ctx.nargs == 1);
  211. maker.output_size(1).output_captured(0, false);
  212. maker.backward([](BackwardContext&, Tensor*const* grads, size_t ngrads) {
  213. mgb_assert(ngrads == 1);
  214. Tensor* grad = grads[0];
  215. apply_result_t ret(1);
  216. if (grad) {
  217. ret[0] = grad->shared_from_this();
  218. }
  219. return ret;
  220. });
  221. return apply(ctx);
  222. }
  223. struct Init {
  224. Init() {
  225. auto& reg = grad_rule_registry();
  226. reg.emplace(Elemwise::typeinfo(), elemwise_grad_rule);
  227. reg.emplace(Reshape::typeinfo(), reshape_grad_rule);
  228. reg.emplace(Subtensor::typeinfo(), subtensor_grad_rule);
  229. reg.emplace(IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule);
  230. reg.emplace(Reduce::typeinfo(), reduce_grad_rule);
  231. reg.emplace(AddAxis::typeinfo(), addAxis_grad_rule);
  232. reg.emplace(RemoveAxis::typeinfo(), removeAxis_grad_rule);
  233. reg.emplace(FastpathCopy::typeinfo(), fastpathcopy_grad_rule);
  234. }
  235. } _;
  236. } // namespace
  237. } // namespace mgb::imperative::python

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台