GitOrigin-RevId: 1fb68a1da2
tags/v1.9.0
@@ -121,22 +121,6 @@ private: | |||||
}; | }; | ||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ProxyGraph::InputPlaceholder); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(ProxyGraph::InputPlaceholder); | ||||
class ProxyGraph::ExecEnv final : public cg::GraphExecutable::ExecEnv { | |||||
public: | |||||
void dispatch_on_comp_node(CompNode, Task&& task) override { task(); } | |||||
void dispatch_on_comp_node_with_mask( | |||||
CompNode, Task&& task, cg::ExecutionMask* mask) override { | |||||
mgb_throw_if( | |||||
mask, GraphError, "ExecutionMask not supported in imperative mode"); | |||||
task(); | |||||
} | |||||
void pause_exec() override {} | |||||
void resume_exec() override {} | |||||
}; | |||||
class ProxyGraph::StaticInferManager : public cg::static_infer::StaticInferManager { | class ProxyGraph::StaticInferManager : public cg::static_infer::StaticInferManager { | ||||
public: | public: | ||||
using Tag = cg::static_infer::Tag; | using Tag = cg::static_infer::Tag; | ||||
@@ -183,26 +167,8 @@ public: | |||||
} | } | ||||
InferType get_infer_type(Tag var) override { | InferType get_infer_type(Tag var) override { | ||||
// may be called during get_proxy_opr or make_backward_graph | |||||
// don't let opr apply any immediate optimization | // don't let opr apply any immediate optimization | ||||
return {InferType::MISSING_INP, InferType::MISSING_INP}; | return {InferType::MISSING_INP, InferType::MISSING_INP}; | ||||
if (auto opr = var->owner_opr()->try_cast_final<InputPlaceholder>()) { | |||||
return {var->shape().ndim ? InferType::CONST : InferType::MISSING_INP, | |||||
opr->m_tensor ? InferType::CONST : InferType::MISSING_INP}; | |||||
} | |||||
if (cur_opr) { | |||||
auto&& outputs = cur_opr->output(); | |||||
auto&& it = std::find(outputs.begin(), outputs.end(), var); | |||||
if (it != outputs.end()) { | |||||
return {infer_shape_fallible(var) ? InferType::CONST | |||||
: InferType::MISSING_INP, | |||||
// value inference could be expensive | |||||
InferType::MISSING_INP}; | |||||
} | |||||
} | |||||
return {InferType::MISSING_INP, InferType::MISSING_INP}; | |||||
} | } | ||||
void update() { | void update() { | ||||
@@ -471,7 +437,6 @@ std::atomic<size_t> ProxyGraph::ProxyGraphImpl::m_node_id = 0; | |||||
ProxyGraph::ProxyGraph() | ProxyGraph::ProxyGraph() | ||||
: m_graph(ProxyGraphImpl::make(this)), | : m_graph(ProxyGraphImpl::make(this)), | ||||
m_env{new ExecEnv}, | |||||
m_static_infer_manager(new StaticInferManager(this)), | m_static_infer_manager(new StaticInferManager(this)), | ||||
m_seq_comp_node_optimizer(new SeqCompNodeOptimizer()) {} | m_seq_comp_node_optimizer(new SeqCompNodeOptimizer()) {} | ||||
@@ -506,32 +471,6 @@ private: | |||||
/*********************** Physical Tensor Impl ***********************/ | /*********************** Physical Tensor Impl ***********************/ | ||||
SmallVector<LogicalTensorDesc> ProxyGraph::infer_output_attrs( | |||||
const OpDef& opdef, const SmallVector<Tensor*>& inputs) { | |||||
SmallVector<LogicalTensorDesc> ret; | |||||
CUR_OPR_GUARD(get_proxy_opr(opdef, inputs)); | |||||
::mgb::opr::intl::WorkspaceLimitHook::set_impl( | |||||
m_graph.get(), ProxyGraph::get_workspace_limit); | |||||
do_shape_infer(true); | |||||
for (auto&& i : m_cur_opr->usable_output()) { | |||||
mgb_assert(i->dtype().valid() && i->comp_node().valid()); | |||||
mgb_assert(i->shape().ndim || i->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)); | |||||
ret.push_back({{i->shape(), i->dtype()}, i->comp_node()}); | |||||
} | |||||
return ret; | |||||
} | |||||
void ProxyGraph::invoke_op( | |||||
const OpDef& opdef, const SmallVector<Tensor*>& inputs, | |||||
const SmallVector<Tensor*>& outputs, const SmallVector<Tensor*>& workspaces) { | |||||
CUR_OPR_GUARD(get_proxy_opr(opdef, inputs)); | |||||
init_output_tensor(outputs, workspaces); | |||||
for (auto oup : m_cur_opr->output()) { | |||||
m_graph->add_used_comp_node(oup->comp_node()); | |||||
} | |||||
m_cur_opr->execute(*m_env); | |||||
} | |||||
void ProxyGraph::cleanup() { | void ProxyGraph::cleanup() { | ||||
if (m_cur_opr) { | if (m_cur_opr) { | ||||
for (auto&& i : m_cur_opr->input()) { | for (auto&& i : m_cur_opr->input()) { | ||||
@@ -545,102 +484,8 @@ void ProxyGraph::cleanup() { | |||||
m_cur_opr = nullptr; | m_cur_opr = nullptr; | ||||
} | } | ||||
void ProxyGraph::init_output_tensor( | |||||
const SmallVector<Tensor*>& outputs, const SmallVector<Tensor*>& workspaces) { | |||||
// get proxy opr | |||||
auto proxy = m_cur_opr; | |||||
auto get_workspace_size = [=](CompNode cn, size_t old_limit) { | |||||
size_t limit = 0; | |||||
for (auto&& var : workspaces) { | |||||
limit += var->dtype().size(var->shape().total_nr_elems()); | |||||
} | |||||
return limit; | |||||
}; | |||||
::mgb::opr::intl::WorkspaceLimitHook::set_impl(m_graph.get(), get_workspace_size); | |||||
do_shape_infer(true); | |||||
size_t j = 0; | |||||
size_t k = 0; | |||||
for (auto&& var : proxy->output()) { | |||||
auto&& chk = var->m_mem_plan.reset_from_owner_var().chunk(); | |||||
if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | |||||
// workspace | |||||
if (workspaces.size()) { | |||||
mgb_assert(k < workspaces.size()); | |||||
auto&& layout = workspaces[k]->layout(); | |||||
mgb_assert( | |||||
var->comp_node() == workspaces[k]->comp_node() && | |||||
var->shape().eq_shape(layout) && var->dtype() == layout.dtype); | |||||
var->m_dev_tensor = workspaces[k]->dev_tensor(); | |||||
++k; | |||||
} else { | |||||
TensorLayout layout{var->shape(), var->dtype(), var->format()}; | |||||
var->m_dev_tensor = BlobManager::inst()->alloc_workspace_with_defrag( | |||||
var->comp_node(), layout); | |||||
} | |||||
} else { | |||||
mgb_assert(j < outputs.size()); | |||||
auto&& tensor = outputs[j]; | |||||
auto&& layout = tensor->layout(); | |||||
mgb_assert( | |||||
var->comp_node() == tensor->comp_node() && | |||||
var->shape().eq_shape(layout) && var->dtype() == layout.dtype); | |||||
var->assign_dev_tensor_from_tensor(tensor->dev_tensor()); | |||||
++j; | |||||
} | |||||
chk.mem_alloc_status.set_from_owner_var(); | |||||
} | |||||
mgb_assert(j == outputs.size()); | |||||
mgb_assert(k == workspaces.size()); | |||||
// Memory forwarding was bypassed in megbrain with graph option | |||||
// imerative_proxy_graph on, here we call mem_plan_fwd_in2out_readonly | |||||
// to initialize some opr(e.g. Subtensor)'s internal state | |||||
// TODO: implement memory forwarding | |||||
proxy->mem_plan_fwd_in2out_readonly(); | |||||
{ | |||||
// some opr (e.g. Reduce) rely on on_mem_status_changed to set | |||||
// input/output tensor corretly, since we bypass var_node_mem_mgr | |||||
// on_mem_status_changed should be called here | |||||
auto&& cb = proxy->get_opr_event_callback().on_mem_status_changed; | |||||
if (cb.valid()) { | |||||
cb.val()(); | |||||
} | |||||
} | |||||
} | |||||
cg::OperatorNodeBase* ProxyGraph::get_proxy_opr( | |||||
const OpDef& opdef, const SmallVector<Tensor*>& inputs) { | |||||
VarNodeArray vinputs(inputs.size()); | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
vinputs[i] = InputPlaceholder::make(*m_graph, *inputs[i]).node(); | |||||
} | |||||
auto opr = OpDef::apply_on_var_node(opdef, vinputs)[0]->owner_opr(); | |||||
mgb_assert(!opr->same_type<InputPlaceholder>()); | |||||
for (auto&& i : opr->input()) { | |||||
mgb_assert(i->owner_opr()->same_type<InputPlaceholder>()); | |||||
} | |||||
return opr; | |||||
} | |||||
/*********************** Logical Tensor Impl ***********************/ | /*********************** Logical Tensor Impl ***********************/ | ||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> ProxyGraph:: | |||||
infer_output_attrs_fallible( | |||||
const OpDef& opdef, const SmallVector<LogicalTensorDesc>& inputs) { | |||||
// this function is just a placeholder | |||||
// it will be overrided by ProxyGraphTypeI::infer_output_attrs_fallible in minigraph | |||||
mgb_assert(0); | |||||
} | |||||
struct ProxyGraph::GradGraph { | |||||
cg::VarNodeArray inputs; | |||||
cg::VarNodeArray outputs; | |||||
cg::VarNodeArray output_grads; | |||||
cg::VarNode* grad; | |||||
}; | |||||
EncodedSubgraph ProxyGraph::make_backward_graph( | EncodedSubgraph ProxyGraph::make_backward_graph( | ||||
const OpDef& opdef, const SmallVector<LogicalTensorDesc>& input_descs, | const OpDef& opdef, const SmallVector<LogicalTensorDesc>& input_descs, | ||||
const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
@@ -793,22 +638,6 @@ VarNodeArray ProxyGraph::make_input_place_holders( | |||||
/*********************** Common Impl ***********************/ | /*********************** Common Impl ***********************/ | ||||
bool ProxyGraph::do_shape_infer(bool sync_value) { | |||||
m_static_infer_manager->update(); | |||||
bool validated = true; | |||||
for (auto* var : m_cur_opr->output()) { | |||||
if (sync_value) { | |||||
var->shape(m_static_infer_manager->infer_shape(var)); | |||||
} else if (auto* shape = m_static_infer_manager->infer_shape_fallible(var)) { | |||||
var->shape(*shape); | |||||
} else { | |||||
validated = false; | |||||
} | |||||
} | |||||
return validated; | |||||
} | |||||
TensorPtr ProxyGraph::as_tensor(cg::OperatorNodeBase* opr, bool share) { | TensorPtr ProxyGraph::as_tensor(cg::OperatorNodeBase* opr, bool share) { | ||||
// TODO : maybe some tensor should copy value from origin opr rather than | // TODO : maybe some tensor should copy value from origin opr rather than | ||||
// share the RawStorage | // share the RawStorage | ||||
@@ -27,44 +27,22 @@ public: | |||||
static std::unique_ptr<MegBrainError> get_async_error() { | static std::unique_ptr<MegBrainError> get_async_error() { | ||||
return std::move(tm_async_error); | return std::move(tm_async_error); | ||||
} | } | ||||
static size_t get_workspace_limit(CompNode cn, size_t old_limit) { | |||||
size_t free = cn.get_free_mem(); | |||||
size_t lmt = cn.get_max_block_size_available(); | |||||
return std::max(lmt, free); | |||||
} | |||||
/********************** Physical Tensor API **********************/ | /********************** Physical Tensor API **********************/ | ||||
SmallVector<LogicalTensorDesc> infer_output_attrs( | |||||
const OpDef& opdef, const SmallVector<Tensor*>& inputs); | |||||
void invoke_op( | |||||
const OpDef& opdef, const SmallVector<Tensor*>& inputs, | |||||
const SmallVector<Tensor*>& outputs, const SmallVector<Tensor*>& workspace); | |||||
EncodedSubgraph make_backward_graph( | EncodedSubgraph make_backward_graph( | ||||
const OpDef& opdef, const SmallVector<LogicalTensorDesc>& input_descs, | const OpDef& opdef, const SmallVector<LogicalTensorDesc>& input_descs, | ||||
const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
const SmallVector<bool>& output_has_grad); | const SmallVector<bool>& output_has_grad); | ||||
/********************** Logical Tensor API **********************/ | |||||
size_t get_opr_output_size( | |||||
const OpDef& opdef, const SmallVector<LogicalTensorDesc>& inputs); | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
const OpDef& opdef, const SmallVector<LogicalTensorDesc>& inputs); | |||||
private: | private: | ||||
ProxyGraph(); | ProxyGraph(); | ||||
class ProxyGraphImpl; | class ProxyGraphImpl; | ||||
class ExecEnv; | |||||
class StaticInferManager; | class StaticInferManager; | ||||
class SeqCompNodeOptimizer; | class SeqCompNodeOptimizer; | ||||
class InputPlaceholder; | class InputPlaceholder; | ||||
struct ProxyGraphInst; | struct ProxyGraphInst; | ||||
struct GradGraph; | |||||
class CurOprGuard; | class CurOprGuard; | ||||
void reset(); | void reset(); | ||||
@@ -73,12 +51,6 @@ private: | |||||
void cleanup(); | void cleanup(); | ||||
void init_output_tensor( | |||||
const SmallVector<Tensor*>& outputs, const SmallVector<Tensor*>& workspace); | |||||
cg::OperatorNodeBase* get_proxy_opr( | |||||
const OpDef& opdef, const SmallVector<Tensor*>& inputs); | |||||
/********************** Logical Tensor Helper **********************/ | /********************** Logical Tensor Helper **********************/ | ||||
cg::VarNodeArray make_input_place_holders( | cg::VarNodeArray make_input_place_holders( | ||||
@@ -86,14 +58,11 @@ private: | |||||
/********************** Common Helper **********************/ | /********************** Common Helper **********************/ | ||||
bool do_shape_infer(bool sync_value); | |||||
TensorPtr as_tensor(cg::OperatorNodeBase* opr, bool share = true); | TensorPtr as_tensor(cg::OperatorNodeBase* opr, bool share = true); | ||||
cg::OperatorNodeBase* m_cur_opr = nullptr; | cg::OperatorNodeBase* m_cur_opr = nullptr; | ||||
std::unique_ptr<ProxyGraphImpl> m_graph; | std::unique_ptr<ProxyGraphImpl> m_graph; | ||||
size_t m_max_op_cnt = 100; | size_t m_max_op_cnt = 100; | ||||
std::unique_ptr<ExecEnv> m_env; | |||||
std::unique_ptr<StaticInferManager> m_static_infer_manager; | std::unique_ptr<StaticInferManager> m_static_infer_manager; | ||||
std::unique_ptr<SeqCompNodeOptimizer> m_seq_comp_node_optimizer; | std::unique_ptr<SeqCompNodeOptimizer> m_seq_comp_node_optimizer; | ||||
@@ -801,18 +801,19 @@ public: | |||||
return ret; | return ret; | ||||
} | } | ||||
SmallVector<LogicalTensorDesc> infer_output_attrs( | |||||
const OpDef& def, const SmallVector<Tensor*>& inputs) { | |||||
SmallVector<LogicalTensorDesc> descs; | |||||
auto& minigraph = get_cached_minigraph(def, inputs); | |||||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
const OpDef& def, SmallVector<TensorPtr> inputs) { | |||||
auto raw_inputs = to_raw_ptr_array(inputs); | |||||
auto& minigraph = get_cached_minigraph(def, raw_inputs); | |||||
auto _ = scoped_attach(&minigraph); | auto _ = scoped_attach(&minigraph); | ||||
auto sess = minigraph.infer_session(inputs); | |||||
auto sess = minigraph.infer_session(raw_inputs); | |||||
::mgb::opr::intl::WorkspaceLimitHook::set_impl( | |||||
minigraph.opr()->owner_graph(), get_workspace_limit); | |||||
// some output var in minigraph.opr()->output() may not appears in | // some output var in minigraph.opr()->output() may not appears in | ||||
// minigraph.opr()->usable_output() bug execution may use the attrs for those | // minigraph.opr()->usable_output() bug execution may use the attrs for those | ||||
// output var, so we infer attrs for all outputs, but only return | // output var, so we infer attrs for all outputs, but only return | ||||
// LogicalTensorDesc for minigraph.opr()->usable_output() | // LogicalTensorDesc for minigraph.opr()->usable_output() | ||||
::mgb::opr::intl::WorkspaceLimitHook::set_impl( | |||||
minigraph.opr()->owner_graph(), get_workspace_limit); | |||||
SmallVector<LogicalTensorDesc> output_descs; | |||||
for (size_t i = 0; i < minigraph.opr()->output().size(); ++i) { | for (size_t i = 0; i < minigraph.opr()->output().size(); ++i) { | ||||
auto* shape = sess.infer(sess.output_data[i].shape_infer, true); | auto* shape = sess.infer(sess.output_data[i].shape_infer, true); | ||||
mgb_assert(shape); | mgb_assert(shape); | ||||
@@ -825,15 +826,9 @@ public: | |||||
mgb_assert( | mgb_assert( | ||||
ovar->shape().ndim || | ovar->shape().ndim || | ||||
ovar->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)); | ovar->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)); | ||||
descs.push_back({{ovar->shape(), ovar->dtype()}, ovar->comp_node()}); | |||||
output_descs.push_back({{ovar->shape(), ovar->dtype()}, ovar->comp_node()}); | |||||
} | } | ||||
return descs; | |||||
} | |||||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
const OpDef& def, SmallVector<TensorPtr> inputs) { | |||||
auto raw_inputs = to_raw_ptr_array(inputs); | |||||
auto output_descs = infer_output_attrs(def, raw_inputs); | |||||
SmallVector<TensorPtr> outputs(output_descs.size(), {}); | SmallVector<TensorPtr> outputs(output_descs.size(), {}); | ||||
for (size_t i = 0; i < outputs.size(); i++) { | for (size_t i = 0; i < outputs.size(); i++) { | ||||
outputs[i] = | outputs[i] = | ||||
@@ -853,11 +848,8 @@ public: | |||||
} | } | ||||
} | } | ||||
} | } | ||||
auto& minigraph = get_cached_minigraph(def, raw_inputs); | |||||
auto _ = scoped_attach(&minigraph); | |||||
// some opr (e.g. Subtensor) may invoke infer_value during execution, | // some opr (e.g. Subtensor) may invoke infer_value during execution, | ||||
// so we need create inference session here | // so we need create inference session here | ||||
auto sess = minigraph.infer_session(raw_inputs); | |||||
minigraph.execute(raw_inputs, raw_outputs, m_env); | minigraph.execute(raw_inputs, raw_outputs, m_env); | ||||
for (auto&& cn : used_cns) { | for (auto&& cn : used_cns) { | ||||
for (auto&& in : inputs) { | for (auto&& in : inputs) { | ||||
@@ -10,11 +10,6 @@ | |||||
*/ | */ | ||||
#include "./mini_graph.h" | #include "./mini_graph.h" | ||||
#if 0 | |||||
// ../proxy_graph.h is deprecated, leave here for debug purpose | |||||
// uncomment #if 0 macro to debug | |||||
#include "../proxy_graph.h" | |||||
#endif | |||||
namespace mgb::imperative::proxy_graph { | namespace mgb::imperative::proxy_graph { | ||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ProxyGraph::InputPlaceholder); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(ProxyGraph::InputPlaceholder); | ||||
@@ -28,18 +23,6 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | ||||
auto ret = proxy_graph::ProxyGraphTypeI::inst().infer_output_attrs_fallible( | auto ret = proxy_graph::ProxyGraphTypeI::inst().infer_output_attrs_fallible( | ||||
def, inputs); | def, inputs); | ||||
#if 0 | |||||
// delete me after the new implementation is stable | |||||
auto ref = ProxyGraph::get_default_graph()->infer_output_attrs_fallible(def, inputs); | |||||
auto& [a, _1] = ret; | |||||
auto& [b, _2] = ref; | |||||
if (a.size() != b.size()) mgb_trap(); | |||||
for (size_t i = 0; i < a.size(); ++i) { | |||||
if (a[i].layout.dtype != b[i].layout.dtype) mgb_trap(); | |||||
if (a[i].comp_node != b[i].comp_node) mgb_trap(); | |||||
if (!a[i].layout.eq_shape(b[i].layout)) mgb_trap(); | |||||
} | |||||
#endif | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -17,83 +17,6 @@ namespace mgb { | |||||
namespace imperative { | namespace imperative { | ||||
namespace proxy_graph_detail { | namespace proxy_graph_detail { | ||||
// those functions are reimplemented with opr cache | |||||
// in ./proxy_graph/mini_graph.h | |||||
#if 0 | |||||
namespace { | |||||
SmallVector<Tensor*> to_raw_ptr_array( | |||||
const SmallVector<TensorPtr>& inputs, bool ensure_storage = true) { | |||||
SmallVector<Tensor*> ret; | |||||
for (auto&& i : inputs) { | |||||
mgb_assert(i); | |||||
ret.push_back(i.get()); | |||||
if (ensure_storage) { | |||||
// apply lazy allocation | |||||
i->blob()->storage(); | |||||
} | |||||
} | |||||
return ret; | |||||
} | |||||
SmallVector<LogicalTensorDesc> infer_output_attrs( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||||
auto&& graph = ProxyGraph::get_default_graph(); | |||||
return graph->infer_output_attrs(def, to_raw_ptr_array(inputs)); | |||||
} | |||||
} // anonymous namespace | |||||
void exec( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
const SmallVector<TensorPtr>& outputs, | |||||
const SmallVector<TensorPtr>& workspaces) { | |||||
auto&& graph = ProxyGraph::get_default_graph(); | |||||
auto raw_inputs = to_raw_ptr_array(inputs), raw_outputs = to_raw_ptr_array(outputs), | |||||
raw_workspaces = to_raw_ptr_array(workspaces); | |||||
CompNode::UnorderedSet used_cns; | |||||
for (auto&& out : raw_outputs) { | |||||
auto cn = out->comp_node(); | |||||
if (used_cns.insert(cn).second) { | |||||
for (auto&& in : inputs) { | |||||
if (in->comp_node() != cn) { | |||||
auto&& e = in->get_or_create_event(); | |||||
e->device_wait_by(cn); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
graph->invoke_op(def, raw_inputs, raw_outputs, raw_workspaces); | |||||
for (auto&& cn : used_cns) { | |||||
for (auto&& in : inputs) { | |||||
if (in->comp_node() != cn) { | |||||
in->add_release_callback(cn); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
const OpDef& def, SmallVector<TensorPtr> inputs) { | |||||
auto output_descs = infer_output_attrs(def, inputs); | |||||
SmallVector<TensorPtr> outputs(output_descs.size(), {}); | |||||
for (size_t i = 0; i < outputs.size(); i++) { | |||||
outputs[i] = Tensor::make(output_descs[i].layout, output_descs[i].comp_node); | |||||
} | |||||
exec(def, inputs, outputs, {}); | |||||
auto async_error = ProxyGraph::get_async_error(); | |||||
if (async_error) { | |||||
throw *async_error; | |||||
} | |||||
return outputs; | |||||
} | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def, | |||||
const SmallVector<LogicalTensorDesc>& inputs) { | |||||
auto&& graph = ProxyGraph::get_default_graph(); | |||||
return graph->infer_output_attrs_fallible(def, inputs); | |||||
} | |||||
#endif | |||||
EncodedSubgraph make_backward_graph( | EncodedSubgraph make_backward_graph( | ||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs, | const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs, | ||||
const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||