GitOrigin-RevId: 171301fc2b
tags/v1.6.0-rc1
@@ -100,6 +100,19 @@ std::vector<std::pair<const char*, std::string>> OpDef::props( | |||
return def.trait()->props(def); | |||
} | |||
EncodedSubraph OpDef::make_forward_graph( | |||
const OpDef& def, | |||
const SmallVector<LogicalTensorDesc>& inputs){ | |||
using ForwardGraphCache = OpMethResultCache<EncodedSubraph, SmallVector<bool>, SmallVector<bool>>; | |||
thread_local ForwardGraphCache cache; | |||
decltype(cache)::key_t cache_key{const_cast<OpDef&>(def).shared_from_this(), inputs}; | |||
auto iter = cache.find(cache_key); | |||
if (iter == cache.end()) { | |||
iter = cache.insert({cache_key, def.trait()->make_forward_graph(def, inputs)}).first; | |||
} | |||
return iter->second; | |||
} | |||
std::string OpDef::to_string() const { | |||
std::string builder = trait()->make_name(*this) + "{"; | |||
for (auto&& [name, value]: props(*this)) { | |||
@@ -16,6 +16,7 @@ | |||
#include "megbrain/imperative/op_def.h" | |||
#include "megbrain/imperative/ops/opr_attr.h" | |||
#include "megbrain/imperative/proxy_graph_detail.h" | |||
#include "megbrain/imperative/subgraph_detail.h" | |||
#include "megbrain/tensor.h" | |||
#include "./op_trait.h" | |||
@@ -38,24 +39,45 @@ StaticData& static_data() { | |||
return data; | |||
} | |||
void OpMethFallback::impl(ApplyOnPhysicalTensor& func, | |||
void OpMethFallbackByProxyGraph::impl(ApplyOnPhysicalTensor& func, | |||
op_meth_tag::ApplyOnPhysicalTensor) { | |||
func.Base::operator=(proxy_graph_detail::apply_on_physical_tensor); | |||
} | |||
void OpMethFallback::impl(Execute& func, op_meth_tag::Execute) { | |||
void OpMethFallbackByProxyGraph::impl(Execute& func, op_meth_tag::Execute) { | |||
func.Base::operator=(proxy_graph_detail::execute); | |||
} | |||
void OpMethFallback::impl(InferOutputMemDesc& func, | |||
void OpMethFallbackByProxyGraph::impl(InferOutputMemDesc& func, | |||
op_meth_tag::InferOutputMemDesc) { | |||
func.Base::operator=(proxy_graph_detail::infer_output_mem_desc); | |||
} | |||
void OpMethFallback::impl(InferOutputAttrsFallible& func, | |||
void OpMethFallbackByProxyGraph::impl(InferOutputAttrsFallible& func, | |||
op_meth_tag::InferOutputAttrsFallible) { | |||
func.Base::operator=(proxy_graph_detail::infer_output_attrs_fallible); | |||
} | |||
void OpMethFallback::impl(GradMaker& func, op_meth_tag::GradMaker) { | |||
void OpMethFallbackByProxyGraph::impl(GradMaker& func, op_meth_tag::GradMaker) { | |||
func.Base::operator=(proxy_graph_detail::make_backward_graph); | |||
} | |||
void OpMethFallbackFromSubgraph::impl(ApplyOnPhysicalTensor& func, | |||
op_meth_tag::ApplyOnPhysicalTensor) { | |||
func.Base::operator=(subgraph_detail::apply_on_physical_tensor); | |||
} | |||
void OpMethFallbackFromSubgraph::impl(InferOutputMemDesc& func, | |||
op_meth_tag::InferOutputMemDesc) { | |||
func.Base::operator=(subgraph_detail::infer_output_mem_desc); | |||
} | |||
void OpMethFallbackFromSubgraph::impl(ApplyOnVarNode& func, | |||
op_meth_tag::ApplyOnVarNode) { | |||
func.Base::operator=(subgraph_detail::apply_on_var_node); | |||
} | |||
void OpMethFallbackFromSubgraph::impl(InferOutputAttrsFallible& func, | |||
op_meth_tag::InferOutputAttrsFallible) { | |||
func.Base::operator=(subgraph_detail::infer_output_attrs_fallible); | |||
} | |||
void OpMethFallbackFromSubgraph::impl(GradMaker& func, op_meth_tag::GradMaker) { | |||
func.Base::operator=(subgraph_detail::make_backward_graph); | |||
} | |||
void OpMethFallback::impl(DecideDispatchMode& func, | |||
op_meth_tag::DecideDispatchMode) { | |||
static auto decide_dispatch_mode = | |||
@@ -99,16 +121,20 @@ void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){ | |||
} | |||
OpTraitRegistry& OpTraitRegistry::fallback() { | |||
using Mode = detail::OpMethFallbackMode; | |||
uint64_t mode = Mode::None; | |||
if (trait->make_forward_graph) { | |||
mode |= Mode::FromSubgraph; | |||
} | |||
if (trait->apply_on_var_node) { | |||
// fallback to proxy graph impl | |||
trait->apply_on_physical_tensor.allow_fallback = true; | |||
trait->execute.allow_fallback = true; | |||
trait->infer_output_mem_desc.allow_fallback = true; | |||
trait->infer_output_attrs_fallible.allow_fallback = true; | |||
trait->make_backward_graph.allow_fallback = true; | |||
mode |= Mode::ByProxyGraph; | |||
} | |||
trait->decide_dispatch_mode.allow_fallback = true; | |||
trait->make_name.allow_fallback = true; | |||
mode |= Mode::Default; | |||
#define SET_FALLBACK_MODE(meth) \ | |||
trait->meth.fallback_mode = mode; | |||
FOR_EACH_OP_METH(SET_FALLBACK_MODE) | |||
#undef SET_FALLBACK_MODE | |||
return *this; | |||
} | |||
@@ -95,9 +95,18 @@ OpMethType(IsSame, | |||
OpMethType(MakeNameFunc, | |||
std::string(const OpDef&)); | |||
OpMethType(GraphMaker, | |||
decltype(OpDef::make_forward_graph)); | |||
// clang-format on | |||
namespace detail { | |||
struct OpMethImplBase { | |||
template <typename Tag, typename RType, typename... Args> | |||
static void impl(thin_function<RType(Args...)>& func, Tag) {} | |||
}; | |||
struct OpMethNotImpl { | |||
template <typename Tag, typename RType, typename... Args> | |||
static void impl(thin_function<RType(Args...)>& func, Tag) { | |||
@@ -106,8 +115,15 @@ struct OpMethNotImpl { | |||
}; | |||
} | |||
}; | |||
struct OpMethFallback : public OpMethNotImpl { | |||
using OpMethNotImpl::impl; | |||
struct OpMethFallback: OpMethImplBase { | |||
using OpMethImplBase::impl; | |||
static void impl(DecideDispatchMode& func, op_meth_tag::DecideDispatchMode); | |||
static void impl(MakeNameFunc& func, op_meth_tag::MakeNameFunc); | |||
}; | |||
struct OpMethFallbackByProxyGraph: OpMethImplBase { | |||
using OpMethImplBase::impl; | |||
static void impl(ApplyOnPhysicalTensor& func, | |||
op_meth_tag::ApplyOnPhysicalTensor); | |||
static void impl(Execute& func, op_meth_tag::Execute); | |||
@@ -115,18 +131,48 @@ struct OpMethFallback : public OpMethNotImpl { | |||
static void impl(InferOutputAttrsFallible& func, | |||
op_meth_tag::InferOutputAttrsFallible); | |||
static void impl(GradMaker& func, op_meth_tag::GradMaker); | |||
static void impl(DecideDispatchMode& func, op_meth_tag::DecideDispatchMode); | |||
static void impl(MakeNameFunc& func, op_meth_tag::MakeNameFunc); | |||
}; | |||
struct OpMethFallbackFromSubgraph: OpMethImplBase { | |||
using OpMethImplBase::impl; | |||
static void impl(ApplyOnPhysicalTensor& func, | |||
op_meth_tag::ApplyOnPhysicalTensor); | |||
static void impl(InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc); | |||
static void impl(ApplyOnVarNode& func, op_meth_tag::ApplyOnVarNode); | |||
static void impl(InferOutputAttrsFallible& func, | |||
op_meth_tag::InferOutputAttrsFallible); | |||
static void impl(GradMaker& func, op_meth_tag::GradMaker); | |||
}; | |||
struct OpMethFallbackMode { | |||
static constexpr uint64_t None = 0; | |||
static constexpr uint64_t Default = 1; | |||
static constexpr uint64_t ByProxyGraph = 2; | |||
static constexpr uint64_t FromSubgraph = 4; | |||
}; | |||
template <typename Tag, typename RType, typename... Args> | |||
struct OpMeth<Tag, RType(Args...)> : public thin_function<RType(Args...)> { | |||
using Base = thin_function<RType(Args...)>; | |||
OpMeth() : Base{}, allow_fallback(false){}; | |||
OpMeth() : Base{}{}; | |||
explicit OpMeth(const Base& base) { this->Base::operator=(base); } | |||
using Base::operator bool; | |||
RType operator()(Args... args) const { | |||
if (!this->Base::operator bool()) { | |||
if (allow_fallback) { | |||
uint64_t mode_mask = ~uint64_t(0); | |||
auto match_mode = [&](uint64_t mode){ | |||
if ((fallback_mode & mode_mask) & mode) { | |||
mode_mask &= ~mode; | |||
return true; | |||
} | |||
return false; | |||
}; | |||
while (!this->Base::operator bool()) { | |||
using Mode = OpMethFallbackMode; | |||
if (match_mode(Mode::FromSubgraph)) { | |||
OpMethFallbackFromSubgraph::impl(*const_cast<OpMeth*>(this), Tag{}); | |||
} else if (match_mode(Mode::ByProxyGraph)) { | |||
OpMethFallbackByProxyGraph::impl(*const_cast<OpMeth*>(this), Tag{}); | |||
} else if (match_mode(Mode::Default)) { | |||
OpMethFallback::impl(*const_cast<OpMeth*>(this), Tag{}); | |||
} else { | |||
OpMethNotImpl::impl(*const_cast<OpMeth*>(this), Tag{}); | |||
@@ -134,7 +180,7 @@ struct OpMeth<Tag, RType(Args...)> : public thin_function<RType(Args...)> { | |||
} | |||
return this->Base::operator()(std::forward<Args>(args)...); | |||
} | |||
bool allow_fallback = false; | |||
uint64_t fallback_mode = OpMethFallbackMode::None; | |||
}; | |||
} // namespace detail | |||
@@ -153,6 +199,7 @@ struct OpTrait { | |||
HashFunc hash; | |||
IsSame is_same_st; | |||
MakeNameFunc make_name; | |||
GraphMaker make_forward_graph; | |||
OpTrait(const char* name); | |||
static OpTrait* find_by_name(const char* name); | |||
static OpTrait* find_by_typeinfo(Typeinfo* type); | |||
@@ -173,7 +220,9 @@ struct OpTrait { | |||
cb(props) \ | |||
cb(hash) \ | |||
cb(is_same_st) \ | |||
cb(make_name) | |||
cb(make_name) \ | |||
cb(make_forward_graph) \ | |||
// clang-format on | |||
struct OpTraitRegistry { | |||
@@ -0,0 +1,169 @@ | |||
/** | |||
* \file imperative/src/impl/subgraph_detail.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/imperative/subgraph_detail.h" | |||
#include "megbrain/imperative/graph_builder.h" | |||
#include "megbrain/opr/io.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "./op_trait.h" | |||
namespace mgb { | |||
namespace imperative { | |||
namespace subgraph_detail { | |||
VarNodeArray apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
SmallVector<LogicalTensorDesc> input_descs; | |||
for (auto&& input: inputs) { | |||
input_descs.push_back({TensorLayout{input->dtype()}, input->comp_node()}); | |||
} | |||
auto apply_functor = [](const std::shared_ptr<OpDef>& op, const VarNodeArray& inputs, size_t nr_outputs){ | |||
return OpDef::apply_on_var_node(*op, inputs); | |||
}; | |||
auto const_functor = [&](const TensorPtr& value) { | |||
return opr::ImmutableTensor::make(*inputs[0]->owner_graph(), value->get_value()).node(); | |||
}; | |||
auto subgraph = def.trait()->make_forward_graph(def, input_descs); | |||
auto outputs = subgraph.apply(inputs, apply_functor, const_functor); | |||
return outputs; | |||
} | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
const OpDef& def, | |||
const SmallVector<LogicalTensorDesc>& inputs) { | |||
auto subgraph = def.trait()->make_forward_graph(def, inputs); | |||
bool all_validated = true; | |||
auto apply_functor = [&](const std::shared_ptr<OpDef>& op, const SmallVector<LogicalTensorDesc>& inputs, size_t nr_outputs){ | |||
auto [outputs, validated] = OpDef::infer_output_attrs_fallible(*op, inputs); | |||
all_validated = all_validated && validated; | |||
return outputs; | |||
}; | |||
auto const_functor = [&](const TensorPtr& value) { | |||
return LogicalTensorDesc{value->layout(), value->comp_node(), value->get_value().proxy_to_default_cpu()}; | |||
}; | |||
auto outputs = subgraph.apply(inputs, apply_functor, const_functor); | |||
return { outputs, all_validated }; | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, | |||
SmallVector<TensorPtr> inputs) { | |||
SmallVector<LogicalTensorDesc> input_descs; | |||
for (auto&& input: inputs) { | |||
input_descs.push_back({input->layout(), input->comp_node()}); | |||
} | |||
auto subgraph = def.trait()->make_forward_graph(def, input_descs); | |||
auto apply_functor = [](const std::shared_ptr<OpDef>& op, const SmallVector<TensorPtr>& inputs, size_t nr_outputs){ | |||
return OpDef::apply_on_physical_tensor(*op, inputs); | |||
}; | |||
auto const_functor = [&](const TensorPtr& value) { | |||
return value; | |||
}; | |||
auto outputs = subgraph.apply(inputs, apply_functor, const_functor); | |||
return outputs; | |||
} | |||
static EncodedSubraph make_backward_graph_from_forward( | |||
const SmallVector<LogicalTensorDesc>& inputs, | |||
const SmallVector<bool>& input_requires_grad, | |||
const SmallVector<bool>& output_has_grad, | |||
EncodedSubraph forward_graph) { | |||
using namespace std::placeholders; | |||
using var_t = Subgraph::var_t; | |||
using vars_t = Subgraph::vars_t; | |||
Subgraph::Builder<LogicalTensorDesc> builder([](auto&& op, auto&& input_descs, size_t nr_outputs){ | |||
auto [descs, _] = OpDef::infer_output_attrs_fallible(*op, input_descs); | |||
return descs; | |||
}); | |||
auto accum_grad = [&](var_t lhs, var_t rhs) { | |||
return builder.write_expr(Elemwise::make(Elemwise::Mode::ADD), {lhs, rhs}, 1)[0]; | |||
}; | |||
GradContext<var_t> grad_context{accum_grad}; | |||
auto input_vars = builder.write_inputs(inputs); | |||
auto outputs = forward_graph.apply(input_vars, std::bind(&decltype(builder)::write_expr, &builder, _1, _2, _3), [&](TensorPtr constant){ | |||
return builder.write_constant(constant, {constant->layout(), constant->comp_node()}); | |||
}); | |||
size_t nr_outputs = outputs.size(); | |||
auto apply_mask = [](auto&& values, SmallVector<bool> mask) { | |||
mgb_assert(mask.size() == values.size(), ""); | |||
std::decay_t<decltype(values)> results; | |||
for (size_t i = 0; i < mask.size(); ++i) { | |||
if (mask[i]) { | |||
results.push_back(values[i]); | |||
} | |||
} | |||
return results; | |||
}; | |||
grad_context.mark_require_grads(apply_mask(input_vars, input_requires_grad)); | |||
builder.iterate([&](std::list<Subgraph::expr_t>::iterator iter){ | |||
grad_context.record_expr(iter->op, iter->inputs, iter->outputs); | |||
}); | |||
auto output_descs = builder.get_descs(outputs); | |||
auto computed_outputs = builder.write_inputs(output_descs); | |||
auto output_grads = builder.write_inputs(output_descs); | |||
grad_context.backward( | |||
apply_mask(outputs, output_has_grad), | |||
apply_mask(output_grads, output_has_grad), | |||
[&](Subgraph::expr_t expr, vars_t output_grads) { | |||
auto bg = OpDef::make_backward_graph( | |||
*expr.op, builder.get_descs(expr.inputs), | |||
grad_context.get_require_grads(expr.inputs), | |||
grad_context.get_has_grads(expr.outputs)); | |||
if (bg.graph.empty()) { | |||
return vars_t(expr.inputs.size(), 0); | |||
} | |||
vars_t grad_inputs; | |||
grad_inputs.insert(grad_inputs.end(), expr.inputs.begin(), | |||
expr.inputs.end()); | |||
grad_inputs.insert(grad_inputs.end(), expr.outputs.begin(), | |||
expr.outputs.end()); | |||
grad_inputs.insert(grad_inputs.end(), output_grads.begin(), | |||
output_grads.end()); | |||
auto apply_functor = std::bind(&decltype(builder)::write_expr, | |||
&builder, _1, _2, _3); | |||
auto const_functor = [&](TensorPtr constant) { | |||
return builder.write_constant(constant, {constant->layout(), | |||
constant->comp_node()}); | |||
}; | |||
return bg.apply(grad_inputs, apply_functor, const_functor); | |||
}); | |||
builder.add_outputs(grad_context.get_grads(input_vars)); | |||
for (size_t i = 0; i < nr_outputs; ++i) { | |||
builder.replace_var(outputs[i], computed_outputs[i]); | |||
} | |||
auto backward_graph = builder.encode(); | |||
return backward_graph; | |||
} | |||
EncodedSubraph make_backward_graph( | |||
const OpDef& def, | |||
const SmallVector<LogicalTensorDesc>& inputs, | |||
const SmallVector<bool>& input_requires_grad, | |||
const SmallVector<bool>& output_has_grad) { | |||
auto forward_graph = OpDef::make_forward_graph(def, inputs); | |||
return make_backward_graph_from_forward(inputs, input_requires_grad, output_has_grad, forward_graph); | |||
} | |||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||
const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs_tensors, | |||
const SmallVector<MemoryDesc>& inputs_mems) { | |||
return {{}, {}}; | |||
} | |||
} | |||
} | |||
} |
@@ -13,6 +13,7 @@ | |||
#include "megbrain/graph.h" | |||
#include "megbrain/imperative/physical_tensor.h" | |||
#include "megbrain/imperative/subgraph.h" | |||
#include "megbrain/imperative/utils/to_string.h" | |||
#include "megbrain/imperative/subgraph.h" | |||
@@ -94,6 +95,10 @@ public: | |||
static std::vector<std::pair<const char*, std::string>> props( | |||
const OpDef& def); | |||
static EncodedSubraph make_forward_graph( | |||
const OpDef& def, | |||
const SmallVector<LogicalTensorDesc>& inputs); | |||
const OpTrait* trait() const; | |||
std::string to_string() const; | |||
@@ -0,0 +1,51 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/subgraph_detail.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megbrain/imperative/op_def.h" | |||
namespace mgb { | |||
namespace imperative { | |||
namespace subgraph_detail { | |||
SmallVector<TensorPtr> | |||
apply_on_physical_tensor(const OpDef& def, | |||
SmallVector<TensorPtr> inputs); | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def, | |||
const SmallVector<LogicalTensorDesc>& inputs); | |||
EncodedSubraph | |||
make_backward_graph(const OpDef& def, | |||
const SmallVector<LogicalTensorDesc>& inputs, | |||
const SmallVector<bool>& input_requires_grad, | |||
const SmallVector<bool>& output_has_grad); | |||
cg::VarNodeArray | |||
apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs); | |||
EncodedSubraph make_backward_graph( | |||
const OpDef& def, | |||
const SmallVector<LogicalTensorDesc>& inputs, | |||
const SmallVector<bool>& input_requires_grad, | |||
const SmallVector<bool>& output_has_grad); | |||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||
const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs_tensors, | |||
const SmallVector<MemoryDesc>& inputs_mems); | |||
} | |||
} | |||
} |