/** * \file imperative/src/impl/interpreter/stack_manager.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #include #include #include #include "megbrain/utils/metahelper.h" #include "megbrain/utils/small_vector.h" namespace mgb::imperative::interpreter::intl{ class StackSnapshot; class StackManager: public NonCopyableObj { public: class Node; class Guard; struct Frame; class Trace; private: std::unique_ptr m_root = nullptr; Node* m_current = nullptr; SmallVector m_trace_id_stack; uint64_t m_last_trace_id = 0; public: StackManager(); std::pair enter(std::string name); void exit(std::string name); Trace dump(); Node* current(); }; class StackManager::Node: public NonCopyableObj { private: std::string m_name; std::unordered_map> m_children; std::unordered_map m_id_table; Node* m_parent = nullptr; int64_t m_depth = -1; uint64_t m_version = 0; explicit Node(std::string name, Node* parent): m_name{name}, m_parent{parent} { if (parent) { m_depth = parent->m_depth + 1; } } public: const std::string& name() const { return m_name; } Node* operator[](const std::string& name) { auto& child = m_children[name]; if (child == nullptr) { child.reset(new Node(name, this)); } return child.get(); } Node* parent() { return m_parent; } bool is_root() { return m_parent == nullptr; } uint64_t version() const { return m_version; } void update_version() { ++m_version; for (auto&& [key, child]: m_children) { child->reset_version(); } m_id_table.clear(); } void reset_version() { m_version = 0; m_id_table.clear(); } int64_t depth() const { return m_depth; } uint64_t next_id(std::string key) { return m_id_table[key]++; } static std::unique_ptr make() { return std::unique_ptr(new Node("", nullptr)); } }; class StackManager::Guard { private: std::string m_name; StackManager* m_manager; public: Guard(std::string name, StackManager* manager): m_name{name}, m_manager{manager}{ if (m_manager) { m_manager->enter(name); } } ~Guard() { release(); } void release() { if (m_manager) { m_manager->exit(m_name); m_manager = nullptr; } } }; struct StackManager::Frame { StackManager::Node* node; uint64_t version; }; class StackManager::Trace { private: SmallVector m_frames; uint64_t m_id = 0; public: explicit Trace(StackManager::Node* top, uint64_t id): m_id{id} { int64_t nr_frames = top->depth() + 1; m_frames = SmallVector(nr_frames); StackManager::Node* node = top; for (int64_t i = 0; i < nr_frames; ++i) { m_frames[m_frames.size()-1-i] = {node, node->version()}; node = node->parent(); } mgb_assert(node->is_root() , ""); } Trace() = default; std::string to_string() const { std::string buffer; for (auto&& [node, version]: m_frames) { if (!buffer.empty()) { buffer.append("."); } buffer.append(node->name()); if (version != 0) { buffer.append(ssprintf("[%zu]", version)); } } return buffer; } const SmallVector& frames() const { return m_frames; } uint64_t id() const { return m_id; } }; inline StackManager::StackManager() { m_root = Node::make(); m_current = m_root.get(); } inline std::pair StackManager::enter(std::string name) { m_current = (*m_current)[name]; m_trace_id_stack.push_back(++m_last_trace_id); return {m_current, m_current->version()}; } inline void StackManager::exit(std::string name) { mgb_assert(m_current->name() == name, "scope name mismatch"); m_current = m_current->parent(); m_trace_id_stack.pop_back(); m_current->update_version(); } inline StackManager::Trace StackManager::dump() { return Trace(m_current, m_trace_id_stack.empty() ? 0 : m_trace_id_stack.back()); } inline StackManager::Node* StackManager::current() { return m_current; } }