You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.cpp 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696
  1. /**
  2. * \file imperative/python/src/grad.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma GCC diagnostic ignored "-Wmissing-field-initializers"
  12. #include "./grad.h"
  13. #include "megbrain/imperative/proxy_graph_detail.h"
  14. #include "megbrain/imperative/backward_graph_opt.h"
  15. #include "megbrain/imperative/ops/autogen.h"
  16. #include "megbrain/utils/mempool.h"
  17. #include "range/v3/all.hpp"
  18. namespace py = pybind11;
  19. namespace views = ranges::views;
  20. namespace mgb::imperative::python {
  21. using scoped_disable = ApplyContext::scoped_disable;
  22. using Flags = Tensor::Flags;
  23. namespace {
  24. struct GradSlotWeakPtr {
  25. std::weak_ptr<GradFn> grad_fn;
  26. size_t idx;
  27. };
  28. struct BackwardGraphCache : std::unordered_map<uint64_t, std::shared_ptr<OptimizedBackwardGraphResult>>, CompNodeDepedentObject {
  29. std::shared_ptr<void> on_comp_node_finalize() override {
  30. clear();
  31. return {};
  32. }
  33. } backward_graph_cache;
  34. std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph(
  35. ApplyContext& ctx, const apply_result_t& outputs) {
  36. // hash
  37. static_assert(alignof(size_t) % alignof(bool) == 0);
  38. size_t buf_size = (1 + ctx.nargs * 2) * sizeof(size_t) + ctx.nargs * sizeof(bool);
  39. alignas(alignof(size_t)) std::byte buf[buf_size];
  40. size_t* size_t_ptr = reinterpret_cast<size_t*>(buf);
  41. bool* bool_ptr = reinterpret_cast<bool*>(size_t_ptr + (1 + ctx.nargs * 2));
  42. bool* bool_ptr0 = bool_ptr;
  43. *(size_t_ptr++) = ctx.op->hash();
  44. for (size_t i = 0; i < ctx.nargs; ++i) {
  45. *(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle());
  46. *(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node());
  47. *(bool_ptr++) = !ctx.args[i]->m_grad_info_dict.empty();
  48. }
  49. mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) &&
  50. bool_ptr == reinterpret_cast<bool*>(buf + buf_size));
  51. uint64_t key = XXHash{}.update(buf, buf_size).digest();
  52. auto&& iter = backward_graph_cache.find(key);
  53. if (iter != backward_graph_cache.end()) {
  54. return iter->second;
  55. }
  56. // slow path
  57. SmallVector<LogicalTensorDesc> inputs(ctx.nargs);
  58. SmallVector<bool> input_requires_grad(ctx.nargs, false);
  59. SmallVector<bool> output_has_grad(outputs.size(), true);
  60. for (size_t i = 0; i < ctx.nargs; ++i) {
  61. inputs[i].comp_node = ctx.args[i]->comp_node();
  62. inputs[i].layout.dtype = ctx.args[i]->dtype();
  63. input_requires_grad[i] = python::input_requires_grad(ctx, i);
  64. }
  65. std::shared_ptr<OptimizedBackwardGraphResult> ret;
  66. auto bg = OpDef::make_backward_graph(
  67. *ctx.op, inputs, input_requires_grad, output_has_grad);
  68. if (!bg.backward.empty()) {
  69. ret = std::make_shared<OptimizedBackwardGraphResult>(bg);
  70. }
  71. backward_graph_cache.emplace(key, ret);
  72. return ret;
  73. }
  74. struct BackwardGraphWithClosure {
  75. std::shared_ptr<OptimizedBackwardGraphResult> backward_graph;
  76. SmallVector<std::shared_ptr<Tensor>> closure;
  77. size_t output_mask_offset;
  78. size_t grad_mask_offset;
  79. BackwardGraphWithClosure(std::shared_ptr<OptimizedBackwardGraphResult> backward_graph_,
  80. ApplyContext& ctx, const apply_result_t& outputs)
  81. : backward_graph(backward_graph_),
  82. output_mask_offset(ctx.nargs),
  83. grad_mask_offset(ctx.nargs + outputs.size()) {
  84. // save_for_backward[0:nargs]:
  85. // whether input is kept for backward
  86. //
  87. // save_for_backward[nargs:nargs+outputs.size()]:
  88. // whether output is kept for backward
  89. //
  90. // save_for_backward[-outputs.size():]:
  91. // whether gradient of output can propagate to any input
  92. //
  93. // Example:
  94. // perform c = a * b, with a.requires_grad == True and
  95. // b.requires_grad == False, save_for_backward = [0, 1, 0, 1]
  96. auto& save_for_backward = backward_graph->save_for_backward;
  97. mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size());
  98. size_t count = std::count_if(save_for_backward.begin(),
  99. save_for_backward.end(),
  100. ranges::identity{});
  101. if (!backward_graph->precomp.empty()) {
  102. auto&& irng = ranges::span(ctx.args, ctx.nargs);
  103. auto&& orng = views::transform(outputs, [](auto&& i){return i.get();});
  104. auto precomp = apply(backward_graph->precomp, views::concat(irng, orng));
  105. closure.reserve(precomp.size() + count);
  106. std::copy(precomp.begin(), precomp.end(), std::back_inserter(closure));
  107. } else {
  108. closure.reserve(count);
  109. }
  110. for (size_t i = 0; i < ctx.nargs; ++i) {
  111. if (save_for_backward[i]) {
  112. closure.push_back(ctx.args[i]->shared_from_this());
  113. }
  114. }
  115. for (size_t i = 0; i < outputs.size(); ++i) {
  116. if (save_for_backward[ctx.nargs + i]) {
  117. closure.push_back(outputs[i]);
  118. }
  119. }
  120. }
  121. template <typename T, typename R>
  122. void operator()(BackwardContext&, T&& grads, R&& receiver) {
  123. Tensor* args[closure.size() + grads.size()];
  124. size_t nargs = 0;
  125. for (auto&& t : closure) {
  126. args[nargs++] = t.get();
  127. }
  128. bool null_grad = false;
  129. for (size_t i = 0; i < grads.size(); ++i) {
  130. if (backward_graph->save_for_backward[grad_mask_offset + i]) {
  131. if (grads[i]) {
  132. if (null_grad) {
  133. PyErr_SetString(PyExc_NotImplementedError, "report to devs");
  134. throw py::error_already_set();
  135. }
  136. args[nargs++] = grads[i];
  137. } else {
  138. null_grad = true;
  139. }
  140. }
  141. }
  142. if (null_grad) return;
  143. auto igrads = apply(backward_graph->backward, args, nargs);
  144. auto&& it = igrads.begin();
  145. for (auto [i, p] : views::enumerate(backward_graph->input_has_grad)) {
  146. if (p) {
  147. receiver(i, std::move(*it));
  148. ++it;
  149. }
  150. }
  151. }
  152. bool input_has_grad(size_t i) {
  153. return backward_graph->input_has_grad[i];
  154. }
  155. bool output_requires_grad(size_t i) {
  156. return backward_graph->save_for_backward[grad_mask_offset + i];
  157. }
  158. bool output_captured(size_t i) {
  159. return backward_graph->save_for_backward[output_mask_offset + i];
  160. }
  161. };
  162. struct PythonBackward {
  163. py::object pyfunc;
  164. size_t input_size;
  165. PythonBackward(py::object f, size_t nin)
  166. : pyfunc(f), input_size(nin) {}
  167. template <typename T, typename R>
  168. void operator()(BackwardContext& ctx, T&& grads, R&& receiver) {
  169. auto args = py::tuple(grads.size());
  170. for (size_t i = 0; i < grads.size(); ++i) {
  171. auto&& g = grads[i];
  172. args[i] = g ? ctx.wrap_tensor(g) : py::none();
  173. }
  174. auto input_grads = py::reinterpret_steal<py::object>(PyObject_Call(pyfunc.ptr(), args.ptr(), nullptr));
  175. if (!input_grads) throw py::error_already_set();
  176. if (input_grads.is_none()) return;
  177. if (auto* tw = TensorWrapper::try_cast(input_grads.ptr())) {
  178. if (input_size != 1) {
  179. throw py::value_error("custom grad rule returned wrong number of grads");
  180. }
  181. if (!ctx.pytype) {
  182. ctx.pytype = Py_TYPE(input_grads.ptr());
  183. }
  184. receiver(0, tw->m_tensor);
  185. return;
  186. }
  187. if (py::len(input_grads) != input_size) {
  188. throw py::value_error("custom grad rule returned wrong number of grads");
  189. }
  190. for (auto [i, g] : views::enumerate(input_grads)) {
  191. if (g.is_none()) continue;
  192. auto* tw = TensorWrapper::try_cast(g.ptr());
  193. if (!tw) {
  194. throw py::type_error("custom grad rule returned non-tensor");
  195. }
  196. if (!ctx.pytype) {
  197. ctx.pytype = Py_TYPE(g.ptr());
  198. }
  199. receiver(i, tw->m_tensor);
  200. }
  201. }
  202. static constexpr bool input_has_grad(size_t) {return true;}
  203. static constexpr bool output_requires_grad(size_t) {return true;}
  204. static constexpr bool output_captured(size_t) {return true;}
  205. };
  206. } // namespace
  207. struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> {
  208. using Base = intrusive_list::Node<GradProducerRecord>;
  209. GradProducerRecord() = default;
  210. GradProducerRecord(GradProducerRecord::head_t& head) : Base(intrusive_list::after_t{}, head) {}
  211. // GradProducerRecord(GradProducerRecord&&) = default;
  212. // GradProducerRecord& operator=(GradProducerRecord&) = default;
  213. // GradProducerRecord& operator=(GradProducerRecord&&) = default;
  214. };
  215. struct GradSlot {
  216. std::shared_ptr<Tensor> grad;
  217. py::object callback;
  218. GradProducerRecord::head_t producer_head;
  219. };
  220. struct GradSlotProducerPtr : GradSlotPtr {
  221. GradProducerRecord producer_record;
  222. GradSlotProducerPtr() = default;
  223. GradSlotProducerPtr(GradInfo& info) : GradSlotPtr(info), producer_record(info->producer_head) {}
  224. };
  225. struct GradFn : std::enable_shared_from_this<GradFn> {
  226. static MemPool<GradFn> pool;
  227. std::weak_ptr<GradKey> key;
  228. // slots for receiving and accumulating grads
  229. // same length as outputs (of forward op)
  230. SmallVector<GradSlot> slots;
  231. // where to send and accumulate grads
  232. // same length as inputs (of forward op)
  233. SmallVector<GradSlotProducerPtr> dsts;
  234. // encapsules actual function to compute gradient
  235. std::variant<std::monostate, BackwardGraphWithClosure, PythonBackward, CustomBackward> backward;
  236. // a flag used during backward
  237. bool in_ref_keeper = false;
  238. static void deleter(GradFn* ptr) {
  239. pool.free(ptr);
  240. }
  241. static std::shared_ptr<GradFn> make() {
  242. return std::shared_ptr<GradFn>(pool.alloc(), &deleter);
  243. }
  244. void clear() {
  245. key.reset();
  246. slots.clear();
  247. dsts.clear();
  248. backward.emplace<std::monostate>();
  249. }
  250. };
  251. GradSlotPtr::operator bool() const {
  252. return bool(grad_fn);
  253. }
  254. GradSlot* GradSlotPtr::operator->() {
  255. return &grad_fn->slots[idx];
  256. }
  257. namespace {
  258. class GradFnHelper {
  259. std::shared_ptr<GradFn> grad_fn;
  260. GradFn* get() {
  261. if (!grad_fn) {
  262. grad_fn = std::make_shared<GradFn>();
  263. }
  264. return grad_fn.get();
  265. }
  266. friend apply_result_t imperative::python::apply_grad(ApplyContext&);
  267. public:
  268. template<typename T, typename... Args>
  269. auto& emplace(Args&&... args) {
  270. return get()->backward.emplace<T>(std::forward<Args>(args)...);
  271. }
  272. void reset() { grad_fn = nullptr; }
  273. };
  274. apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
  275. // copy inputs first, or trace will make InputNodes for each usage
  276. ApplyContext ctx_dup = ctx;
  277. SmallVector<std::shared_ptr<Tensor>> inputs_copy;
  278. SmallVector<Tensor*> inputs_copy_weak;
  279. for (size_t i = 0; i < ctx.nargs; ++i) {
  280. Tensor* input = ctx.args[i];
  281. inputs_copy.push_back(python::apply(FastpathCopy::make(), input)[0]);
  282. inputs_copy_weak.push_back(inputs_copy.back().get());
  283. inputs_copy.back()->m_grad_info_dict = ctx.args[i]->m_grad_info_dict;
  284. if (input->m_flags & Flags::GRAD) {
  285. inputs_copy.back()->m_flags |= Flags::GRAD;
  286. }
  287. }
  288. ctx_dup.args = inputs_copy_weak.data();
  289. auto outputs = apply(ctx_dup);
  290. auto backward_graph = make_backward_graph(ctx_dup, outputs);
  291. if (!backward_graph) {
  292. return outputs;
  293. }
  294. ret_grad_fn.emplace<BackwardGraphWithClosure>(std::move(backward_graph), ctx_dup, outputs);
  295. return outputs;
  296. }
  297. apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
  298. auto* op = ctx.op->try_cast_final<GenericPyOp>();
  299. py::tuple pyin(ctx.nargs);
  300. for (size_t i = 0; i < ctx.nargs; ++i) {
  301. pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this());
  302. }
  303. auto grad_rule = py::getattr(op->obj, "_grad_rule");
  304. auto pyret = py::reinterpret_steal<py::object>(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr));
  305. if (!pyret) throw py::error_already_set();
  306. auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret);
  307. ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs);
  308. if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) {
  309. return {tw->m_tensor};
  310. }
  311. apply_result_t ret;
  312. ret.reserve(py::len(outputs));
  313. for (auto&& i : outputs) {
  314. auto* tw = TensorWrapper::try_cast(i.ptr());
  315. mgb_assert(tw);
  316. ret.push_back(tw->m_tensor);
  317. }
  318. return ret;
  319. }
  320. } // namespace
  321. apply_result_t apply_grad(ApplyContext& ctx) {
  322. std::unordered_set<std::shared_ptr<GradKey>> grad_keys;
  323. for (size_t i = 0; i < ctx.nargs; ++i) {
  324. auto* tensor = ctx.args[i];
  325. if (!tensor->m_grad_info_dict.empty()) {
  326. size_t grad_cnt = 0;
  327. for (auto&& grad_info: tensor->m_grad_info_dict) {
  328. auto input_grad_key = grad_info.grad_fn->key.lock();
  329. if (input_grad_key && input_grad_key->active && !input_grad_key->is_blocked()) {
  330. grad_keys.insert(input_grad_key);
  331. grad_cnt++;
  332. }
  333. }
  334. if (!grad_cnt) {
  335. tensor->m_flags &= ~Flags::GRAD;
  336. }
  337. } else {
  338. tensor->m_flags &= ~Flags::GRAD;
  339. }
  340. }
  341. ctx.flags &= ~Flags::GRAD;
  342. if (grad_keys.empty()) {
  343. return apply(ctx);
  344. } else if (grad_keys.size() > 1 && !GradKey::allow_higher_order_directive) {
  345. PyErr_SetString(
  346. PyExc_NotImplementedError,
  347. "second order directive not enabled, please call "
  348. "'megengine.experimental.enable_higher_order_directive'");
  349. throw pyext17::py_err_set();
  350. }
  351. GradFnHelper grad_fn_holder;
  352. auto outputs = [&]() {
  353. auto _ = scoped_disable(Flags::GRAD);
  354. if (ctx.op->same_type<GenericPyOp>()) {
  355. return python_grad_rule(ctx, grad_fn_holder);
  356. }
  357. auto&& registry = grad_rule_registry();
  358. auto&& it = registry.find(ctx.op->dyn_typeinfo());
  359. if (it != registry.end()) {
  360. auto&& maker = grad_fn_holder.emplace<CustomBackward>().maker(ctx);
  361. try {
  362. auto ret = it->second(ctx, maker);
  363. maker.finalize();
  364. return ret;
  365. } catch (GradRuleFallback&) {
  366. grad_fn_holder.reset();
  367. }
  368. }
  369. return backward_graph_grad_rule(ctx, grad_fn_holder);
  370. }();
  371. if (!grad_fn_holder.grad_fn) {
  372. return outputs;
  373. }
  374. for (auto&& grad_key: grad_keys) {
  375. auto grad_fn = std::make_shared<GradFn>();
  376. grad_fn->backward = grad_fn_holder.grad_fn->backward;
  377. grad_fn->key = grad_key;
  378. grad_fn->slots.resize(outputs.size());
  379. grad_fn->dsts.reserve(ctx.nargs);
  380. std::visit([&](auto& backward) {
  381. using T = std::decay_t<decltype(backward)>;
  382. if constexpr (std::is_same_v<T, std::monostate>) {
  383. mgb_assert(0);
  384. } else {
  385. for (size_t i = 0; i < ctx.nargs; ++i) {
  386. if (backward.input_has_grad(i) && input_requires_grad(ctx, i) && ctx.args[i]->m_grad_info_dict.count(grad_key.get())) {
  387. auto& input_grad_info = ctx.args[i]->m_grad_info_dict.at(grad_key.get());
  388. grad_fn->dsts.emplace_back(input_grad_info);
  389. // register as grad producer
  390. grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head);
  391. } else {
  392. grad_fn->dsts.emplace_back();
  393. }
  394. }
  395. for (size_t i = 0; i < outputs.size(); ++i) {
  396. if (backward.output_requires_grad(i)) {
  397. if (backward.output_captured(i)) {
  398. // avoid reference cycle [Tensor <-> GradFn]
  399. static std::shared_ptr<OpDef> op = std::make_shared<FastpathCopy>();
  400. outputs[i] = python::apply(op, outputs[i])[0];
  401. }
  402. // populate grad info of output tensor
  403. auto& grad_info = outputs[i]->m_grad_info_dict[grad_key.get()];
  404. grad_info.grad_fn = grad_fn;
  405. grad_info.idx = i;
  406. grad_info.insert_after(grad_key->free_vars_head);
  407. outputs[i]->m_flags |= Flags::GRAD;
  408. }
  409. }
  410. }
  411. }, grad_fn->backward);
  412. // record forward history
  413. grad_key->tape.emplace_back(grad_fn);
  414. }
  415. return outputs;
  416. }
  417. PyObject* GradKeyWrapper::get_priority() {
  418. return py::cast(m_key->priority).release().ptr();
  419. }
  420. void GradKeyWrapper::set_priority(pybind11::handle priority) {
  421. m_key->priority = py::cast<int>(priority);
  422. }
  423. void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) {
  424. if (nargs != 2) {
  425. throw py::type_error("expect 2 arguments");
  426. }
  427. auto* tw = TensorWrapper::try_cast(args[0]);
  428. if (!tw) {
  429. throw py::type_error("argument 1 must be Tensor");
  430. }
  431. auto* tensor = tw->m_tensor.get();
  432. py::object callback;
  433. if (args[1] != Py_None) {
  434. callback = py::reinterpret_borrow<py::object>(args[1]);
  435. }
  436. m_key->attach(tensor, std::move(callback));
  437. }
  438. //! GradKey is weakly refered by tensor->m_grad_info.grad_fn->key after attach
  439. void GradKey::attach(Tensor* tensor, pybind11::object callback) {
  440. if (!active) {
  441. throw py::value_error("grad key finalized");
  442. }
  443. if (tensor->m_grad_info_dict.count(this)) {
  444. if (tensor->m_grad_info_dict.at(this)->callback) {
  445. throw py::value_error("callback already set on this tensor");
  446. }
  447. } else {
  448. auto& grad_info = tensor->m_grad_info_dict[this];
  449. grad_info.idx = 0;
  450. auto& grad_fn = grad_info.grad_fn;
  451. grad_fn = std::make_shared<GradFn>();
  452. grad_fn->key = shared_from_this();
  453. grad_fn->slots.resize(1);
  454. grad_info.insert_after(free_vars_head);
  455. tensor->m_flags |= Flags::GRAD;
  456. }
  457. tensor->m_grad_info_dict.at(this).grad_fn->slots[0].callback = std::move(callback);
  458. }
  459. template<typename T>
  460. void accum_grad(std::shared_ptr<Tensor>& grad, T&& delta) {
  461. if (!grad) {
  462. grad = std::forward<T>(delta);
  463. return;
  464. }
  465. static std::shared_ptr<OpDef> op = std::shared_ptr<OpDef>(new Elemwise(Elemwise::Mode::ADD));
  466. grad = apply(op, grad, std::forward<T>(delta))[0];
  467. }
  468. void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) {
  469. if (!active) {
  470. throw py::value_error("finalized");
  471. }
  472. if (tensors.size() != grads.size()) {
  473. throw py::value_error("tensor and grad size mismatch");
  474. }
  475. // this GradKey is marked inactive here
  476. active = false;
  477. struct CleanupGuard {
  478. GradKey* owner;
  479. size_t priority_backup;
  480. CleanupGuard(GradKey* this_) : owner(this_) {
  481. priority_backup = sm_min_priority;
  482. sm_min_priority = owner->priority + 1;
  483. }
  484. ~CleanupGuard() {
  485. owner->cleanup();
  486. sm_min_priority = priority_backup;
  487. }
  488. } _cleanup_guard(this);
  489. if (tape.empty()) return;
  490. BackwardContext bctx;
  491. if (!grads.empty()) {
  492. bctx.pytype = Py_TYPE(grads[0]->self().ptr());
  493. }
  494. for (size_t i = 0; i < tensors.size(); ++i) {
  495. if (tensors[i]->m_tensor->m_grad_info_dict.count(this) == 0) {
  496. continue;
  497. }
  498. auto& grad_info = tensors[i]->m_tensor->m_grad_info_dict.at(this);
  499. grad_info->grad = grads[i]->m_tensor;
  500. }
  501. std::vector<std::shared_ptr<GradFn>> ref_keeper;
  502. ref_keeper.reserve(tape.size());
  503. // back-propagation in reverse order
  504. for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) {
  505. auto&& grad_fn = tape[k].lock();
  506. if (!grad_fn) continue;
  507. auto grad_receiver = [&](size_t i, auto&& g) {
  508. auto& dst = grad_fn->dsts[i];
  509. if (dst) {
  510. accum_grad(dst->grad, std::forward<decltype(g)>(g));
  511. }
  512. };
  513. std::visit([&](auto&& backward) {
  514. using T = std::decay_t<decltype(backward)>;
  515. if constexpr (std::is_same_v<T, std::monostate>) {
  516. mgb_assert(0);
  517. } else {
  518. auto&& grads = views::transform(grad_fn->slots, [](auto&& slot) {return slot.grad.get();});
  519. backward(bctx, std::forward<decltype(grads)>(grads), grad_receiver);
  520. }
  521. }, grad_fn->backward);
  522. for (auto&& dst : grad_fn->dsts) {
  523. if (!dst.grad_fn) continue;
  524. if (!dst.grad_fn->in_ref_keeper) {
  525. // after grad_fn is cleared, refcnt of subsequent grad_fn
  526. // could drop to 0
  527. dst.grad_fn->in_ref_keeper = true;
  528. ref_keeper.push_back(dst.grad_fn);
  529. }
  530. if (!dst.producer_record.next && dst->callback && dst->grad) {
  531. // I'm the last grad producer, invoke callback
  532. dst->callback(bctx.wrap_tensor(dst->grad));
  533. }
  534. }
  535. grad_fn->clear();
  536. } // finish tape loop
  537. }
  538. void GradKey::cleanup() {
  539. active = false;
  540. tape.clear();
  541. for (intrusive_list::Iterator it(free_vars_head); it;) {
  542. it->grad_fn.reset();
  543. (it++)->unlink();
  544. }
  545. }
  546. void GradKeyWrapper::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) {
  547. m_key->backward(std::move(tensors), std::move(grads));
  548. }
  549. PyObject* GradKeyWrapper::get_name() {
  550. return py::cast(m_key->name).release().ptr();
  551. }
  552. void GradKeyWrapper::set_name(py::handle name) {
  553. m_key->name = py::cast<std::string>(name);
  554. }
  555. PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) {
  556. if (nargs != 1) {
  557. PyErr_SetString(PyExc_TypeError, "expect 1 argument");
  558. return nullptr;
  559. }
  560. auto* tw = TensorWrapper::try_cast(args[0]);
  561. if (!tw) {
  562. PyErr_SetString(PyExc_TypeError, "expect Tensor");
  563. return nullptr;
  564. }
  565. if (tw->m_tensor->m_grad_info_dict.count(m_key.get())) {
  566. Py_RETURN_TRUE;
  567. }
  568. Py_RETURN_FALSE;
  569. }
  570. int GradKey::sm_min_priority = std::numeric_limits<int>::min();
  571. GradKey::~GradKey() {
  572. cleanup();
  573. }
  574. std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry() {
  575. static std::unordered_map<Typeinfo*, GradRuleFn> registry;
  576. return registry;
  577. }
  578. void GradInfoCollection::_shrink() {
  579. auto pred = [](GradInfo& info){ return !(info.grad_fn) || info.grad_fn->key.expired(); };
  580. auto iter = std::remove_if(m_storage.begin(), m_storage.end(), pred);
  581. m_storage.erase(iter, m_storage.end());
  582. }
  583. bool GradInfoCollection::contains(GradKey* key) {
  584. _shrink();
  585. for (auto&& grad_info: m_storage) {
  586. if (grad_info.grad_fn->key.lock().get() == key) {
  587. return true;
  588. }
  589. }
  590. return false;
  591. }
  592. GradInfo& GradInfoCollection::operator[](GradKey* key) {
  593. _shrink();
  594. for (auto&& grad_info: m_storage) {
  595. if (grad_info.grad_fn->key.lock().get() == key) {
  596. return grad_info;
  597. }
  598. }
  599. m_storage.emplace_back();
  600. return m_storage.back();
  601. }
  602. GradInfo& GradInfoCollection::at(GradKey* key) {
  603. _shrink();
  604. for (auto&& grad_info: m_storage) {
  605. if (grad_info.grad_fn->key.lock().get() == key) {
  606. return grad_info;
  607. }
  608. }
  609. mgb_assert(false);
  610. }
  611. } // namespace mgb::imperative::python

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台