You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

backward_graph_opt.cpp 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. /**
  2. * \file imperative/src/impl/backward_graph_opt.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/imperative/backward_graph_opt.h"
  12. #include "megbrain/imperative/ops/backward_graph.h"
  13. #include "megbrain/imperative/ops/autogen.h"
  14. using namespace mgb;
  15. using namespace imperative;
  16. OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphResult& src)
  17. : input_has_grad(src.input_has_grad) {
  18. if (!src.backward->same_type<BackwardGraph>()) {
  19. // backward graph only contains a single op
  20. backward = src.backward;
  21. save_for_backward = src.save_for_backward;
  22. return;
  23. }
  24. save_for_backward.resize(src.save_for_backward.size(), false);
  25. precomp.reset(new BackwardGraph);
  26. backward.reset(new BackwardGraph);
  27. auto&& graph = src.backward->cast_final_safe<BackwardGraph>().graph();
  28. auto&& mask = src.save_for_backward;
  29. size_t input_size = src.input_has_grad.size();
  30. size_t output_size = (mask.size() - input_size) / 2;
  31. mgb_assert(input_size + output_size * 2 == mask.size());
  32. auto& fgraph = precomp->cast_final<BackwardGraph>().graph();
  33. auto& bgraph = backward->cast_final<BackwardGraph>().graph();
  34. // optimization: move ops (e.g. GetVarShape) to forward to
  35. // reduce memory footprint
  36. struct VInfo {
  37. bool appears_in_backward = false;
  38. };
  39. std::unordered_map<size_t, VInfo> vinfo;
  40. // step 1.1: ops not in whitelist must run in backward.
  41. // mark their inputs as always appears in backward
  42. for (auto&& [op, iv, ov] : graph.exprs) {
  43. if (!op->same_type<GetVarShape>()) {
  44. for (auto&& v : iv) {
  45. vinfo[v].appears_in_backward = true;
  46. }
  47. }
  48. }
  49. // step 1.2: inputs only available in backward (i.e. grads)
  50. // should be marked as always appears in backward
  51. for (size_t i = 0, j = 0; i < mask.size(); ++i) {
  52. if (!mask[i]) continue;
  53. if (i > input_size + output_size) {
  54. vinfo[graph.inputs[j]].appears_in_backward = true;
  55. }
  56. ++j;
  57. }
  58. // step 2: try to move ops to forward, if not all their inputs
  59. // are marked always appears in backward (otherwise no memory saving)
  60. for (auto&& expr : graph.exprs) {
  61. auto&& [op, iv, ov] = expr;
  62. if (std::all_of(iv.begin(), iv.end(), [&](auto&& v){return vinfo[v].appears_in_backward;})) {
  63. bgraph.exprs.push_back(expr);
  64. for (auto&& v : ov) {
  65. vinfo[v].appears_in_backward = true;
  66. }
  67. // logically should also mark all inputs as appears in backward
  68. // but clearly that's a no-op.
  69. } else {
  70. fgraph.exprs.push_back(expr);
  71. for (auto&& v : ov) {
  72. if (vinfo[v].appears_in_backward) {
  73. // appears_in_backward won't change after this point
  74. // so it is safe to set fgraph.outputs based on current value
  75. fgraph.outputs.push_back(v);
  76. }
  77. }
  78. }
  79. }
  80. // initialize remaining parts
  81. fgraph.constants = graph.constants;
  82. fgraph.inputs.reserve(input_size + output_size);
  83. for (size_t i = 0, j = 0; i < input_size + output_size; ++i) {
  84. if (!mask[i]) {
  85. fgraph.inputs.push_back(1000000000 + i);
  86. continue;
  87. }
  88. fgraph.inputs.push_back(graph.inputs[j++]);
  89. }
  90. bgraph.constants = graph.constants;
  91. bgraph.outputs = graph.outputs;
  92. bgraph.inputs = fgraph.outputs;
  93. for (size_t i = 0, j = 0; i < mask.size(); ++i) {
  94. if (mask[i]) {
  95. auto&& v = graph.inputs[j++];
  96. if (vinfo[v].appears_in_backward) {
  97. save_for_backward[i] = true;
  98. bgraph.inputs.push_back(v);
  99. }
  100. }
  101. }
  102. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台