You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

stack_manager.h 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. /**
  2. * \file imperative/src/impl/interpreter/stack_manager.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include <string>
  13. #include <memory>
  14. #include <unordered_map>
  15. #include "megbrain/utils/metahelper.h"
  16. #include "megbrain/utils/small_vector.h"
  17. namespace mgb::imperative::interpreter::intl{
  18. class StackSnapshot;
  19. class StackManager: public NonCopyableObj {
  20. public:
  21. class Node;
  22. class Guard;
  23. struct Frame;
  24. class Trace;
  25. private:
  26. std::unique_ptr<Node> m_root = nullptr;
  27. Node* m_current = nullptr;
  28. SmallVector<uint64_t> m_trace_id_stack;
  29. uint64_t m_last_trace_id = 0;
  30. public:
  31. StackManager();
  32. std::pair<Node*, uint64_t> enter(std::string name);
  33. void exit(std::string name);
  34. Trace dump();
  35. Node* current();
  36. };
  37. class StackManager::Node: public NonCopyableObj {
  38. private:
  39. std::string m_name;
  40. std::unordered_map<std::string, std::unique_ptr<Node>> m_children;
  41. std::unordered_map<std::string, size_t> m_id_table;
  42. Node* m_parent = nullptr;
  43. int64_t m_depth = -1;
  44. uint64_t m_version = 0;
  45. explicit Node(std::string name, Node* parent): m_name{name}, m_parent{parent} {
  46. if (parent) {
  47. m_depth = parent->m_depth + 1;
  48. }
  49. }
  50. public:
  51. const std::string& name() const {
  52. return m_name;
  53. }
  54. Node* operator[](const std::string& name) {
  55. auto& child = m_children[name];
  56. if (child == nullptr) {
  57. child.reset(new Node(name, this));
  58. }
  59. return child.get();
  60. }
  61. Node* parent() {
  62. return m_parent;
  63. }
  64. bool is_root() {
  65. return m_parent == nullptr;
  66. }
  67. uint64_t version() const {
  68. return m_version;
  69. }
  70. void update_version() {
  71. ++m_version;
  72. for (auto&& [key, child]: m_children) {
  73. child->reset_version();
  74. }
  75. m_id_table.clear();
  76. }
  77. void reset_version() {
  78. m_version = 0;
  79. m_id_table.clear();
  80. }
  81. int64_t depth() const {
  82. return m_depth;
  83. }
  84. uint64_t next_id(std::string key) {
  85. return m_id_table[key]++;
  86. }
  87. static std::unique_ptr<Node> make() {
  88. return std::unique_ptr<Node>(new Node("", nullptr));
  89. }
  90. };
  91. class StackManager::Guard {
  92. private:
  93. std::string m_name;
  94. StackManager* m_manager;
  95. public:
  96. Guard(std::string name, StackManager* manager): m_name{name}, m_manager{manager}{
  97. if (m_manager) {
  98. m_manager->enter(name);
  99. }
  100. }
  101. ~Guard() {
  102. release();
  103. }
  104. void release() {
  105. if (m_manager) {
  106. m_manager->exit(m_name);
  107. m_manager = nullptr;
  108. }
  109. }
  110. };
  111. struct StackManager::Frame {
  112. StackManager::Node* node;
  113. uint64_t version;
  114. };
  115. class StackManager::Trace {
  116. private:
  117. SmallVector<StackManager::Frame> m_frames;
  118. uint64_t m_id = 0;
  119. public:
  120. explicit Trace(StackManager::Node* top, uint64_t id): m_id{id} {
  121. int64_t nr_frames = top->depth() + 1;
  122. m_frames = SmallVector<StackManager::Frame>(nr_frames);
  123. StackManager::Node* node = top;
  124. for (int64_t i = 0; i < nr_frames; ++i) {
  125. m_frames[m_frames.size()-1-i] = {node, node->version()};
  126. node = node->parent();
  127. }
  128. mgb_assert(node->is_root() , "");
  129. }
  130. Trace() = default;
  131. std::string to_string() const {
  132. std::string buffer;
  133. for (auto&& [node, version]: m_frames) {
  134. if (!buffer.empty()) {
  135. buffer.append(".");
  136. }
  137. buffer.append(node->name());
  138. if (version != 0) {
  139. buffer.append(ssprintf("[%zu]", version));
  140. }
  141. }
  142. return buffer;
  143. }
  144. const SmallVector<StackManager::Frame>& frames() const {
  145. return m_frames;
  146. }
  147. uint64_t id() const {
  148. return m_id;
  149. }
  150. };
  151. inline StackManager::StackManager() {
  152. m_root = Node::make();
  153. m_current = m_root.get();
  154. }
  155. inline std::pair<StackManager::Node*, uint64_t> StackManager::enter(std::string name) {
  156. m_current = (*m_current)[name];
  157. m_trace_id_stack.push_back(++m_last_trace_id);
  158. return {m_current, m_current->version()};
  159. }
  160. inline void StackManager::exit(std::string name) {
  161. mgb_assert(m_current->name() == name, "scope name mismatch");
  162. m_current = m_current->parent();
  163. m_trace_id_stack.pop_back();
  164. m_current->update_version();
  165. }
  166. inline StackManager::Trace StackManager::dump() {
  167. return Trace(m_current, m_trace_id_stack.empty() ? 0 : m_trace_id_stack.back());
  168. }
  169. inline StackManager::Node* StackManager::current() {
  170. return m_current;
  171. }
  172. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台