You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.cpp 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679
  1. /**
  2. * \file imperative/python/src/grad.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma GCC diagnostic ignored "-Wmissing-field-initializers"
  12. #include "./grad.h"
  13. #include "megbrain/imperative/proxy_graph_detail.h"
  14. #include "megbrain/imperative/backward_graph_opt.h"
  15. #include "megbrain/imperative/ops/autogen.h"
  16. #include "megbrain/utils/mempool.h"
  17. #include "range/v3/all.hpp"
  18. namespace py = pybind11;
  19. namespace views = ranges::views;
  20. namespace mgb::imperative::python {
  21. using scoped_disable = ApplyContext::scoped_disable;
  22. using Flags = Tensor::Flags;
  23. namespace {
  24. struct GradSlotWeakPtr {
  25. std::weak_ptr<GradFn> grad_fn;
  26. size_t idx;
  27. };
  28. std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph(
  29. ApplyContext& ctx, const apply_result_t& outputs) {
  30. // hash
  31. using OptimizedBackwardGraphCache = OpMethResultCache<std::shared_ptr<OptimizedBackwardGraphResult>, SmallVector<bool>>;
  32. thread_local OptimizedBackwardGraphCache cache;
  33. decltype(cache)::key_t cache_key{ctx.op};
  34. SmallVector<LogicalTensorDesc>& input_descs = cache_key.inputs;
  35. SmallVector<bool>& input_requires_grad = std::get<0>(cache_key.extras);
  36. input_descs.resize(ctx.nargs);
  37. input_requires_grad.resize(ctx.nargs);
  38. for (size_t i = 0; i < ctx.nargs; ++i) {
  39. input_descs[i].layout.dtype = ctx.args[i]->dtype();
  40. input_descs[i].comp_node = ctx.args[i]->comp_node();
  41. input_requires_grad[i] = python::input_requires_grad(ctx, i);
  42. }
  43. auto iter = cache.find(cache_key);
  44. if (iter != cache.end()) {
  45. return iter->second;
  46. }
  47. // slow path
  48. SmallVector<bool> output_has_grad(outputs.size(), true);
  49. std::shared_ptr<OptimizedBackwardGraphResult> ret;
  50. auto bg = OpDef::make_backward_graph(
  51. *ctx.op, input_descs, input_requires_grad, output_has_grad);
  52. if (!bg.graph.empty()) {
  53. ret = std::make_shared<OptimizedBackwardGraphResult>(bg);
  54. }
  55. cache.emplace(cache_key, ret);
  56. return ret;
  57. }
  58. struct BackwardGraphWithClosure {
  59. std::shared_ptr<OptimizedBackwardGraphResult> backward_graph;
  60. SmallVector<std::shared_ptr<Tensor>> closure;
  61. size_t output_mask_offset;
  62. size_t grad_mask_offset;
  63. BackwardGraphWithClosure(std::shared_ptr<OptimizedBackwardGraphResult> backward_graph_,
  64. ApplyContext& ctx, const apply_result_t& outputs)
  65. : backward_graph(backward_graph_),
  66. output_mask_offset(ctx.nargs),
  67. grad_mask_offset(ctx.nargs + outputs.size()) {
  68. // save_for_backward[0:nargs]:
  69. // whether input is kept for backward
  70. //
  71. // save_for_backward[nargs:nargs+outputs.size()]:
  72. // whether output is kept for backward
  73. //
  74. // save_for_backward[-outputs.size():]:
  75. // whether gradient of output can propagate to any input
  76. //
  77. // Example:
  78. // perform c = a * b, with a.requires_grad == True and
  79. // b.requires_grad == False, save_for_backward = [0, 1, 0, 1]
  80. auto& save_for_backward = backward_graph->save_for_backward;
  81. mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size());
  82. size_t count = std::count_if(save_for_backward.begin(),
  83. save_for_backward.end(),
  84. ranges::identity{});
  85. if (!backward_graph->precomp.empty()) {
  86. auto&& irng = ranges::span(ctx.args, ctx.nargs);
  87. auto&& orng = views::transform(outputs, [](auto&& i){return i.get();});
  88. auto precomp = apply(backward_graph->precomp, views::concat(irng, orng));
  89. closure.reserve(precomp.size() + count);
  90. std::copy(precomp.begin(), precomp.end(), std::back_inserter(closure));
  91. } else {
  92. closure.reserve(count);
  93. }
  94. for (size_t i = 0; i < ctx.nargs; ++i) {
  95. if (save_for_backward[i]) {
  96. closure.push_back(ctx.args[i]->shared_from_this());
  97. }
  98. }
  99. for (size_t i = 0; i < outputs.size(); ++i) {
  100. if (save_for_backward[ctx.nargs + i]) {
  101. closure.push_back(outputs[i]);
  102. }
  103. }
  104. }
  105. template <typename T, typename R>
  106. void operator()(BackwardContext&, T&& grads, R&& receiver) {
  107. Tensor* args[closure.size() + grads.size()];
  108. size_t nargs = 0;
  109. for (auto&& t : closure) {
  110. args[nargs++] = t.get();
  111. }
  112. bool null_grad = false;
  113. for (size_t i = 0; i < grads.size(); ++i) {
  114. if (backward_graph->save_for_backward[grad_mask_offset + i]) {
  115. if (grads[i]) {
  116. if (null_grad) {
  117. PyErr_SetString(PyExc_NotImplementedError, "report to devs");
  118. throw py::error_already_set();
  119. }
  120. args[nargs++] = grads[i];
  121. } else {
  122. null_grad = true;
  123. }
  124. }
  125. }
  126. if (null_grad) return;
  127. auto igrads = apply(backward_graph->backward, args, nargs);
  128. auto&& it = igrads.begin();
  129. for (auto [i, p] : views::enumerate(backward_graph->input_has_grad)) {
  130. if (p) {
  131. receiver(i, std::move(*it));
  132. ++it;
  133. }
  134. }
  135. }
  136. bool input_has_grad(size_t i) {
  137. return backward_graph->input_has_grad[i];
  138. }
  139. bool output_requires_grad(size_t i) {
  140. return backward_graph->save_for_backward[grad_mask_offset + i];
  141. }
  142. bool output_captured(size_t i) {
  143. return backward_graph->save_for_backward[output_mask_offset + i];
  144. }
  145. };
  146. struct PythonBackward {
  147. py::object pyfunc;
  148. size_t input_size;
  149. PythonBackward(py::object f, size_t nin)
  150. : pyfunc(f), input_size(nin) {}
  151. template <typename T, typename R>
  152. void operator()(BackwardContext& ctx, T&& grads, R&& receiver) {
  153. auto args = py::tuple(grads.size());
  154. for (size_t i = 0; i < grads.size(); ++i) {
  155. auto&& g = grads[i];
  156. args[i] = g ? ctx.wrap_tensor(g) : py::none();
  157. }
  158. auto input_grads = py::reinterpret_steal<py::object>(PyObject_Call(pyfunc.ptr(), args.ptr(), nullptr));
  159. if (!input_grads) throw py::error_already_set();
  160. if (input_grads.is_none()) return;
  161. if (auto* tw = TensorWrapper::try_cast(input_grads.ptr())) {
  162. if (input_size != 1) {
  163. throw py::value_error("custom grad rule returned wrong number of grads");
  164. }
  165. if (!ctx.pytype) {
  166. ctx.pytype = Py_TYPE(input_grads.ptr());
  167. }
  168. receiver(0, tw->m_tensor);
  169. return;
  170. }
  171. if (py::len(input_grads) != input_size) {
  172. throw py::value_error("custom grad rule returned wrong number of grads");
  173. }
  174. for (auto [i, g] : views::enumerate(input_grads)) {
  175. if (g.is_none()) continue;
  176. auto* tw = TensorWrapper::try_cast(g.ptr());
  177. if (!tw) {
  178. throw py::type_error("custom grad rule returned non-tensor");
  179. }
  180. if (!ctx.pytype) {
  181. ctx.pytype = Py_TYPE(g.ptr());
  182. }
  183. receiver(i, tw->m_tensor);
  184. }
  185. }
  186. static constexpr bool input_has_grad(size_t) {return true;}
  187. static constexpr bool output_requires_grad(size_t) {return true;}
  188. static constexpr bool output_captured(size_t) {return true;}
  189. };
  190. } // namespace
  191. struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> {
  192. using Base = intrusive_list::Node<GradProducerRecord>;
  193. GradProducerRecord() = default;
  194. GradProducerRecord(GradProducerRecord::head_t& head) : Base(intrusive_list::after_t{}, head) {}
  195. // GradProducerRecord(GradProducerRecord&&) = default;
  196. // GradProducerRecord& operator=(GradProducerRecord&) = default;
  197. // GradProducerRecord& operator=(GradProducerRecord&&) = default;
  198. };
  199. struct GradSlot {
  200. std::shared_ptr<Tensor> grad;
  201. py::object callback;
  202. GradProducerRecord::head_t producer_head;
  203. };
  204. struct GradSlotProducerPtr : GradSlotPtr {
  205. GradProducerRecord producer_record;
  206. GradSlotProducerPtr() = default;
  207. GradSlotProducerPtr(GradInfo& info) : GradSlotPtr(info), producer_record(info->producer_head) {}
  208. };
  209. struct GradFn : std::enable_shared_from_this<GradFn> {
  210. static MemPool<GradFn> pool;
  211. std::weak_ptr<GradKey> key;
  212. // slots for receiving and accumulating grads
  213. // same length as outputs (of forward op)
  214. SmallVector<GradSlot> slots;
  215. // where to send and accumulate grads
  216. // same length as inputs (of forward op)
  217. SmallVector<GradSlotProducerPtr> dsts;
  218. // encapsules actual function to compute gradient
  219. std::variant<std::monostate, BackwardGraphWithClosure, PythonBackward, CustomBackward> backward;
  220. // a flag used during backward
  221. bool in_ref_keeper = false;
  222. static void deleter(GradFn* ptr) {
  223. pool.free(ptr);
  224. }
  225. static std::shared_ptr<GradFn> make() {
  226. return std::shared_ptr<GradFn>(pool.alloc(), &deleter);
  227. }
  228. void clear() {
  229. key.reset();
  230. slots.clear();
  231. dsts.clear();
  232. backward.emplace<std::monostate>();
  233. }
  234. };
  235. GradSlotPtr::operator bool() const {
  236. return bool(grad_fn);
  237. }
  238. GradSlot* GradSlotPtr::operator->() {
  239. return &grad_fn->slots[idx];
  240. }
  241. namespace {
  242. class GradFnHelper {
  243. std::shared_ptr<GradFn> grad_fn;
  244. GradFn* get() {
  245. if (!grad_fn) {
  246. grad_fn = std::make_shared<GradFn>();
  247. }
  248. return grad_fn.get();
  249. }
  250. friend apply_result_t imperative::python::apply_grad(ApplyContext&);
  251. public:
  252. template<typename T, typename... Args>
  253. auto& emplace(Args&&... args) {
  254. return get()->backward.emplace<T>(std::forward<Args>(args)...);
  255. }
  256. void reset() { grad_fn = nullptr; }
  257. };
  258. apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
  259. // copy inputs first, or trace will make InputNodes for each usage
  260. ApplyContext ctx_dup = ctx;
  261. SmallVector<std::shared_ptr<Tensor>> inputs_copy;
  262. SmallVector<Tensor*> inputs_copy_weak;
  263. for (size_t i = 0; i < ctx.nargs; ++i) {
  264. Tensor* input = ctx.args[i];
  265. inputs_copy.push_back(python::apply(FastpathCopy::make(), input)[0]);
  266. inputs_copy_weak.push_back(inputs_copy.back().get());
  267. inputs_copy.back()->m_grad_info_dict = ctx.args[i]->m_grad_info_dict;
  268. if (input->m_flags & Flags::GRAD) {
  269. inputs_copy.back()->m_flags |= Flags::GRAD;
  270. }
  271. }
  272. ctx_dup.args = inputs_copy_weak.data();
  273. auto outputs = apply(ctx_dup);
  274. auto backward_graph = make_backward_graph(ctx_dup, outputs);
  275. if (!backward_graph) {
  276. return outputs;
  277. }
  278. ret_grad_fn.emplace<BackwardGraphWithClosure>(std::move(backward_graph), ctx_dup, outputs);
  279. return outputs;
  280. }
  281. apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
  282. auto* op = ctx.op->try_cast_final<GenericPyOp>();
  283. py::tuple pyin(ctx.nargs);
  284. for (size_t i = 0; i < ctx.nargs; ++i) {
  285. pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this());
  286. }
  287. auto grad_rule = py::getattr(op->obj, "_grad_rule");
  288. auto pyret = py::reinterpret_steal<py::object>(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr));
  289. if (!pyret) throw py::error_already_set();
  290. auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret);
  291. ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs);
  292. if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) {
  293. return {tw->m_tensor};
  294. }
  295. apply_result_t ret;
  296. ret.reserve(py::len(outputs));
  297. for (auto&& i : outputs) {
  298. auto* tw = TensorWrapper::try_cast(i.ptr());
  299. mgb_assert(tw);
  300. ret.push_back(tw->m_tensor);
  301. }
  302. return ret;
  303. }
  304. } // namespace
  305. apply_result_t apply_grad(ApplyContext& ctx) {
  306. std::unordered_set<std::shared_ptr<GradKey>> grad_keys;
  307. for (size_t i = 0; i < ctx.nargs; ++i) {
  308. auto* tensor = ctx.args[i];
  309. if (!tensor->m_grad_info_dict.empty()) {
  310. size_t grad_cnt = 0;
  311. for (auto&& grad_info: tensor->m_grad_info_dict) {
  312. auto input_grad_key = grad_info.grad_fn->key.lock();
  313. if (input_grad_key && input_grad_key->active && !input_grad_key->is_blocked()) {
  314. grad_keys.insert(input_grad_key);
  315. grad_cnt++;
  316. }
  317. }
  318. if (!grad_cnt) {
  319. tensor->m_flags &= ~Flags::GRAD;
  320. }
  321. } else {
  322. tensor->m_flags &= ~Flags::GRAD;
  323. }
  324. }
  325. ctx.flags &= ~Flags::GRAD;
  326. if (grad_keys.empty()) {
  327. return apply(ctx);
  328. } else if (grad_keys.size() > 1 && !GradKey::allow_higher_order_directive) {
  329. PyErr_SetString(
  330. PyExc_NotImplementedError,
  331. "second order directive not enabled, please call "
  332. "'megengine.experimental.enable_higher_order_directive'");
  333. throw pyext17::py_err_set();
  334. }
  335. GradFnHelper grad_fn_holder;
  336. auto outputs = [&]() {
  337. auto _ = scoped_disable(Flags::GRAD);
  338. if (ctx.op->same_type<GenericPyOp>()) {
  339. return python_grad_rule(ctx, grad_fn_holder);
  340. }
  341. auto&& registry = grad_rule_registry();
  342. auto&& it = registry.find(ctx.op->dyn_typeinfo());
  343. if (it != registry.end()) {
  344. auto&& maker = grad_fn_holder.emplace<CustomBackward>().maker(ctx);
  345. try {
  346. auto ret = it->second(ctx, maker);
  347. maker.finalize();
  348. return ret;
  349. } catch (GradRuleFallback&) {
  350. grad_fn_holder.reset();
  351. }
  352. }
  353. return backward_graph_grad_rule(ctx, grad_fn_holder);
  354. }();
  355. if (!grad_fn_holder.grad_fn) {
  356. return outputs;
  357. }
  358. for (auto&& grad_key: grad_keys) {
  359. auto grad_fn = std::make_shared<GradFn>();
  360. grad_fn->backward = grad_fn_holder.grad_fn->backward;
  361. grad_fn->key = grad_key;
  362. grad_fn->slots.resize(outputs.size());
  363. grad_fn->dsts.reserve(ctx.nargs);
  364. std::visit([&](auto& backward) {
  365. using T = std::decay_t<decltype(backward)>;
  366. if constexpr (std::is_same_v<T, std::monostate>) {
  367. mgb_assert(0);
  368. } else {
  369. for (size_t i = 0; i < ctx.nargs; ++i) {
  370. if (backward.input_has_grad(i) && input_requires_grad(ctx, i) && ctx.args[i]->m_grad_info_dict.count(grad_key.get())) {
  371. auto& input_grad_info = ctx.args[i]->m_grad_info_dict.at(grad_key.get());
  372. grad_fn->dsts.emplace_back(input_grad_info);
  373. // register as grad producer
  374. grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head);
  375. } else {
  376. grad_fn->dsts.emplace_back();
  377. }
  378. }
  379. for (size_t i = 0; i < outputs.size(); ++i) {
  380. if (backward.output_requires_grad(i)) {
  381. if (backward.output_captured(i)) {
  382. // avoid reference cycle [Tensor <-> GradFn]
  383. static std::shared_ptr<OpDef> op = std::make_shared<FastpathCopy>();
  384. outputs[i] = python::apply(op, outputs[i])[0];
  385. }
  386. // populate grad info of output tensor
  387. auto& grad_info = outputs[i]->m_grad_info_dict[grad_key.get()];
  388. grad_info.grad_fn = grad_fn;
  389. grad_info.idx = i;
  390. grad_info.insert_after(grad_key->free_vars_head);
  391. outputs[i]->m_flags |= Flags::GRAD;
  392. }
  393. }
  394. }
  395. }, grad_fn->backward);
  396. // record forward history
  397. grad_key->tape.emplace_back(grad_fn);
  398. }
  399. return outputs;
  400. }
  401. PyObject* GradKeyWrapper::get_priority() {
  402. return py::cast(m_key->priority).release().ptr();
  403. }
  404. void GradKeyWrapper::set_priority(pybind11::handle priority) {
  405. m_key->priority = py::cast<int>(priority);
  406. }
  407. void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) {
  408. if (nargs != 2) {
  409. throw py::type_error("expect 2 arguments");
  410. }
  411. auto* tw = TensorWrapper::try_cast(args[0]);
  412. if (!tw) {
  413. throw py::type_error("argument 1 must be Tensor");
  414. }
  415. auto* tensor = tw->m_tensor.get();
  416. py::object callback;
  417. if (args[1] != Py_None) {
  418. callback = py::reinterpret_borrow<py::object>(args[1]);
  419. }
  420. m_key->attach(tensor, std::move(callback));
  421. }
  422. //! GradKey is weakly refered by tensor->m_grad_info.grad_fn->key after attach
  423. void GradKey::attach(Tensor* tensor, pybind11::object callback) {
  424. if (!active) {
  425. throw py::value_error("grad key finalized");
  426. }
  427. if (tensor->m_grad_info_dict.count(this)) {
  428. if (tensor->m_grad_info_dict.at(this)->callback) {
  429. throw py::value_error("callback already set on this tensor");
  430. }
  431. } else {
  432. auto& grad_info = tensor->m_grad_info_dict[this];
  433. grad_info.idx = 0;
  434. auto& grad_fn = grad_info.grad_fn;
  435. grad_fn = std::make_shared<GradFn>();
  436. grad_fn->key = shared_from_this();
  437. grad_fn->slots.resize(1);
  438. grad_info.insert_after(free_vars_head);
  439. tensor->m_flags |= Flags::GRAD;
  440. }
  441. tensor->m_grad_info_dict.at(this).grad_fn->slots[0].callback = std::move(callback);
  442. }
  443. template<typename T>
  444. void accum_grad(std::shared_ptr<Tensor>& grad, T&& delta) {
  445. if (!grad) {
  446. grad = std::forward<T>(delta);
  447. return;
  448. }
  449. static std::shared_ptr<OpDef> op = std::shared_ptr<OpDef>(new Elemwise(Elemwise::Mode::ADD));
  450. grad = apply(op, grad, std::forward<T>(delta))[0];
  451. }
  452. void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) {
  453. if (!active) {
  454. throw py::value_error("finalized");
  455. }
  456. if (tensors.size() != grads.size()) {
  457. throw py::value_error("tensor and grad size mismatch");
  458. }
  459. // this GradKey is marked inactive here
  460. active = false;
  461. struct CleanupGuard {
  462. GradKey* owner;
  463. size_t priority_backup;
  464. CleanupGuard(GradKey* this_) : owner(this_) {
  465. priority_backup = sm_min_priority;
  466. sm_min_priority = owner->priority + 1;
  467. }
  468. ~CleanupGuard() {
  469. owner->cleanup();
  470. sm_min_priority = priority_backup;
  471. }
  472. } _cleanup_guard(this);
  473. if (tape.empty()) return;
  474. BackwardContext bctx;
  475. if (!grads.empty()) {
  476. bctx.pytype = Py_TYPE(grads[0]->self().ptr());
  477. }
  478. for (size_t i = 0; i < tensors.size(); ++i) {
  479. if (tensors[i]->m_tensor->m_grad_info_dict.count(this) == 0) {
  480. continue;
  481. }
  482. auto& grad_info = tensors[i]->m_tensor->m_grad_info_dict.at(this);
  483. grad_info->grad = grads[i]->m_tensor;
  484. }
  485. std::vector<std::shared_ptr<GradFn>> ref_keeper;
  486. ref_keeper.reserve(tape.size());
  487. // back-propagation in reverse order
  488. for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) {
  489. auto&& grad_fn = tape[k].lock();
  490. if (!grad_fn) continue;
  491. auto grad_receiver = [&](size_t i, auto&& g) {
  492. auto& dst = grad_fn->dsts[i];
  493. if (dst) {
  494. accum_grad(dst->grad, std::forward<decltype(g)>(g));
  495. }
  496. };
  497. std::visit([&](auto&& backward) {
  498. using T = std::decay_t<decltype(backward)>;
  499. if constexpr (std::is_same_v<T, std::monostate>) {
  500. mgb_assert(0);
  501. } else {
  502. auto&& grads = views::transform(grad_fn->slots, [](auto&& slot) {return slot.grad.get();});
  503. backward(bctx, std::forward<decltype(grads)>(grads), grad_receiver);
  504. }
  505. }, grad_fn->backward);
  506. for (auto&& dst : grad_fn->dsts) {
  507. if (!dst.grad_fn) continue;
  508. if (!dst.grad_fn->in_ref_keeper) {
  509. // after grad_fn is cleared, refcnt of subsequent grad_fn
  510. // could drop to 0
  511. dst.grad_fn->in_ref_keeper = true;
  512. ref_keeper.push_back(dst.grad_fn);
  513. }
  514. if (!dst.producer_record.next && dst->callback && dst->grad) {
  515. // I'm the last grad producer, invoke callback
  516. dst->callback(bctx.wrap_tensor(dst->grad));
  517. }
  518. }
  519. grad_fn->clear();
  520. } // finish tape loop
  521. }
  522. void GradKey::cleanup() {
  523. active = false;
  524. tape.clear();
  525. for (intrusive_list::Iterator it(free_vars_head); it;) {
  526. it->grad_fn.reset();
  527. (it++)->unlink();
  528. }
  529. }
  530. void GradKeyWrapper::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) {
  531. m_key->backward(std::move(tensors), std::move(grads));
  532. }
  533. PyObject* GradKeyWrapper::get_name() {
  534. return py::cast(m_key->name).release().ptr();
  535. }
  536. void GradKeyWrapper::set_name(py::handle name) {
  537. m_key->name = py::cast<std::string>(name);
  538. }
  539. PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) {
  540. if (nargs != 1) {
  541. PyErr_SetString(PyExc_TypeError, "expect 1 argument");
  542. return nullptr;
  543. }
  544. auto* tw = TensorWrapper::try_cast(args[0]);
  545. if (!tw) {
  546. PyErr_SetString(PyExc_TypeError, "expect Tensor");
  547. return nullptr;
  548. }
  549. if (tw->m_tensor->m_grad_info_dict.count(m_key.get())) {
  550. Py_RETURN_TRUE;
  551. }
  552. Py_RETURN_FALSE;
  553. }
  554. int GradKey::sm_min_priority = std::numeric_limits<int>::min();
  555. GradKey::~GradKey() {
  556. cleanup();
  557. }
  558. std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry() {
  559. static std::unordered_map<Typeinfo*, GradRuleFn> registry;
  560. return registry;
  561. }
  562. void GradInfoCollection::_shrink() {
  563. auto pred = [](GradInfo& info){ return !(info.grad_fn) || info.grad_fn->key.expired(); };
  564. auto iter = std::remove_if(m_storage.begin(), m_storage.end(), pred);
  565. m_storage.erase(iter, m_storage.end());
  566. }
  567. bool GradInfoCollection::contains(GradKey* key) {
  568. _shrink();
  569. for (auto&& grad_info: m_storage) {
  570. if (grad_info.grad_fn->key.lock().get() == key) {
  571. return true;
  572. }
  573. }
  574. return false;
  575. }
  576. GradInfo& GradInfoCollection::operator[](GradKey* key) {
  577. _shrink();
  578. for (auto&& grad_info: m_storage) {
  579. if (grad_info.grad_fn->key.lock().get() == key) {
  580. return grad_info;
  581. }
  582. }
  583. m_storage.emplace_back();
  584. return m_storage.back();
  585. }
  586. GradInfo& GradInfoCollection::at(GradKey* key) {
  587. _shrink();
  588. for (auto&& grad_info: m_storage) {
  589. if (grad_info.grad_fn->key.lock().get() == key) {
  590. return grad_info;
  591. }
  592. }
  593. mgb_assert(false);
  594. }
  595. } // namespace mgb::imperative::python

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台