You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.h 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. /**
  2. * \file imperative/python/src/grad.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "./tensor.h"
  13. #include "megbrain/imperative/ops/utility.h"
  14. #include <megbrain/utils/small_vector.h>
  15. #include <memory>
  16. namespace mgb::imperative::python {
  17. apply_result_t apply_grad(ApplyContext& ctx);
  18. struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj {
  19. std::string name;
  20. bool active = true;
  21. GradInfo::head_t free_vars_head;
  22. std::vector<std::weak_ptr<GradFn>> tape;
  23. ~GradKey();
  24. void attach(Tensor* tensor, pybind11::object callback);
  25. void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
  26. void cleanup();
  27. };
  28. struct GradKeyWrapper {
  29. using wrap_t = pyext17::wrap<GradKeyWrapper>;
  30. static constexpr auto tp_name = pybind11::detail::_("GradKey");
  31. std::shared_ptr<GradKey> m_key;
  32. inline GradKeyWrapper() : m_key(std::make_shared<GradKey>()) {}
  33. PyObject* get_name();
  34. void set_name(pybind11::handle name);
  35. void attach(PyObject*const* args, size_t nargs);
  36. void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
  37. PyObject* is_attached_to(PyObject*const* args, size_t nargs);
  38. };
  39. struct BackwardContext {
  40. PyTypeObject* pytype = nullptr;
  41. auto wrap_tensor(std::shared_ptr<Tensor> t) {
  42. if (pytype) {
  43. return TensorWrapper::make(pytype, std::move(t));
  44. }
  45. return TensorWrapper::make(std::move(t));
  46. }
  47. auto wrap_tensor(Tensor* t) {
  48. return wrap_tensor(t->shared_from_this());
  49. }
  50. };
  51. struct CustomBackward {
  52. using BackwardFn = std::function<apply_result_t(BackwardContext&, Tensor*const*, size_t)>;
  53. BackwardFn m_backward;
  54. SmallVector<bool, 8> m_input_has_grad;
  55. struct OutputAttr {bool requires_grad = true, captured = true;};
  56. SmallVector<OutputAttr> m_output_attrs;
  57. public:
  58. template<typename T, typename R>
  59. void operator()(BackwardContext& ctx, T&& grads, R&& receiver) {
  60. size_t nargs = grads.size();
  61. Tensor* args[nargs];
  62. for (size_t i = 0; i < nargs; ++i) {
  63. args[i] = grads[i];
  64. }
  65. auto ret = m_backward(ctx, args, nargs);
  66. for (size_t i = 0; i < ret.size(); ++i) {
  67. if (auto&& t = ret[i]) {
  68. receiver(i, std::move(t));
  69. }
  70. }
  71. }
  72. bool input_has_grad(size_t i) {return m_input_has_grad[i];}
  73. bool output_requires_grad(size_t i) {return m_output_attrs[i].requires_grad;}
  74. bool output_captured(size_t i) {return m_output_attrs[i].captured;}
  75. class Maker {
  76. bool output_size_set = false, input_has_grad_initialized = false;
  77. CustomBackward& target;
  78. ApplyContext& ctx;
  79. void init_input_has_grad() {
  80. if (!input_has_grad_initialized) {
  81. input_has_grad_initialized = true;
  82. target.m_input_has_grad.resize(ctx.nargs, true);
  83. }
  84. }
  85. public:
  86. Maker(CustomBackward& target_, ApplyContext& ctx_) : target(target_), ctx(ctx_) {}
  87. template<typename F>
  88. Maker& backward(F&& f) {
  89. mgb_assert(!target.m_backward);
  90. target.m_backward = std::forward<F>(f);
  91. return *this;
  92. }
  93. // mandatory
  94. Maker& output_size(size_t sz) {
  95. mgb_assert(!output_size_set);
  96. output_size_set = true;
  97. target.m_output_attrs.resize(sz);
  98. return *this;
  99. }
  100. // optional, defaults to all true
  101. Maker& input_has_grad(size_t i, bool v) {
  102. init_input_has_grad();
  103. target.m_input_has_grad.at(i) = v;
  104. return *this;
  105. }
  106. // optional, defaults to all true
  107. Maker& output_requires_grad(size_t i, bool v) {
  108. target.m_output_attrs.at(i).requires_grad = v;
  109. return *this;
  110. }
  111. // optional, defaults to all true
  112. Maker& output_captured(size_t i, bool v) {
  113. target.m_output_attrs.at(i).captured = v;
  114. return *this;
  115. }
  116. void finalize() {
  117. mgb_assert(output_size_set);
  118. init_input_has_grad();
  119. }
  120. };
  121. Maker maker(ApplyContext& ctx) {return {*this, ctx};}
  122. };
  123. using GradRuleFn = std::function<apply_result_t(ApplyContext&, CustomBackward::Maker&)>;
  124. std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry();
  125. inline bool input_requires_grad(const ApplyContext& ctx, size_t i) {
  126. return bool(ctx.args[i]->m_grad_info.grad_fn);
  127. }
  128. struct GradRuleFallback : std::exception {};
  129. template<typename T>
  130. bool register_grad_rule(Typeinfo* typeinfo, T&& rule) {
  131. return grad_rule_registry().emplace(typeinfo, std::forward<T>(rule)).second;
  132. }
  133. } // namespace mgb::imperative::python
  134. namespace pybind11::detail {
  135. template<> struct type_caster<mgb::imperative::python::GradKeyWrapper> : mgb::imperative::python::GradKeyWrapper::wrap_t::caster {};
  136. } // namespace pybind11::detail

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台