You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

subgraph.cpp 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. /**
  2. * \file imperative/src/impl/subgraph.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megbrain/imperative/subgraph.h"
  13. namespace mgb {
  14. namespace imperative {
  15. void Subgraph::remove_unused_exprs() {
  16. std::unordered_set<size_t> required_vars = {outputs.begin(), outputs.end()};
  17. required_vars.erase(0);
  18. for (auto iter = exprs.rbegin(); iter != exprs.rend(); ++iter) {
  19. auto& expr = *iter;
  20. bool required = false;
  21. for (auto output : expr.outputs) {
  22. if (required_vars.count(output)) {
  23. required = true;
  24. break;
  25. }
  26. }
  27. if (required) {
  28. required_vars.insert(expr.inputs.begin(), expr.inputs.end());
  29. } else {
  30. expr.op = nullptr;
  31. }
  32. }
  33. exprs.erase(std::remove_if(exprs.begin(), exprs.end(),
  34. [](auto expr) { return expr.op == nullptr; }),
  35. exprs.end());
  36. }
  37. SmallVector<bool> Subgraph::gen_input_mask() {
  38. std::unordered_set<size_t> unused_inputs = {inputs.begin(), inputs.end()};
  39. for (auto&& expr : exprs) {
  40. for (auto&& input : expr.inputs) {
  41. unused_inputs.erase(input);
  42. }
  43. }
  44. for (auto&& output : outputs) {
  45. unused_inputs.erase(output);
  46. }
  47. unused_inputs.insert(0);
  48. SmallVector<bool> mask(inputs.size(), true);
  49. for (size_t i = 0; i < inputs.size(); ++i) {
  50. if (unused_inputs.count(inputs[i])) {
  51. mask[i] = false;
  52. }
  53. }
  54. return mask;
  55. }
  56. SmallVector<bool> Subgraph::gen_output_mask() {
  57. std::unordered_set<size_t> invalid_outputs = {outputs.begin(),
  58. outputs.end()};
  59. for (auto&& input : inputs) {
  60. invalid_outputs.erase(input);
  61. }
  62. for (auto&& expr : exprs) {
  63. for (auto&& output : expr.outputs) {
  64. invalid_outputs.erase(output);
  65. }
  66. }
  67. for (auto&& constant: constants) {
  68. invalid_outputs.erase(constant.first);
  69. }
  70. invalid_outputs.insert(0);
  71. SmallVector<bool> mask(outputs.size(), true);
  72. for (size_t i = 0; i < outputs.size(); ++i) {
  73. if (invalid_outputs.count(outputs[i])) {
  74. mask[i] = false;
  75. }
  76. }
  77. return mask;
  78. }
  79. void Subgraph::replace_vars(
  80. const std::unordered_map<size_t, size_t>& replace_map) {
  81. // FIXME: preprocess replace_map
  82. auto replace_var = [&](var_t& var) {
  83. // TODO: detect infinite loop
  84. while (replace_map.count(var)) {
  85. var = replace_map.at(var);
  86. }
  87. };
  88. for (auto& expr : exprs) {
  89. for (auto& input : expr.inputs) {
  90. replace_var(input);
  91. }
  92. }
  93. for (auto& output : outputs) {
  94. replace_var(output);
  95. }
  96. }
  97. std::string EncodedSubraph::repr() const {
  98. std::string buffer;
  99. buffer.push_back('|');
  100. for (size_t i = 0; i < input_mask.size(); ++i) {
  101. buffer.push_back(input_mask[i] ? '#' : ' ');
  102. }
  103. buffer.push_back('|');
  104. buffer.push_back('\n');
  105. buffer.append(graph.repr());
  106. buffer.push_back('|');
  107. for (size_t i = 0; i < output_mask.size(); ++i) {
  108. buffer.push_back(output_mask[i] ? '#' : ' ');
  109. }
  110. buffer.push_back('|');
  111. return buffer;
  112. }
  113. size_t EncodedSubraph::hash() const {
  114. return std::hash<std::string>{}(repr());
  115. }
  116. } // namespace imperative
  117. } // namespace mgb

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台