You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.h 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. /**
  2. * \file imperative/python/src/grad.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "./tensor.h"
  13. #include "megbrain/imperative/ops/utility.h"
  14. #include <megbrain/utils/small_vector.h>
  15. #include <memory>
  16. namespace mgb::imperative::python {
  17. apply_result_t apply_grad(ApplyContext& ctx);
  18. struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj {
  19. std::string name;
  20. bool active = true;
  21. GradInfo::head_t free_vars_head;
  22. std::vector<std::weak_ptr<GradFn>> tape;
  23. int priority = 0;
  24. ~GradKey();
  25. void attach(Tensor* tensor, pybind11::object callback);
  26. void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
  27. void cleanup();
  28. bool is_blocked() const {
  29. return priority < sm_min_priority;
  30. }
  31. private:
  32. static int sm_min_priority;
  33. };
  34. struct GradKeyWrapper {
  35. using wrap_t = pyext17::wrap<GradKeyWrapper>;
  36. static constexpr auto tp_name = pybind11::detail::_("GradKey");
  37. std::shared_ptr<GradKey> m_key;
  38. inline GradKeyWrapper() : m_key(std::make_shared<GradKey>()) {}
  39. PyObject* get_name();
  40. void set_name(pybind11::handle name);
  41. PyObject* get_priority();
  42. void set_priority(pybind11::handle priority);
  43. void attach(PyObject*const* args, size_t nargs);
  44. void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
  45. PyObject* is_attached_to(PyObject*const* args, size_t nargs);
  46. };
  47. struct BackwardContext {
  48. PyTypeObject* pytype = nullptr;
  49. auto wrap_tensor(std::shared_ptr<Tensor> t) {
  50. if (pytype) {
  51. return TensorWrapper::make(pytype, std::move(t));
  52. }
  53. return TensorWrapper::make(std::move(t));
  54. }
  55. auto wrap_tensor(Tensor* t) {
  56. return wrap_tensor(t->shared_from_this());
  57. }
  58. };
  59. struct CustomBackward {
  60. using BackwardFn = std::function<apply_result_t(BackwardContext&, Tensor*const*, size_t)>;
  61. BackwardFn m_backward;
  62. SmallVector<bool, 8> m_input_has_grad;
  63. struct OutputAttr {bool requires_grad = true, captured = true;};
  64. SmallVector<OutputAttr> m_output_attrs;
  65. public:
  66. template<typename T, typename R>
  67. void operator()(BackwardContext& ctx, T&& grads, R&& receiver) {
  68. size_t nargs = grads.size();
  69. Tensor* args[nargs];
  70. for (size_t i = 0; i < nargs; ++i) {
  71. args[i] = grads[i];
  72. }
  73. auto ret = m_backward(ctx, args, nargs);
  74. for (size_t i = 0; i < ret.size(); ++i) {
  75. if (auto&& t = ret[i]) {
  76. receiver(i, std::move(t));
  77. }
  78. }
  79. }
  80. bool input_has_grad(size_t i) {return m_input_has_grad[i];}
  81. bool output_requires_grad(size_t i) {return m_output_attrs[i].requires_grad;}
  82. bool output_captured(size_t i) {return m_output_attrs[i].captured;}
  83. class Maker {
  84. bool output_size_set = false, input_has_grad_initialized = false;
  85. CustomBackward& target;
  86. ApplyContext& ctx;
  87. void init_input_has_grad() {
  88. if (!input_has_grad_initialized) {
  89. input_has_grad_initialized = true;
  90. target.m_input_has_grad.resize(ctx.nargs, true);
  91. }
  92. }
  93. public:
  94. Maker(CustomBackward& target_, ApplyContext& ctx_) : target(target_), ctx(ctx_) {}
  95. template<typename F>
  96. Maker& backward(F&& f) {
  97. mgb_assert(!target.m_backward);
  98. target.m_backward = std::forward<F>(f);
  99. return *this;
  100. }
  101. // mandatory
  102. Maker& output_size(size_t sz) {
  103. mgb_assert(!output_size_set);
  104. output_size_set = true;
  105. target.m_output_attrs.resize(sz);
  106. return *this;
  107. }
  108. // optional, defaults to all true
  109. Maker& input_has_grad(size_t i, bool v) {
  110. init_input_has_grad();
  111. target.m_input_has_grad.at(i) = v;
  112. return *this;
  113. }
  114. // optional, defaults to all true
  115. Maker& output_requires_grad(size_t i, bool v) {
  116. target.m_output_attrs.at(i).requires_grad = v;
  117. return *this;
  118. }
  119. // optional, defaults to all true
  120. Maker& output_captured(size_t i, bool v) {
  121. target.m_output_attrs.at(i).captured = v;
  122. return *this;
  123. }
  124. void finalize() {
  125. mgb_assert(output_size_set);
  126. init_input_has_grad();
  127. }
  128. };
  129. Maker maker(ApplyContext& ctx) {return {*this, ctx};}
  130. };
  131. using GradRuleFn = std::function<apply_result_t(ApplyContext&, CustomBackward::Maker&)>;
  132. std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry();
  133. inline bool input_requires_grad(const ApplyContext& ctx, size_t i) {
  134. return !ctx.args[i]->m_grad_info_dict.empty();
  135. }
  136. struct GradRuleFallback : std::exception {};
  137. template<typename T>
  138. bool register_grad_rule(Typeinfo* typeinfo, T&& rule) {
  139. return grad_rule_registry().emplace(typeinfo, std::forward<T>(rule)).second;
  140. }
  141. } // namespace mgb::imperative::python
  142. namespace pybind11::detail {
  143. template<> struct type_caster<mgb::imperative::python::GradKeyWrapper> : mgb::imperative::python::GradKeyWrapper::wrap_t::caster {};
  144. } // namespace pybind11::detail

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台