|
|
@@ -20,6 +20,7 @@ |
|
|
|
#include "./proxy_graph_base.h" |
|
|
|
|
|
|
|
#include <optional> |
|
|
|
#include "megbrain/opr/utility.h" |
|
|
|
#include "range/v3/all.hpp" |
|
|
|
|
|
|
|
namespace mgb::imperative::proxy_graph { |
|
|
@@ -83,7 +84,7 @@ TensorAdaptor(T&) -> TensorAdaptor<T, void>; |
|
|
|
template <typename T> |
|
|
|
TensorAdaptor(T*) -> TensorAdaptor<T, void>; |
|
|
|
|
|
|
|
SmallVector<Tensor*> to_raw_ptr_array( |
|
|
|
inline SmallVector<Tensor*> to_raw_ptr_array( |
|
|
|
const SmallVector<TensorPtr>& inputs, bool ensure_storage = true) { |
|
|
|
SmallVector<Tensor*> ret; |
|
|
|
for (auto&& i : inputs) { |
|
|
@@ -243,6 +244,13 @@ public: |
|
|
|
vinputs[i] = opr_ref_keeper.back()->output(0); |
|
|
|
} |
|
|
|
auto ovars = OpDef::apply_on_var_node(opdef, vinputs); |
|
|
|
if (!m_opr) { |
|
|
|
// identity |
|
|
|
mgb_assert(vinputs.size() == 1 && ovars.size() == 1); |
|
|
|
mgb_assert(ovars[0] == vinputs[0]); |
|
|
|
auto&& input = vinputs[0]; |
|
|
|
ovars[0] = opr::Identity::make(input).node(); |
|
|
|
} |
|
|
|
mgb_assert(m_opr); |
|
|
|
output_data.resize(m_opr->output().size()); |
|
|
|
for (auto* v : ovars) { |
|
|
@@ -343,7 +351,6 @@ public: |
|
|
|
} else { |
|
|
|
mgb_assert(j < outputs.size()); |
|
|
|
auto&& tensor = outputs[j]; |
|
|
|
auto&& layout = tensor->layout(); |
|
|
|
if (var->m_mem_plan.chunk().owner_var != var) { |
|
|
|
tensor->assign_from_dev_tensor( |
|
|
|
var->m_dev_tensor); // memory forwarding |
|
|
@@ -613,6 +620,7 @@ class ExecMiniGraph : public ProxyGraph::MiniGraph { |
|
|
|
busy_oprs.pop_front(); |
|
|
|
return m_opr; |
|
|
|
} |
|
|
|
mgb_assert(false); |
|
|
|
} |
|
|
|
|
|
|
|
template <bool in_use> |
|
|
|