GitOrigin-RevId: 171301fc2b
tags/v1.6.0-rc1
@@ -100,6 +100,19 @@ std::vector<std::pair<const char*, std::string>> OpDef::props( | |||||
return def.trait()->props(def); | return def.trait()->props(def); | ||||
} | } | ||||
EncodedSubraph OpDef::make_forward_graph( | |||||
const OpDef& def, | |||||
const SmallVector<LogicalTensorDesc>& inputs){ | |||||
using ForwardGraphCache = OpMethResultCache<EncodedSubraph, SmallVector<bool>, SmallVector<bool>>; | |||||
thread_local ForwardGraphCache cache; | |||||
decltype(cache)::key_t cache_key{const_cast<OpDef&>(def).shared_from_this(), inputs}; | |||||
auto iter = cache.find(cache_key); | |||||
if (iter == cache.end()) { | |||||
iter = cache.insert({cache_key, def.trait()->make_forward_graph(def, inputs)}).first; | |||||
} | |||||
return iter->second; | |||||
} | |||||
std::string OpDef::to_string() const { | std::string OpDef::to_string() const { | ||||
std::string builder = trait()->make_name(*this) + "{"; | std::string builder = trait()->make_name(*this) + "{"; | ||||
for (auto&& [name, value]: props(*this)) { | for (auto&& [name, value]: props(*this)) { | ||||
@@ -16,6 +16,7 @@ | |||||
#include "megbrain/imperative/op_def.h" | #include "megbrain/imperative/op_def.h" | ||||
#include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
#include "megbrain/imperative/proxy_graph_detail.h" | #include "megbrain/imperative/proxy_graph_detail.h" | ||||
#include "megbrain/imperative/subgraph_detail.h" | |||||
#include "megbrain/tensor.h" | #include "megbrain/tensor.h" | ||||
#include "./op_trait.h" | #include "./op_trait.h" | ||||
@@ -38,24 +39,45 @@ StaticData& static_data() { | |||||
return data; | return data; | ||||
} | } | ||||
void OpMethFallback::impl(ApplyOnPhysicalTensor& func, | |||||
void OpMethFallbackByProxyGraph::impl(ApplyOnPhysicalTensor& func, | |||||
op_meth_tag::ApplyOnPhysicalTensor) { | op_meth_tag::ApplyOnPhysicalTensor) { | ||||
func.Base::operator=(proxy_graph_detail::apply_on_physical_tensor); | func.Base::operator=(proxy_graph_detail::apply_on_physical_tensor); | ||||
} | } | ||||
void OpMethFallback::impl(Execute& func, op_meth_tag::Execute) { | |||||
void OpMethFallbackByProxyGraph::impl(Execute& func, op_meth_tag::Execute) { | |||||
func.Base::operator=(proxy_graph_detail::execute); | func.Base::operator=(proxy_graph_detail::execute); | ||||
} | } | ||||
void OpMethFallback::impl(InferOutputMemDesc& func, | |||||
void OpMethFallbackByProxyGraph::impl(InferOutputMemDesc& func, | |||||
op_meth_tag::InferOutputMemDesc) { | op_meth_tag::InferOutputMemDesc) { | ||||
func.Base::operator=(proxy_graph_detail::infer_output_mem_desc); | func.Base::operator=(proxy_graph_detail::infer_output_mem_desc); | ||||
} | } | ||||
void OpMethFallback::impl(InferOutputAttrsFallible& func, | |||||
void OpMethFallbackByProxyGraph::impl(InferOutputAttrsFallible& func, | |||||
op_meth_tag::InferOutputAttrsFallible) { | op_meth_tag::InferOutputAttrsFallible) { | ||||
func.Base::operator=(proxy_graph_detail::infer_output_attrs_fallible); | func.Base::operator=(proxy_graph_detail::infer_output_attrs_fallible); | ||||
} | } | ||||
void OpMethFallback::impl(GradMaker& func, op_meth_tag::GradMaker) { | |||||
void OpMethFallbackByProxyGraph::impl(GradMaker& func, op_meth_tag::GradMaker) { | |||||
func.Base::operator=(proxy_graph_detail::make_backward_graph); | func.Base::operator=(proxy_graph_detail::make_backward_graph); | ||||
} | } | ||||
void OpMethFallbackFromSubgraph::impl(ApplyOnPhysicalTensor& func, | |||||
op_meth_tag::ApplyOnPhysicalTensor) { | |||||
func.Base::operator=(subgraph_detail::apply_on_physical_tensor); | |||||
} | |||||
void OpMethFallbackFromSubgraph::impl(InferOutputMemDesc& func, | |||||
op_meth_tag::InferOutputMemDesc) { | |||||
func.Base::operator=(subgraph_detail::infer_output_mem_desc); | |||||
} | |||||
void OpMethFallbackFromSubgraph::impl(ApplyOnVarNode& func, | |||||
op_meth_tag::ApplyOnVarNode) { | |||||
func.Base::operator=(subgraph_detail::apply_on_var_node); | |||||
} | |||||
void OpMethFallbackFromSubgraph::impl(InferOutputAttrsFallible& func, | |||||
op_meth_tag::InferOutputAttrsFallible) { | |||||
func.Base::operator=(subgraph_detail::infer_output_attrs_fallible); | |||||
} | |||||
void OpMethFallbackFromSubgraph::impl(GradMaker& func, op_meth_tag::GradMaker) { | |||||
func.Base::operator=(subgraph_detail::make_backward_graph); | |||||
} | |||||
void OpMethFallback::impl(DecideDispatchMode& func, | void OpMethFallback::impl(DecideDispatchMode& func, | ||||
op_meth_tag::DecideDispatchMode) { | op_meth_tag::DecideDispatchMode) { | ||||
static auto decide_dispatch_mode = | static auto decide_dispatch_mode = | ||||
@@ -99,16 +121,20 @@ void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){ | |||||
} | } | ||||
OpTraitRegistry& OpTraitRegistry::fallback() { | OpTraitRegistry& OpTraitRegistry::fallback() { | ||||
using Mode = detail::OpMethFallbackMode; | |||||
uint64_t mode = Mode::None; | |||||
if (trait->make_forward_graph) { | |||||
mode |= Mode::FromSubgraph; | |||||
} | |||||
if (trait->apply_on_var_node) { | if (trait->apply_on_var_node) { | ||||
// fallback to proxy graph impl | |||||
trait->apply_on_physical_tensor.allow_fallback = true; | |||||
trait->execute.allow_fallback = true; | |||||
trait->infer_output_mem_desc.allow_fallback = true; | |||||
trait->infer_output_attrs_fallible.allow_fallback = true; | |||||
trait->make_backward_graph.allow_fallback = true; | |||||
mode |= Mode::ByProxyGraph; | |||||
} | } | ||||
trait->decide_dispatch_mode.allow_fallback = true; | |||||
trait->make_name.allow_fallback = true; | |||||
mode |= Mode::Default; | |||||
#define SET_FALLBACK_MODE(meth) \ | |||||
trait->meth.fallback_mode = mode; | |||||
FOR_EACH_OP_METH(SET_FALLBACK_MODE) | |||||
#undef SET_FALLBACK_MODE | |||||
return *this; | return *this; | ||||
} | } | ||||
@@ -95,9 +95,18 @@ OpMethType(IsSame, | |||||
OpMethType(MakeNameFunc, | OpMethType(MakeNameFunc, | ||||
std::string(const OpDef&)); | std::string(const OpDef&)); | ||||
OpMethType(GraphMaker, | |||||
decltype(OpDef::make_forward_graph)); | |||||
// clang-format on | // clang-format on | ||||
namespace detail { | namespace detail { | ||||
struct OpMethImplBase { | |||||
template <typename Tag, typename RType, typename... Args> | |||||
static void impl(thin_function<RType(Args...)>& func, Tag) {} | |||||
}; | |||||
struct OpMethNotImpl { | struct OpMethNotImpl { | ||||
template <typename Tag, typename RType, typename... Args> | template <typename Tag, typename RType, typename... Args> | ||||
static void impl(thin_function<RType(Args...)>& func, Tag) { | static void impl(thin_function<RType(Args...)>& func, Tag) { | ||||
@@ -106,8 +115,15 @@ struct OpMethNotImpl { | |||||
}; | }; | ||||
} | } | ||||
}; | }; | ||||
struct OpMethFallback : public OpMethNotImpl { | |||||
using OpMethNotImpl::impl; | |||||
struct OpMethFallback: OpMethImplBase { | |||||
using OpMethImplBase::impl; | |||||
static void impl(DecideDispatchMode& func, op_meth_tag::DecideDispatchMode); | |||||
static void impl(MakeNameFunc& func, op_meth_tag::MakeNameFunc); | |||||
}; | |||||
struct OpMethFallbackByProxyGraph: OpMethImplBase { | |||||
using OpMethImplBase::impl; | |||||
static void impl(ApplyOnPhysicalTensor& func, | static void impl(ApplyOnPhysicalTensor& func, | ||||
op_meth_tag::ApplyOnPhysicalTensor); | op_meth_tag::ApplyOnPhysicalTensor); | ||||
static void impl(Execute& func, op_meth_tag::Execute); | static void impl(Execute& func, op_meth_tag::Execute); | ||||
@@ -115,18 +131,48 @@ struct OpMethFallback : public OpMethNotImpl { | |||||
static void impl(InferOutputAttrsFallible& func, | static void impl(InferOutputAttrsFallible& func, | ||||
op_meth_tag::InferOutputAttrsFallible); | op_meth_tag::InferOutputAttrsFallible); | ||||
static void impl(GradMaker& func, op_meth_tag::GradMaker); | static void impl(GradMaker& func, op_meth_tag::GradMaker); | ||||
static void impl(DecideDispatchMode& func, op_meth_tag::DecideDispatchMode); | |||||
static void impl(MakeNameFunc& func, op_meth_tag::MakeNameFunc); | |||||
}; | }; | ||||
struct OpMethFallbackFromSubgraph: OpMethImplBase { | |||||
using OpMethImplBase::impl; | |||||
static void impl(ApplyOnPhysicalTensor& func, | |||||
op_meth_tag::ApplyOnPhysicalTensor); | |||||
static void impl(InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc); | |||||
static void impl(ApplyOnVarNode& func, op_meth_tag::ApplyOnVarNode); | |||||
static void impl(InferOutputAttrsFallible& func, | |||||
op_meth_tag::InferOutputAttrsFallible); | |||||
static void impl(GradMaker& func, op_meth_tag::GradMaker); | |||||
}; | |||||
struct OpMethFallbackMode { | |||||
static constexpr uint64_t None = 0; | |||||
static constexpr uint64_t Default = 1; | |||||
static constexpr uint64_t ByProxyGraph = 2; | |||||
static constexpr uint64_t FromSubgraph = 4; | |||||
}; | |||||
template <typename Tag, typename RType, typename... Args> | template <typename Tag, typename RType, typename... Args> | ||||
struct OpMeth<Tag, RType(Args...)> : public thin_function<RType(Args...)> { | struct OpMeth<Tag, RType(Args...)> : public thin_function<RType(Args...)> { | ||||
using Base = thin_function<RType(Args...)>; | using Base = thin_function<RType(Args...)>; | ||||
OpMeth() : Base{}, allow_fallback(false){}; | |||||
OpMeth() : Base{}{}; | |||||
explicit OpMeth(const Base& base) { this->Base::operator=(base); } | explicit OpMeth(const Base& base) { this->Base::operator=(base); } | ||||
using Base::operator bool; | using Base::operator bool; | ||||
RType operator()(Args... args) const { | RType operator()(Args... args) const { | ||||
if (!this->Base::operator bool()) { | |||||
if (allow_fallback) { | |||||
uint64_t mode_mask = ~uint64_t(0); | |||||
auto match_mode = [&](uint64_t mode){ | |||||
if ((fallback_mode & mode_mask) & mode) { | |||||
mode_mask &= ~mode; | |||||
return true; | |||||
} | |||||
return false; | |||||
}; | |||||
while (!this->Base::operator bool()) { | |||||
using Mode = OpMethFallbackMode; | |||||
if (match_mode(Mode::FromSubgraph)) { | |||||
OpMethFallbackFromSubgraph::impl(*const_cast<OpMeth*>(this), Tag{}); | |||||
} else if (match_mode(Mode::ByProxyGraph)) { | |||||
OpMethFallbackByProxyGraph::impl(*const_cast<OpMeth*>(this), Tag{}); | |||||
} else if (match_mode(Mode::Default)) { | |||||
OpMethFallback::impl(*const_cast<OpMeth*>(this), Tag{}); | OpMethFallback::impl(*const_cast<OpMeth*>(this), Tag{}); | ||||
} else { | } else { | ||||
OpMethNotImpl::impl(*const_cast<OpMeth*>(this), Tag{}); | OpMethNotImpl::impl(*const_cast<OpMeth*>(this), Tag{}); | ||||
@@ -134,7 +180,7 @@ struct OpMeth<Tag, RType(Args...)> : public thin_function<RType(Args...)> { | |||||
} | } | ||||
return this->Base::operator()(std::forward<Args>(args)...); | return this->Base::operator()(std::forward<Args>(args)...); | ||||
} | } | ||||
bool allow_fallback = false; | |||||
uint64_t fallback_mode = OpMethFallbackMode::None; | |||||
}; | }; | ||||
} // namespace detail | } // namespace detail | ||||
@@ -153,6 +199,7 @@ struct OpTrait { | |||||
HashFunc hash; | HashFunc hash; | ||||
IsSame is_same_st; | IsSame is_same_st; | ||||
MakeNameFunc make_name; | MakeNameFunc make_name; | ||||
GraphMaker make_forward_graph; | |||||
OpTrait(const char* name); | OpTrait(const char* name); | ||||
static OpTrait* find_by_name(const char* name); | static OpTrait* find_by_name(const char* name); | ||||
static OpTrait* find_by_typeinfo(Typeinfo* type); | static OpTrait* find_by_typeinfo(Typeinfo* type); | ||||
@@ -173,7 +220,9 @@ struct OpTrait { | |||||
cb(props) \ | cb(props) \ | ||||
cb(hash) \ | cb(hash) \ | ||||
cb(is_same_st) \ | cb(is_same_st) \ | ||||
cb(make_name) | |||||
cb(make_name) \ | |||||
cb(make_forward_graph) \ | |||||
// clang-format on | // clang-format on | ||||
struct OpTraitRegistry { | struct OpTraitRegistry { | ||||
@@ -0,0 +1,169 @@ | |||||
/** | |||||
* \file imperative/src/impl/subgraph_detail.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/imperative/subgraph_detail.h" | |||||
#include "megbrain/imperative/graph_builder.h" | |||||
#include "megbrain/opr/io.h" | |||||
#include "megbrain/imperative/ops/autogen.h" | |||||
#include "./op_trait.h" | |||||
namespace mgb { | |||||
namespace imperative { | |||||
namespace subgraph_detail { | |||||
VarNodeArray apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
SmallVector<LogicalTensorDesc> input_descs; | |||||
for (auto&& input: inputs) { | |||||
input_descs.push_back({TensorLayout{input->dtype()}, input->comp_node()}); | |||||
} | |||||
auto apply_functor = [](const std::shared_ptr<OpDef>& op, const VarNodeArray& inputs, size_t nr_outputs){ | |||||
return OpDef::apply_on_var_node(*op, inputs); | |||||
}; | |||||
auto const_functor = [&](const TensorPtr& value) { | |||||
return opr::ImmutableTensor::make(*inputs[0]->owner_graph(), value->get_value()).node(); | |||||
}; | |||||
auto subgraph = def.trait()->make_forward_graph(def, input_descs); | |||||
auto outputs = subgraph.apply(inputs, apply_functor, const_functor); | |||||
return outputs; | |||||
} | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
const OpDef& def, | |||||
const SmallVector<LogicalTensorDesc>& inputs) { | |||||
auto subgraph = def.trait()->make_forward_graph(def, inputs); | |||||
bool all_validated = true; | |||||
auto apply_functor = [&](const std::shared_ptr<OpDef>& op, const SmallVector<LogicalTensorDesc>& inputs, size_t nr_outputs){ | |||||
auto [outputs, validated] = OpDef::infer_output_attrs_fallible(*op, inputs); | |||||
all_validated = all_validated && validated; | |||||
return outputs; | |||||
}; | |||||
auto const_functor = [&](const TensorPtr& value) { | |||||
return LogicalTensorDesc{value->layout(), value->comp_node(), value->get_value().proxy_to_default_cpu()}; | |||||
}; | |||||
auto outputs = subgraph.apply(inputs, apply_functor, const_functor); | |||||
return { outputs, all_validated }; | |||||
} | |||||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
const OpDef& def, | |||||
SmallVector<TensorPtr> inputs) { | |||||
SmallVector<LogicalTensorDesc> input_descs; | |||||
for (auto&& input: inputs) { | |||||
input_descs.push_back({input->layout(), input->comp_node()}); | |||||
} | |||||
auto subgraph = def.trait()->make_forward_graph(def, input_descs); | |||||
auto apply_functor = [](const std::shared_ptr<OpDef>& op, const SmallVector<TensorPtr>& inputs, size_t nr_outputs){ | |||||
return OpDef::apply_on_physical_tensor(*op, inputs); | |||||
}; | |||||
auto const_functor = [&](const TensorPtr& value) { | |||||
return value; | |||||
}; | |||||
auto outputs = subgraph.apply(inputs, apply_functor, const_functor); | |||||
return outputs; | |||||
} | |||||
static EncodedSubraph make_backward_graph_from_forward( | |||||
const SmallVector<LogicalTensorDesc>& inputs, | |||||
const SmallVector<bool>& input_requires_grad, | |||||
const SmallVector<bool>& output_has_grad, | |||||
EncodedSubraph forward_graph) { | |||||
using namespace std::placeholders; | |||||
using var_t = Subgraph::var_t; | |||||
using vars_t = Subgraph::vars_t; | |||||
Subgraph::Builder<LogicalTensorDesc> builder([](auto&& op, auto&& input_descs, size_t nr_outputs){ | |||||
auto [descs, _] = OpDef::infer_output_attrs_fallible(*op, input_descs); | |||||
return descs; | |||||
}); | |||||
auto accum_grad = [&](var_t lhs, var_t rhs) { | |||||
return builder.write_expr(Elemwise::make(Elemwise::Mode::ADD), {lhs, rhs}, 1)[0]; | |||||
}; | |||||
GradContext<var_t> grad_context{accum_grad}; | |||||
auto input_vars = builder.write_inputs(inputs); | |||||
auto outputs = forward_graph.apply(input_vars, std::bind(&decltype(builder)::write_expr, &builder, _1, _2, _3), [&](TensorPtr constant){ | |||||
return builder.write_constant(constant, {constant->layout(), constant->comp_node()}); | |||||
}); | |||||
size_t nr_outputs = outputs.size(); | |||||
auto apply_mask = [](auto&& values, SmallVector<bool> mask) { | |||||
mgb_assert(mask.size() == values.size(), ""); | |||||
std::decay_t<decltype(values)> results; | |||||
for (size_t i = 0; i < mask.size(); ++i) { | |||||
if (mask[i]) { | |||||
results.push_back(values[i]); | |||||
} | |||||
} | |||||
return results; | |||||
}; | |||||
grad_context.mark_require_grads(apply_mask(input_vars, input_requires_grad)); | |||||
builder.iterate([&](std::list<Subgraph::expr_t>::iterator iter){ | |||||
grad_context.record_expr(iter->op, iter->inputs, iter->outputs); | |||||
}); | |||||
auto output_descs = builder.get_descs(outputs); | |||||
auto computed_outputs = builder.write_inputs(output_descs); | |||||
auto output_grads = builder.write_inputs(output_descs); | |||||
grad_context.backward( | |||||
apply_mask(outputs, output_has_grad), | |||||
apply_mask(output_grads, output_has_grad), | |||||
[&](Subgraph::expr_t expr, vars_t output_grads) { | |||||
auto bg = OpDef::make_backward_graph( | |||||
*expr.op, builder.get_descs(expr.inputs), | |||||
grad_context.get_require_grads(expr.inputs), | |||||
grad_context.get_has_grads(expr.outputs)); | |||||
if (bg.graph.empty()) { | |||||
return vars_t(expr.inputs.size(), 0); | |||||
} | |||||
vars_t grad_inputs; | |||||
grad_inputs.insert(grad_inputs.end(), expr.inputs.begin(), | |||||
expr.inputs.end()); | |||||
grad_inputs.insert(grad_inputs.end(), expr.outputs.begin(), | |||||
expr.outputs.end()); | |||||
grad_inputs.insert(grad_inputs.end(), output_grads.begin(), | |||||
output_grads.end()); | |||||
auto apply_functor = std::bind(&decltype(builder)::write_expr, | |||||
&builder, _1, _2, _3); | |||||
auto const_functor = [&](TensorPtr constant) { | |||||
return builder.write_constant(constant, {constant->layout(), | |||||
constant->comp_node()}); | |||||
}; | |||||
return bg.apply(grad_inputs, apply_functor, const_functor); | |||||
}); | |||||
builder.add_outputs(grad_context.get_grads(input_vars)); | |||||
for (size_t i = 0; i < nr_outputs; ++i) { | |||||
builder.replace_var(outputs[i], computed_outputs[i]); | |||||
} | |||||
auto backward_graph = builder.encode(); | |||||
return backward_graph; | |||||
} | |||||
EncodedSubraph make_backward_graph( | |||||
const OpDef& def, | |||||
const SmallVector<LogicalTensorDesc>& inputs, | |||||
const SmallVector<bool>& input_requires_grad, | |||||
const SmallVector<bool>& output_has_grad) { | |||||
auto forward_graph = OpDef::make_forward_graph(def, inputs); | |||||
return make_backward_graph_from_forward(inputs, input_requires_grad, output_has_grad, forward_graph); | |||||
} | |||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||||
const OpDef& def, | |||||
const SmallVector<TensorPtr>& inputs_tensors, | |||||
const SmallVector<MemoryDesc>& inputs_mems) { | |||||
return {{}, {}}; | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -13,6 +13,7 @@ | |||||
#include "megbrain/graph.h" | #include "megbrain/graph.h" | ||||
#include "megbrain/imperative/physical_tensor.h" | #include "megbrain/imperative/physical_tensor.h" | ||||
#include "megbrain/imperative/subgraph.h" | |||||
#include "megbrain/imperative/utils/to_string.h" | #include "megbrain/imperative/utils/to_string.h" | ||||
#include "megbrain/imperative/subgraph.h" | #include "megbrain/imperative/subgraph.h" | ||||
@@ -94,6 +95,10 @@ public: | |||||
static std::vector<std::pair<const char*, std::string>> props( | static std::vector<std::pair<const char*, std::string>> props( | ||||
const OpDef& def); | const OpDef& def); | ||||
static EncodedSubraph make_forward_graph( | |||||
const OpDef& def, | |||||
const SmallVector<LogicalTensorDesc>& inputs); | |||||
const OpTrait* trait() const; | const OpTrait* trait() const; | ||||
std::string to_string() const; | std::string to_string() const; | ||||
@@ -0,0 +1,51 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/subgraph_detail.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain/imperative/op_def.h" | |||||
namespace mgb { | |||||
namespace imperative { | |||||
namespace subgraph_detail { | |||||
SmallVector<TensorPtr> | |||||
apply_on_physical_tensor(const OpDef& def, | |||||
SmallVector<TensorPtr> inputs); | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def, | |||||
const SmallVector<LogicalTensorDesc>& inputs); | |||||
EncodedSubraph | |||||
make_backward_graph(const OpDef& def, | |||||
const SmallVector<LogicalTensorDesc>& inputs, | |||||
const SmallVector<bool>& input_requires_grad, | |||||
const SmallVector<bool>& output_has_grad); | |||||
cg::VarNodeArray | |||||
apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs); | |||||
EncodedSubraph make_backward_graph( | |||||
const OpDef& def, | |||||
const SmallVector<LogicalTensorDesc>& inputs, | |||||
const SmallVector<bool>& input_requires_grad, | |||||
const SmallVector<bool>& output_has_grad); | |||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||||
const OpDef& def, | |||||
const SmallVector<TensorPtr>& inputs_tensors, | |||||
const SmallVector<MemoryDesc>& inputs_mems); | |||||
} | |||||
} | |||||
} |