Browse Source

fix(minigraph): supports varnode forwarding

GitOrigin-RevId: 4494106f0a
release-1.10
Megvii Engine Team 3 years ago
parent
commit
951ed476d6
1 changed files with 10 additions and 2 deletions
  1. +10
    -2
      imperative/src/impl/proxy_graph/mini_graph.h

+ 10
- 2
imperative/src/impl/proxy_graph/mini_graph.h View File

@@ -20,6 +20,7 @@
#include "./proxy_graph_base.h"

#include <optional>
#include "megbrain/opr/utility.h"
#include "range/v3/all.hpp"

namespace mgb::imperative::proxy_graph {
@@ -83,7 +84,7 @@ TensorAdaptor(T&) -> TensorAdaptor<T, void>;
template <typename T>
TensorAdaptor(T*) -> TensorAdaptor<T, void>;

SmallVector<Tensor*> to_raw_ptr_array(
inline SmallVector<Tensor*> to_raw_ptr_array(
const SmallVector<TensorPtr>& inputs, bool ensure_storage = true) {
SmallVector<Tensor*> ret;
for (auto&& i : inputs) {
@@ -243,6 +244,13 @@ public:
vinputs[i] = opr_ref_keeper.back()->output(0);
}
auto ovars = OpDef::apply_on_var_node(opdef, vinputs);
if (!m_opr) {
// identity
mgb_assert(vinputs.size() == 1 && ovars.size() == 1);
mgb_assert(ovars[0] == vinputs[0]);
auto&& input = vinputs[0];
ovars[0] = opr::Identity::make(input).node();
}
mgb_assert(m_opr);
output_data.resize(m_opr->output().size());
for (auto* v : ovars) {
@@ -343,7 +351,6 @@ public:
} else {
mgb_assert(j < outputs.size());
auto&& tensor = outputs[j];
auto&& layout = tensor->layout();
if (var->m_mem_plan.chunk().owner_var != var) {
tensor->assign_from_dev_tensor(
var->m_dev_tensor); // memory forwarding
@@ -613,6 +620,7 @@ class ExecMiniGraph : public ProxyGraph::MiniGraph {
busy_oprs.pop_front();
return m_opr;
}
mgb_assert(false);
}

template <bool in_use>


Loading…
Cancel
Save