Browse Source

refactor(mgb): decouple node insertion from ComputingGraphImpl

GitOrigin-RevId: 59b45fcb17
tags/v1.0.0-rc1
Megvii Engine Team 5 years ago
parent
commit
d782edf80f
6 changed files with 48 additions and 14 deletions
  1. +1
    -3
      src/core/impl/graph/bases.cpp
  2. +8
    -0
      src/core/impl/graph/cg_impl.cpp
  3. +5
    -1
      src/core/impl/graph/cg_impl.h
  4. +2
    -5
      src/core/impl/graph/operator_node.cpp
  5. +22
    -0
      src/core/include/megbrain/graph/cg.h
  6. +10
    -5
      src/core/include/megbrain/utils/mempool.h

+ 1
- 3
src/core/impl/graph/bases.cpp View File

@@ -18,11 +18,9 @@ GraphNodeBase::GraphNodeBase(ComputingGraph *owner_graph):
m_owner_graph{owner_graph}
{
mgb_assert(owner_graph, "owner graph not given");
auto id = static_cast<ComputingGraphImpl*>(owner_graph)->next_node_id();
m_id = id;
m_id = owner_graph->next_node_id();
}

AsyncExecutable::~AsyncExecutable() noexcept = default;

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}


+ 8
- 0
src/core/impl/graph/cg_impl.cpp View File

@@ -267,6 +267,14 @@ void ComputingGraphImpl::cleanup() {
m_opr_refkeeper.clear();
}

void* ComputingGraphImpl::alloc_varnode_storage() {
return m_var_node_pool.alloc_raw();
};

void ComputingGraphImpl::free_varnode_storage(void *ptr) {
m_var_node_pool.free_raw(ptr);
};

OperatorNodeBase* ComputingGraphImpl::insert_opr(
std::unique_ptr<OperatorNodeBase> opr_uniqp) {
auto opr = opr_uniqp.get();


+ 5
- 1
src/core/impl/graph/cg_impl.h View File

@@ -142,6 +142,10 @@ public:
OperatorNodeBase* insert_opr(
std::unique_ptr<OperatorNodeBase> opr) override;

void* alloc_varnode_storage() override;

void free_varnode_storage(void *ptr) override;

const VarReceiverInfo& var_receiver_in_current_comp_seq(
const VarNode* var) const override;

@@ -161,7 +165,7 @@ public:

TopoSorter& topo_sorter() { return components().topo_sorter; }

size_t next_node_id() { return (*m_node_id_counter)++; }
size_t next_node_id() override { return (*m_node_id_counter)++; }

VarNodeMemManager& var_node_mem_manager() {
return components().var_node_mem_manager;


+ 2
- 5
src/core/impl/graph/operator_node.cpp View File

@@ -93,10 +93,8 @@ OperatorNodeBase::OperatorNodeBase(ComputingGraph *owner,
}

OperatorNodeBase::~OperatorNodeBase() noexcept {
auto &&pool = ComputingGraphImpl::cast(
owner_graph())->var_node_pool();
for (auto i: m_output) {
pool.free(i);
owner_graph()->free_varnode(i);
}
}

@@ -264,8 +262,7 @@ VarNode* OperatorNodeBase::add_output(const Maybe<std::string> &name) {
mgb_assert(!m_inserted_in_graph && !m_node_prop.valid(),
"add output on opr after it has been inserted into graph");

auto ptr = ComputingGraphImpl::cast(
owner_graph())->var_node_pool().alloc(
auto ptr = owner_graph()->alloc_varnode(
name.valid() ? this->name() + ":" + name.val() : name, this);
m_output.push_back(ptr);
return ptr;


+ 22
- 0
src/core/include/megbrain/graph/cg.h View File

@@ -174,6 +174,8 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
return m_id;
}

virtual size_t next_node_id() = 0;

static std::shared_ptr<ComputingGraph> make();

//! assert that refcnt for ptr is one and destories the ptr
@@ -236,6 +238,26 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
std::unique_ptr<OperatorNodeBase> opr) = 0;

/*!
* \brief used by OperatorNodeBase to allocate its outputs
*/
template<typename... Args>
VarNode* alloc_varnode(Args&&... args) {
return new(alloc_varnode_storage()) VarNode(std::forward<Args>(args)...);
}

inline void free_varnode(VarNode* var) {
var->~VarNode();
free_varnode_storage(var);
}
protected:
/*!
* \brief provided by impl to support alloc_varnode
*/
virtual void* alloc_varnode_storage() = 0;

virtual void free_varnode_storage(void *ptr) = 0;
public:
/*!
* \brief get current computing sequence
*/
virtual AsyncExecutable* current_comp_seq() = 0;


+ 10
- 5
src/core/include/megbrain/utils/mempool.h View File

@@ -86,12 +86,17 @@ namespace mgb {
};
using UniquePtr = std::unique_ptr<T, Deleter>;

void* alloc_raw() {
return m_storage.alloc(Const<>::ELEM_SIZE);
}

void free_raw(void *ptr) {
m_storage.free(ptr);
}

template<typename...Args>
T* alloc(Args&&... args) {
auto ptr = static_cast<T*>(
m_storage.alloc(Const<>::ELEM_SIZE));
new(ptr) T(std::forward<Args>(args)...);
return ptr;
return new(alloc_raw()) T(std::forward<Args>(args)...);
}

template<typename...Args>
@@ -102,7 +107,7 @@ namespace mgb {

void free(T *ptr) {
ptr->~T();
m_storage.free(ptr);
free_raw(ptr);
}

//! reorder free list for cache friendly in future alloc


Loading…
Cancel
Save