You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.cpp 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. /**
  2. * \file imperative/src/impl/transformations/grad.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/imperative/transformations/grad.h"
  12. #include "megbrain/imperative/graph_cache.h"
  13. #include <range/v3/all.hpp>
  14. namespace mgb {
  15. namespace imperative {
  16. static std::shared_ptr<OptimizedBackwardGraphResult> make_optimized_backward_graph(
  17. std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs,
  18. Span<bool> inputs_require_grad) {
  19. // hash
  20. using OptimizedBackwardGraphCache = OpMethResultCache<
  21. std::shared_ptr<OptimizedBackwardGraphResult>, SmallVector<bool>>;
  22. thread_local auto cache = std::make_unique<OptimizedBackwardGraphCache>();
  23. OptimizedBackwardGraphCache::key_t cache_key{op};
  24. SmallVector<LogicalTensorDesc>& input_descs = cache_key.inputs;
  25. std::get<0>(cache_key.extras) = inputs_require_grad.copy_into<SmallVector<bool>>();
  26. input_descs.resize(inputs.size());
  27. for (size_t i = 0; i < inputs.size(); ++i) {
  28. input_descs[i].layout.dtype = inputs[i].dtype().cast<DTypeValue>();
  29. input_descs[i].comp_node = inputs[i].device().cast<CompNodeValue>();
  30. }
  31. auto iter = cache->find(cache_key);
  32. if (iter != cache->end()) {
  33. return iter->second;
  34. }
  35. // slow path
  36. SmallVector<bool> output_has_grad(outputs.size(), true);
  37. std::shared_ptr<OptimizedBackwardGraphResult> ret;
  38. auto bg = OpDef::make_backward_graph(
  39. *op, input_descs, std::get<0>(cache_key.extras), output_has_grad);
  40. if (!bg.graph.empty()) {
  41. ret = std::make_shared<OptimizedBackwardGraphResult>(bg);
  42. }
  43. cache->emplace(cache_key, ret);
  44. return ret;
  45. }
  46. BackwardGraphWithClosure::BackwardGraphWithClosure(
  47. std::shared_ptr<OptimizedBackwardGraphResult> backward_graph,
  48. std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs)
  49. : backward_graph(backward_graph),
  50. output_mask_offset(inputs.size()),
  51. grad_mask_offset(inputs.size() + outputs.size()) {
  52. auto& save_for_backward = backward_graph->save_for_backward;
  53. mgb_assert(save_for_backward.size() == inputs.size() + 2 * outputs.size());
  54. size_t count = std::count_if(
  55. save_for_backward.begin(), save_for_backward.end(), ranges::identity{});
  56. if (!backward_graph->precomp.empty()) {
  57. SmallVector<ValueRef> inputs_and_outputs;
  58. for (auto&& input : inputs) {
  59. inputs_and_outputs.push_back(input);
  60. }
  61. for (auto&& output : outputs) {
  62. inputs_and_outputs.push_back(output);
  63. }
  64. auto precomp = imperative::apply(backward_graph->precomp, inputs_and_outputs);
  65. closure.reserve(precomp.size() + count);
  66. std::copy(precomp.begin(), precomp.end(), std::back_inserter(closure));
  67. } else {
  68. closure.reserve(count);
  69. }
  70. for (size_t i = 0; i < inputs.size(); ++i) {
  71. if (save_for_backward[i]) {
  72. closure.push_back(inputs[i]);
  73. }
  74. }
  75. for (size_t i = 0; i < outputs.size(); ++i) {
  76. if (save_for_backward[inputs.size() + i]) {
  77. closure.push_back(outputs[i]);
  78. }
  79. }
  80. }
  81. void BackwardGraphWithClosure::operator()(
  82. std::vector<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) {
  83. ValueRef args[closure.size() + grads.size()];
  84. size_t nargs = 0;
  85. for (auto&& value : closure) {
  86. args[nargs++] = value;
  87. }
  88. bool null_grad = false;
  89. for (size_t i = 0; i < grads.size(); ++i) {
  90. if (backward_graph->save_for_backward[grad_mask_offset + i]) {
  91. if (grads[i]) {
  92. mgb_assert(!null_grad, "null_grad");
  93. args[nargs++] = grads[i];
  94. } else {
  95. null_grad = true;
  96. }
  97. }
  98. }
  99. if (null_grad) {
  100. return;
  101. }
  102. auto igrads = imperative::apply(backward_graph->backward, Span(args, nargs));
  103. auto&& iter = igrads.begin();
  104. for (auto [i, p] : ranges::views::enumerate(backward_graph->input_has_grad)) {
  105. if (p) {
  106. receiver(i, std::move(*iter));
  107. ++iter;
  108. }
  109. }
  110. }
  111. void CustomBackward::operator()(
  112. std::vector<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) {
  113. size_t nargs = grads.size();
  114. ValueRef args[nargs];
  115. for (size_t i = 0; i < nargs; ++i) {
  116. args[i] = grads[i];
  117. }
  118. auto ret = m_backward({args, nargs});
  119. for (size_t i = 0; i < ret.size(); ++i) {
  120. if (auto&& t = ret[i]) {
  121. receiver(i, std::move(t));
  122. }
  123. }
  124. }
  125. std::string GradSlot::to_string() const {
  126. bool has_callback = bool(callback);
  127. return ssprintf(
  128. "GradSlot{grad=%s, has_callback=%d}", m_grad.to_string().c_str(),
  129. (int)has_callback);
  130. }
  131. std::string GradFn::to_string() const {
  132. return ssprintf("GradFn{dests=%s}", imperative::to_string(m_dests).c_str());
  133. }
  134. std::string GradSlotPtr::to_string() const {
  135. if (!m_fn) {
  136. return "<empty>";
  137. }
  138. return (*this)->to_string();
  139. }
  140. std::string GradValue::to_string() const {
  141. return ssprintf(
  142. "GradValue{key=\"%s\", slot=%s, value=%s}", m_key->name().c_str(),
  143. m_slot.to_string().c_str(), m_value.to_string().c_str());
  144. }
  145. static std::unordered_map<Typeinfo*, CustomBackward::BackwardRule>&
  146. get_backward_rule_storage() {
  147. static std::unordered_map<Typeinfo*, CustomBackward::BackwardRule> sl_storage;
  148. return sl_storage;
  149. }
  150. bool CustomBackward::register_grad_rule(Typeinfo* typeinfo, BackwardRule rule) {
  151. return get_backward_rule_storage().insert({typeinfo, rule}).second;
  152. }
  153. auto CustomBackward::lookup_grad_rule(Typeinfo* typeinfo) -> BackwardRule {
  154. auto iter = get_backward_rule_storage().find(typeinfo);
  155. if (iter == get_backward_rule_storage().end()) {
  156. return {};
  157. }
  158. return iter->second;
  159. }
  160. void GradKey::backward() {
  161. mgb_assert(m_frozen);
  162. auto& tape = m_frozen_tape;
  163. for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) {
  164. auto& [grad_fn, op] = tape[k];
  165. auto grad_receiver = [&, grad_fn = grad_fn](size_t i, ValueRef grad) {
  166. auto& dest = grad_fn->m_dests[i];
  167. if (dest) {
  168. auto& existing_grad = dest->m_grad;
  169. if (!existing_grad) {
  170. existing_grad = grad;
  171. } else {
  172. existing_grad = imperative::apply(
  173. ApplyOp(*Elemwise::make(Elemwise::Mode::ADD)),
  174. existing_grad, grad)[0];
  175. }
  176. }
  177. };
  178. // clang-format off
  179. std::visit([&, grad_fn = grad_fn, op = op](auto&& backward) {
  180. using T = std::decay_t<decltype(backward)>;
  181. if constexpr (std::is_same_v<T, std::monostate>) {
  182. mgb_throw(AssertionError, "invalid backward");
  183. } else {
  184. mgb_assert(grad_fn->m_slots.size() > 0);
  185. std::vector<ValueRef> grads;
  186. for (auto&& slot : grad_fn->m_slots) {
  187. grads.push_back(slot.m_grad);
  188. }
  189. backward(grads, grad_receiver);
  190. }
  191. }, grad_fn->m_backward);
  192. // clang-format on
  193. for (auto&& dest : grad_fn->m_dests) {
  194. if (!dest) {
  195. continue;
  196. }
  197. if (!dest.m_producer_record.next && dest->callback && dest->m_grad) {
  198. // I'm the last grad producer, invoke callback
  199. dest->callback(dest->m_grad);
  200. }
  201. }
  202. grad_fn->clear();
  203. }
  204. tape.clear();
  205. }
  206. GradValue::ref_t GradKey::attach(
  207. ValueRef tensor, std::function<void(ValueRef)> callback) {
  208. auto grad_value = tensor.as_ref<GradValue>();
  209. if (grad_value && grad_value->has_key(shared_from_this())) {
  210. mgb_assert(
  211. !tensor.cast<GradValue>().slot_for(shared_from_this())->callback,
  212. "callback exists");
  213. } else {
  214. GradSlotPtr grad_slot;
  215. auto& grad_fn = grad_slot.m_fn;
  216. grad_fn = std::make_shared<GradFn>();
  217. grad_fn->m_key = shared_from_this();
  218. grad_fn->m_slots.resize(1);
  219. grad_slot.m_index = 0;
  220. grad_value = GradValue::make(tensor, shared_from_this(), grad_slot);
  221. }
  222. grad_value->slot_for(shared_from_this()).m_fn->m_slots[0].callback = callback;
  223. return grad_value;
  224. }
  225. void GradKey::freeze() {
  226. mgb_assert(m_frozen_tape.empty() && !m_frozen);
  227. for (auto&& [grad_fn, op] : m_tape) {
  228. if (auto valid_grad_fn = grad_fn.lock()) {
  229. m_frozen_tape.push_back({valid_grad_fn, op});
  230. }
  231. }
  232. m_tape.clear();
  233. m_frozen = true;
  234. }
  235. std::vector<ValueRef> GradTransformation::apply_transformation(
  236. const Operator& op, Span<ValueRef> inputs) {
  237. auto unwrap_inputs = [this](Span<ValueRef> inputs) -> SmallVector<ValueRef> {
  238. SmallVector<ValueRef> unwrapped_inputs;
  239. for (auto&& input : inputs) {
  240. if (auto grad_value = as_grad_value(input)) {
  241. unwrapped_inputs.push_back(grad_value->m_value);
  242. } else {
  243. unwrapped_inputs.push_back(input);
  244. }
  245. }
  246. return unwrapped_inputs;
  247. };
  248. if (m_suppressed) {
  249. return imperative::apply(op, unwrap_inputs(inputs));
  250. }
  251. if (auto* op_val = op.as<ApplyOp>()) {
  252. size_t nr_require_grad = 0;
  253. SmallVector<bool> require_grads;
  254. for (auto&& input : inputs) {
  255. if (is_grad_value(input)) {
  256. nr_require_grad++;
  257. require_grads.push_back(true);
  258. } else {
  259. require_grads.push_back(false);
  260. }
  261. }
  262. if (nr_require_grad == 0) {
  263. return imperative::apply(op, inputs);
  264. }
  265. SmallVector<ValueRef> captured_inputs;
  266. SmallVector<bool> inputs_require_grad;
  267. // capture value so that trace could assume input as same
  268. auto capture_value = [](ValueRef value) {
  269. // TODO: fastpath copy shouldn't be an OpDef
  270. return imperative::apply(ApplyOp(*FastpathCopy::make()), {&value, 1})[0];
  271. };
  272. for (auto& input : inputs) {
  273. if (auto grad_value = as_grad_value(input)) {
  274. captured_inputs.push_back(capture_value(grad_value->m_value));
  275. inputs_require_grad.push_back(true);
  276. } else {
  277. captured_inputs.push_back(capture_value(input));
  278. inputs_require_grad.push_back(false);
  279. }
  280. }
  281. decltype(std::declval<GradFn>().m_backward) backward_storage;
  282. auto outputs = [&] {
  283. auto backward_rule =
  284. CustomBackward::lookup_grad_rule(op_val->op().dyn_typeinfo());
  285. if (backward_rule) {
  286. CustomBackward backward;
  287. auto optional_outputs = backward_rule(
  288. op_val->op(), {captured_inputs.data(), captured_inputs.size()},
  289. {inputs_require_grad.data(), inputs_require_grad.size()},
  290. backward);
  291. if (optional_outputs) {
  292. backward_storage = backward;
  293. // backward by rule
  294. return *optional_outputs;
  295. }
  296. }
  297. auto outputs = imperative::apply(
  298. op, {captured_inputs.begin(), captured_inputs.end()});
  299. auto backward_graph = make_optimized_backward_graph(
  300. op.cast<ApplyOp>().op().shared_from_this(),
  301. {captured_inputs.begin(), captured_inputs.end()},
  302. {outputs.data(), outputs.size()},
  303. {inputs_require_grad.data(), inputs_require_grad.size()});
  304. if (backward_graph) {
  305. backward_storage = BackwardGraphWithClosure(
  306. backward_graph, op.cast<ApplyOp>().op().shared_from_this(),
  307. {captured_inputs.begin(), captured_inputs.end()},
  308. {outputs.data(), outputs.size()});
  309. // backward by make_backward_graph
  310. return outputs;
  311. } else {
  312. // no backward
  313. return outputs;
  314. }
  315. }();
  316. if (std::holds_alternative<std::monostate>(backward_storage)) {
  317. return outputs;
  318. }
  319. auto grad_fn = std::make_shared<GradFn>();
  320. grad_fn->m_key = m_key;
  321. grad_fn->m_slots.resize(outputs.size());
  322. grad_fn->m_backward = backward_storage;
  323. mgb_assert(!outputs.empty());
  324. grad_fn->m_dests.reserve(inputs.size());
  325. // clang-format off
  326. std::visit([&](auto& backward) {
  327. using T = std::decay_t<decltype(backward)>;
  328. if constexpr (std::is_same_v<T, std::monostate>) {
  329. mgb_throw(AssertionError, "invalid backward");
  330. } else {
  331. for (size_t i = 0; i < inputs.size(); ++i) {
  332. if (backward.input_has_grad(i) && require_grads[i]) {
  333. auto& input_grad_slot =
  334. inputs[i].cast<GradValue>().slot_for(m_key);
  335. grad_fn->m_dests.emplace_back(input_grad_slot);
  336. grad_fn->m_dests.back().m_producer_record.insert_after(
  337. input_grad_slot->m_producer_head);
  338. } else {
  339. grad_fn->m_dests.emplace_back();
  340. }
  341. }
  342. for (size_t i = 0; i < outputs.size(); ++i) {
  343. if (backward.output_requires_grad(i)) {
  344. auto grad_value = GradValue::make(outputs[i], m_key, GradSlotPtr{grad_fn, i});
  345. outputs[i] = record_grad(grad_value);
  346. }
  347. }
  348. }
  349. }, grad_fn->m_backward);
  350. // clang-format on
  351. mgb_assert(!grad_fn->m_slots.empty());
  352. m_key->m_tape.push_back({grad_fn, op_val->op().shared_from_this()});
  353. return outputs;
  354. } else if (auto* attach_grad = op.as<AttachGrad>()) {
  355. if (!has_key(attach_grad->key())) {
  356. return imperative::apply(op, unwrap_inputs(inputs));
  357. }
  358. auto tensor = inputs[0];
  359. GenericFunction callback = (GenericFunction&)inputs[1].cast<FunctionValue>();
  360. auto output = attach_grad->key()->attach(tensor, [callback](ValueRef grad) {
  361. auto ret = callback({&grad, 1});
  362. assert(ret.empty());
  363. });
  364. return {record_grad(output)};
  365. } else if (auto* grad_backward = op.as<GradBackward>()) {
  366. if (!has_key(grad_backward->key())) {
  367. return imperative::apply(op, unwrap_inputs(inputs));
  368. }
  369. size_t nr_grads = inputs.size() / 2;
  370. mgb_assert(nr_grads * 2 == inputs.size());
  371. auto values = inputs.sub(0, nr_grads);
  372. auto grads = inputs.sub(nr_grads, nr_grads);
  373. make_backward_closure(values)(grads);
  374. return {};
  375. } else if (auto* is_attached_to = op.as<IsAttachedTo>()) {
  376. if (has_key(is_attached_to->key())) {
  377. if (auto grad_value = as_grad_value(inputs[0])) {
  378. // TODO: assert grad_fn
  379. return {BoolValue::make(true)};
  380. }
  381. }
  382. return {BoolValue::make(false)};
  383. } else if (auto* set_grad = op.as<SetGrad>()) {
  384. // TODO: merge SetGrad and ApplyOp
  385. auto grad_fn = std::make_shared<GradFn>();
  386. auto& backward =
  387. std::get<CustomBackward>(grad_fn->m_backward = CustomBackward());
  388. size_t nr_inputs = set_grad->nr_inputs();
  389. mgb_assert(inputs.size() > nr_inputs);
  390. size_t nr_outputs = inputs.size() - nr_inputs;
  391. Span<ValueRef> inputs_ = {inputs.data(), nr_inputs};
  392. Span<ValueRef> outputs_ = {inputs.data() + nr_inputs, nr_outputs};
  393. backward.m_input_has_grad = SmallVector(nr_inputs, true);
  394. backward.m_output_attrs =
  395. SmallVector(nr_outputs, CustomBackward::OutputAttr{true, true});
  396. backward.m_backward = set_grad->grad_fn();
  397. std::vector<ValueRef> outputs;
  398. grad_fn->m_key = m_key;
  399. grad_fn->m_slots.resize(nr_outputs);
  400. grad_fn->m_dests.reserve(nr_inputs);
  401. for (size_t i = 0; i < nr_inputs; ++i) {
  402. if (auto grad_value = as_grad_value(inputs_[i])) {
  403. auto& input_grad_slot = grad_value->m_slot;
  404. grad_fn->m_dests.emplace_back(grad_value->m_slot);
  405. grad_fn->m_dests.back().m_producer_record.insert_after(
  406. input_grad_slot->m_producer_head);
  407. } else {
  408. grad_fn->m_dests.emplace_back();
  409. }
  410. }
  411. for (size_t i = 0; i < nr_outputs; ++i) {
  412. auto& output = outputs_[i];
  413. auto grad_value = as_grad_value(output);
  414. if (grad_value) {
  415. grad_value = GradValue::make(
  416. grad_value->m_value, m_key, GradSlotPtr(grad_fn, i));
  417. } else {
  418. grad_value = GradValue::make(output, m_key, GradSlotPtr(grad_fn, i));
  419. }
  420. outputs.push_back(record_grad(grad_value));
  421. }
  422. m_key->m_tape.push_back({grad_fn, nullptr});
  423. return outputs;
  424. } else if (auto* gbc = op.as<GetBackwardColsure>()) {
  425. if (gbc->key() != m_key) {
  426. return imperative::apply(op, unwrap_inputs(inputs));
  427. }
  428. return {FunctionValue::make(make_backward_closure(inputs))};
  429. } else if (op.is<DetachGrad>()) {
  430. if (auto grad_value = as_grad_value(inputs[0])) {
  431. return {grad_value->m_value};
  432. } else {
  433. return {inputs[0]};
  434. }
  435. } else if (op.is<GetGradKey>()) {
  436. for (auto&& input : inputs) {
  437. if (auto grad_value = as_grad_value(input)) {
  438. return {GradKeyValue::make(grad_value->m_key)};
  439. }
  440. }
  441. return imperative::apply(op, inputs);
  442. } else if (op.kind() == Operator::IdentityLike) {
  443. mgb_assert(inputs.size() == 1);
  444. if (auto grad_value = as_grad_value(inputs[0])) {
  445. auto output = imperative::apply(op, grad_value->m_value)[0];
  446. auto grad_output = GradValue::make(
  447. output, grad_value->key(), grad_value->slot_for(m_key));
  448. return {record_grad(grad_output)};
  449. } else {
  450. return imperative::apply(op, inputs);
  451. }
  452. } else if (op.is<CreateTensor>()) {
  453. return imperative::apply(op, inputs);
  454. } else {
  455. SmallVector<ValueRef> unwrapped_inputs;
  456. for (auto&& input : inputs) {
  457. if (auto grad_value = as_grad_value(input)) {
  458. unwrapped_inputs.push_back(grad_value->m_value);
  459. } else {
  460. unwrapped_inputs.push_back(input);
  461. }
  462. }
  463. auto outputs = imperative::apply(
  464. op, {unwrapped_inputs.data(), unwrapped_inputs.size()});
  465. mgb_assert(op.kind() == Operator::GetAttrLike || outputs.empty());
  466. return outputs;
  467. }
  468. }
  469. GenericFunction GradTransformation::make_backward_closure(Span<ValueRef> ys) {
  470. // reset GradKey
  471. auto grad_key = m_key;
  472. std::vector<GradSlotPtr> y_slots;
  473. for (auto&& y : ys) {
  474. if (auto grad_value = as_grad_value(y)) {
  475. y_slots.push_back(grad_value->slot_for(grad_key));
  476. } else {
  477. y_slots.emplace_back();
  478. }
  479. }
  480. GenericFunction closure = [grad_key,
  481. y_slots](Span<ValueRef> dys) -> std::vector<ValueRef> {
  482. size_t nr_grads = y_slots.size();
  483. mgb_assert(dys.size() == nr_grads);
  484. for (size_t i = 0; i < nr_grads; ++i) {
  485. if (y_slots[i]) {
  486. y_slots[i]->m_grad = dys[i];
  487. }
  488. }
  489. grad_key->backward();
  490. return {};
  491. };
  492. grad_key->freeze();
  493. cleanup();
  494. return closure;
  495. }
  496. void GradTransformation::on_unregister() noexcept {
  497. cleanup();
  498. }
  499. void GradTransformation::cleanup() {
  500. for (auto&& weak_value : m_weak_values) {
  501. auto grad_value = weak_value.lock();
  502. if (grad_value) {
  503. mgb_assert(grad_value->m_key == m_key);
  504. grad_value.reset(grad_value->m_value);
  505. }
  506. }
  507. m_weak_values.clear();
  508. m_key = {};
  509. }
  510. void GradTransformation::suppress() {
  511. m_suppressed++;
  512. }
  513. void GradTransformation::resume() {
  514. m_suppressed--;
  515. }
  516. } // namespace imperative
  517. } // namespace mgb