You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

proxy_graph_detail.cpp 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. /**
  2. * \file imperative/src/impl/proxy_graph_detail.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./proxy_graph.h"
  12. #include "megbrain/imperative/proxy_graph_detail.h"
  13. #include "megbrain/imperative/ops/autogen.h"
  14. namespace mgb {
  15. namespace imperative {
  16. namespace proxy_graph_detail {
  17. namespace {
  18. SmallVector<Tensor*> to_raw_ptr_array(
  19. const SmallVector<TensorPtr>& inputs,
  20. bool ensure_storage=true) {
  21. SmallVector<Tensor*> ret;
  22. for (auto&& i : inputs) {
  23. mgb_assert(i);
  24. ret.push_back(i.get());
  25. if (ensure_storage) {
  26. // apply lazy allocation
  27. i->blob()->storage();
  28. }
  29. }
  30. return ret;
  31. }
  32. SmallVector<LogicalTensorDesc>
  33. infer_output_attrs(const OpDef& def,
  34. const SmallVector<TensorPtr>& inputs) {
  35. auto&& graph = ProxyGraph::get_default_graph();
  36. return graph->infer_output_attrs(def, to_raw_ptr_array(inputs));
  37. }
  38. } // anonymous namespace
  39. void exec(const OpDef& def,
  40. const SmallVector<TensorPtr>& inputs,
  41. const SmallVector<TensorPtr>& outputs) {
  42. auto&& graph = ProxyGraph::get_default_graph();
  43. auto raw_inputs = to_raw_ptr_array(inputs),
  44. raw_outputs = to_raw_ptr_array(outputs);
  45. CompNode::UnorderedSet used_cns;
  46. for (auto&& out: raw_outputs) {
  47. auto cn = out->comp_node();
  48. if (used_cns.insert(cn).second) {
  49. for (auto&& in: inputs) {
  50. if (in->comp_node() != cn) {
  51. auto&& e = in->get_or_create_event();
  52. e->device_wait_by(cn);
  53. }
  54. }
  55. }
  56. }
  57. graph->invoke_op(def, raw_inputs, raw_outputs);
  58. for (auto&& cn: used_cns) {
  59. for (auto&& in: inputs) {
  60. if (in->comp_node() != cn) {
  61. in->add_release_callback(cn);
  62. }
  63. }
  64. }
  65. }
  66. SmallVector<TensorPtr>
  67. apply_on_physical_tensor(const OpDef& def,
  68. SmallVector<TensorPtr> inputs) {
  69. auto output_descs = infer_output_attrs(def, inputs);
  70. SmallVector<TensorPtr> outputs(output_descs.size(), {});
  71. for (size_t i = 0; i < outputs.size(); i++) {
  72. auto& output = outputs[i];
  73. auto& output_desc = output_descs[i];
  74. if (def.same_type<Elemwise>()) {
  75. for (size_t j = 0; j < inputs.size(); j++) {
  76. // TODO: reindex inputs to support inplace exprs like 'y = x op x'.
  77. auto& input = inputs[j];
  78. // Because we pass inputs by value, if input and input->blob() are all unique,
  79. // their ownerships are on the stack, thus we can reuse them safely.
  80. // @see: interpreter::intl::ChannelImpl::process_one_task
  81. if (input.unique() && input->blob().unique() && input->blob()->storage().unique() &&
  82. input->layout().dtype == output_desc.layout.dtype &&
  83. input->layout().eq_layout(output_desc.layout) &&
  84. input->comp_node() == output_desc.comp_node) {
  85. static std::atomic_llong inplace_count = 0;
  86. mgb_log_debug("do inplace for elemwise, layout: %s, count: %lld",
  87. output_desc.layout.to_string().c_str(), ++inplace_count);
  88. output = Tensor::make(input->blob(), input->layout(), input->offset());
  89. break;
  90. }
  91. }
  92. }
  93. if (!output) {
  94. output = Tensor::make(output_desc.layout, output_desc.comp_node);
  95. }
  96. }
  97. exec(def, inputs, outputs);
  98. auto async_error = ProxyGraph::get_async_error();
  99. if (async_error) {
  100. throw *async_error;
  101. }
  102. return outputs;
  103. }
  104. // std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def,
  105. // const SmallVector<LogicalTensorDesc>& inputs) {
  106. // auto&& graph = ProxyGraph::get_default_graph();
  107. // return graph->infer_output_attrs_fallible(def, inputs);
  108. // }
  109. namespace {
  110. size_t get_backward_graph_hash_key(const OpDef& def,
  111. const SmallVector<LogicalTensorDesc>& inputs,
  112. const SmallVector<bool>& input_requires_grad,
  113. const SmallVector<bool>& output_has_grad) {
  114. XXHash state;
  115. size_t length = 0, data[3 + 2 * inputs.size()];
  116. data[length ++] = def.hash();
  117. for (auto &&i : inputs) {
  118. data[length ++] = mgb::hash(i.layout.dtype.handle());
  119. data[length ++] = mgb::hash(i.comp_node);
  120. }
  121. data[length ++] = mgb::hash(input_requires_grad);
  122. data[length ++] = mgb::hash(output_has_grad);
  123. mgb_assert(length == 3 + 2 * inputs.size());
  124. state.update(data, length * sizeof(size_t));
  125. return state.digest();
  126. }
  127. struct BackwardGraphCache : std::unordered_map<size_t, BackwardGraphResult>, CompNodeDepedentObject {
  128. std::shared_ptr<void> on_comp_node_finalize() override {
  129. clear();
  130. return {};
  131. }
  132. } backward_graph_cache;
  133. } // anonymous namespace
  134. BackwardGraphResult
  135. make_backward_graph(const OpDef& def,
  136. const SmallVector<LogicalTensorDesc>& inputs,
  137. const SmallVector<bool>& input_requires_grad,
  138. const SmallVector<bool>& output_has_grad) {
  139. auto hash_key = get_backward_graph_hash_key(def, inputs, input_requires_grad, output_has_grad);
  140. auto&& iter = backward_graph_cache.find(hash_key);
  141. if (iter != backward_graph_cache.end()) {
  142. return iter->second;
  143. }
  144. auto&& graph = ProxyGraph::get_default_graph();
  145. auto res = graph->make_backward_graph(def, inputs, input_requires_grad, output_has_grad);
  146. backward_graph_cache.emplace(hash_key, res);
  147. return res;
  148. }
  149. } // namespace proxy_graph_detail
  150. } // namespace imperative
  151. } // namespace mgb
  152. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台