From 20e8541bbbe7de333a626f7a00f032808db117eb Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 9 Aug 2021 17:01:03 +0800 Subject: [PATCH] refactor(imperative): bind fallback impl on first op method call GitOrigin-RevId: 82ae1e32052f274dea67ced95dc6ab694883425b --- imperative/src/impl/op_trait.cpp | 73 ++++++------ imperative/src/impl/op_trait.h | 165 +++++++++++++++++++--------- imperative/src/impl/tensor_sanity_check.cpp | 44 ++++---- 3 files changed, 175 insertions(+), 107 deletions(-) diff --git a/imperative/src/impl/op_trait.cpp b/imperative/src/impl/op_trait.cpp index 59b3241d..59f1befa 100644 --- a/imperative/src/impl/op_trait.cpp +++ b/imperative/src/impl/op_trait.cpp @@ -38,6 +38,38 @@ StaticData& static_data() { return data; } +void OpMethFallback::impl(ApplyOnPhysicalTensor& func, + op_meth_tag::ApplyOnPhysicalTensor) { + func.Base::operator=(proxy_graph_detail::apply_on_physical_tensor); +} +void OpMethFallback::impl(Execute& func, op_meth_tag::Execute) { + func.Base::operator=(proxy_graph_detail::execute); +} +void OpMethFallback::impl(InferOutputMemDesc& func, + op_meth_tag::InferOutputMemDesc) { + func.Base::operator=(proxy_graph_detail::infer_output_mem_desc); +} +void OpMethFallback::impl(InferOutputAttrsFallible& func, + op_meth_tag::InferOutputAttrsFallible) { + func.Base::operator=(proxy_graph_detail::infer_output_attrs_fallible); +} +void OpMethFallback::impl(GradMaker& func, op_meth_tag::GradMaker) { + func.Base::operator=(proxy_graph_detail::make_backward_graph); +} +void OpMethFallback::impl(DecideDispatchMode& func, + op_meth_tag::DecideDispatchMode) { + static auto decide_dispatch_mode = + [](const OpDef&, const SmallVector&) { + return DispatchMode::KERNEL; + }; + func.Base::operator=(decide_dispatch_mode); +} +void OpMethFallback::impl(MakeNameFunc& func, op_meth_tag::MakeNameFunc) { + static auto make_name = [](const OpDef& def) -> std::string { + return def.trait()->name; + }; + func.Base::operator=(make_name); +} } // detail OpTrait::OpTrait(const char* name_): name(name_) {} @@ -66,44 +98,17 @@ void OpTrait::for_each_trait(thin_function visitor){ } } -DispatchMode fallback_decide_dispatch_mode( - const OpDef& def, - const SmallVector& inputs) { - return KERNEL; -} - OpTraitRegistry& OpTraitRegistry::fallback() { if (trait->apply_on_var_node) { // fallback to proxy graph impl - if (!trait->apply_on_physical_tensor) { - trait->apply_on_physical_tensor = - proxy_graph_detail::apply_on_physical_tensor; - } - if (!trait->execute) { - trait->execute = proxy_graph_detail::execute; - } - if (!trait->infer_output_mem_desc) { - trait->infer_output_mem_desc = - proxy_graph_detail::infer_output_mem_desc; - } - if (!trait->infer_output_attrs_fallible) { - trait->infer_output_attrs_fallible = - proxy_graph_detail::infer_output_attrs_fallible; - } - if (!trait->make_backward_graph) { - trait->make_backward_graph = - proxy_graph_detail::make_backward_graph; - } - } - if (!trait->decide_dispatch_mode) { - trait->decide_dispatch_mode = fallback_decide_dispatch_mode; - } - if (!trait->make_name) { - static auto make_name = [](const OpDef& def) -> std::string { - return def.trait()->name; - }; - trait->make_name = make_name; + trait->apply_on_physical_tensor.allow_fallback = true; + trait->execute.allow_fallback = true; + trait->infer_output_mem_desc.allow_fallback = true; + trait->infer_output_attrs_fallible.allow_fallback = true; + trait->make_backward_graph.allow_fallback = true; } + trait->decide_dispatch_mode.allow_fallback = true; + trait->make_name.allow_fallback = true; return *this; } diff --git a/imperative/src/impl/op_trait.h b/imperative/src/impl/op_trait.h index dff44147..93e5c397 100644 --- a/imperative/src/impl/op_trait.h +++ b/imperative/src/impl/op_trait.h @@ -15,21 +15,10 @@ namespace mgb { namespace imperative { - namespace detail { -template +template struct OpMeth; -template -struct OpMeth : public thin_function { - using Base = thin_function; - using Base::Base; - RType operator()(Args... args) const { - if (!this->Base::operator bool()) { - mgb_throw(MegBrainError, "Not Implemented"); - } - return this->Base::operator()(std::forward(args)...); - } -}; + template struct ToVarNodeArray: std::false_type {}; template<> @@ -58,28 +47,95 @@ struct ToVarNodeArray: std::true_type { }; } // namespace detail -using OpDefMaker = detail::OpMeth< - decltype(OpDef::make_from_op_node)>; -using DecideDispatchMode = detail::OpMeth< - decltype(OpDef::decide_dispatch_mode)>; -using ApplyOnPhysicalTensor = detail::OpMeth< - decltype(OpDef::apply_on_physical_tensor)>; -using InferOutputMemDesc = detail::OpMeth< - decltype(OpDef::infer_output_mem_desc)>; -using Execute = detail::OpMeth< - decltype(OpDef::execute)>; -using ApplyOnDeviceTensorND = detail::OpMeth< - decltype(OpDef::apply_on_device_tensornd)>; -using ApplyOnVarNode = detail::OpMeth< - decltype(OpDef::apply_on_var_node)>; -using InferOutputAttrsFallible = detail::OpMeth< - decltype(OpDef::infer_output_attrs_fallible)>; -using GradMaker = detail::OpMeth< - decltype(OpDef::make_backward_graph)>; -using Props = detail::OpMeth; -using HashFunc = detail::OpMeth; -using IsSame = detail::OpMeth; -using MakeNameFunc = detail::OpMeth; +// clang-format off +#define OpMethType(TYPE, SIG) \ + namespace detail::op_meth_tag { \ + struct TYPE { \ + constexpr static char name[] = #TYPE; \ + }; \ + } \ + using TYPE = detail::OpMeth + +OpMethType(OpDefMaker, + decltype(OpDef::make_from_op_node)); + +OpMethType(DecideDispatchMode, + decltype(OpDef::decide_dispatch_mode)); + +OpMethType(ApplyOnPhysicalTensor, + decltype(OpDef::apply_on_physical_tensor)); + +OpMethType(InferOutputMemDesc, + decltype(OpDef::infer_output_mem_desc)); + +OpMethType(Execute, + decltype(OpDef::execute)); + +OpMethType(ApplyOnDeviceTensorND, + decltype(OpDef::apply_on_device_tensornd)); + +OpMethType(ApplyOnVarNode, + decltype(OpDef::apply_on_var_node)); + +OpMethType(InferOutputAttrsFallible, + decltype(OpDef::infer_output_attrs_fallible)); + +OpMethType(GradMaker, + decltype(OpDef::make_backward_graph)); + +OpMethType(Props, + decltype(OpDef::props)); + +OpMethType(HashFunc, + size_t(const OpDef&)); + +OpMethType(IsSame, + bool(const OpDef&, const OpDef&)); + +OpMethType(MakeNameFunc, + std::string(const OpDef&)); +// clang-format on + +namespace detail { +struct OpMethNotImpl { + template + static void impl(thin_function& func, Tag) { + func = [](Args... args) -> RType { + mgb_throw(MegBrainError, "%s was not implemented yet", Tag::name); + }; + } +}; +struct OpMethFallback : public OpMethNotImpl { + using OpMethNotImpl::impl; + static void impl(ApplyOnPhysicalTensor& func, + op_meth_tag::ApplyOnPhysicalTensor); + static void impl(Execute& func, op_meth_tag::Execute); + static void impl(InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc); + static void impl(InferOutputAttrsFallible& func, + op_meth_tag::InferOutputAttrsFallible); + static void impl(GradMaker& func, op_meth_tag::GradMaker); + static void impl(DecideDispatchMode& func, op_meth_tag::DecideDispatchMode); + static void impl(MakeNameFunc& func, op_meth_tag::MakeNameFunc); +}; +template +struct OpMeth : public thin_function { + using Base = thin_function; + using Base::operator bool; + OpMeth() : Base{}, allow_fallback(false){}; + explicit OpMeth(const Base& base) { this->Base::operator=(base); } + RType operator()(Args... args) const { + if (!this->Base::operator bool()) { + if (allow_fallback) { + OpMethFallback::impl(*const_cast(this), Tag{}); + } else { + OpMethNotImpl::impl(*const_cast(this), Tag{}); + } + } + return this->Base::operator()(std::forward(args)...); + } + bool allow_fallback = false; +}; +} // namespace detail struct OpTrait { const char* name; @@ -102,28 +158,31 @@ struct OpTrait { static void for_each_trait(thin_function visitor); }; -#define FOR_EACH_OP_METH(cb) \ - cb(make_from_op_node) \ - cb(decide_dispatch_mode) \ - cb(apply_on_physical_tensor) \ - cb(infer_output_mem_desc) \ - cb(execute) \ - cb(apply_on_device_tensornd) \ - cb(apply_on_var_node) \ +// clang-format off +#define FOR_EACH_OP_METH(cb) \ + cb(make_from_op_node) \ + cb(decide_dispatch_mode) \ + cb(apply_on_physical_tensor) \ + cb(infer_output_mem_desc) \ + cb(execute) \ + cb(apply_on_device_tensornd) \ + cb(apply_on_var_node) \ cb(infer_output_attrs_fallible) \ - cb(make_backward_graph) \ - cb(props) \ - cb(hash) \ - cb(is_same_st) \ + cb(make_backward_graph) \ + cb(props) \ + cb(hash) \ + cb(is_same_st) \ cb(make_name) +// clang-format on struct OpTraitRegistry { OpTrait* trait; -#define DECL(meth) \ - OpTraitRegistry& meth(decltype(OpTrait::meth) f) { \ - mgb_assert(!trait->meth, "op %s has duplicate method %s", trait->name, #meth); \ - trait->meth = f; \ - return *this; \ +#define DECL(meth) \ + OpTraitRegistry& meth(decltype(OpTrait::meth)::Base f) { \ + mgb_assert(!trait->meth, "op %s has duplicate method %s", trait->name, \ + #meth); \ + trait->meth.Base::operator=(f); \ + return *this; \ } FOR_EACH_OP_METH(DECL) #undef DECL @@ -162,7 +221,7 @@ struct OpTraitRegistry { } }; -} // namespace imperative +} // namespace imperative } // namespace mgb #define OP_TRAIT_REG(name, ...) \ diff --git a/imperative/src/impl/tensor_sanity_check.cpp b/imperative/src/impl/tensor_sanity_check.cpp index af6c9056..090617bc 100644 --- a/imperative/src/impl/tensor_sanity_check.cpp +++ b/imperative/src/impl/tensor_sanity_check.cpp @@ -80,26 +80,30 @@ void TensorSanityCheck::enable() { OpTrait::for_each_trait([this](OpTrait& trait) { auto backup = std::make_unique( std::move(trait.apply_on_physical_tensor)); - trait.apply_on_physical_tensor = [this, backup = backup.get()] ( - const OpDef& def, const SmallVector& inputs) { - for (auto&& i: inputs) { - if (!m_checker->check(i)) { - mgb_throw(TensorChecksumCalc::Error, - "tensor modified before exec %s", print_op(def).c_str()); - } - } - auto output = (*backup)(def, inputs); - for (auto&& i: output) { - mgb_assert(m_checker->check(i)); - } - for (auto&& i: inputs) { - if (!m_checker->check(i)) { - mgb_throw(TensorChecksumCalc::Error, - "tensor modified after exec %s", print_op(def).c_str()); - } - } - return output; - }; + trait.apply_on_physical_tensor = ApplyOnPhysicalTensor( + [this, backup = backup.get()]( + const OpDef& def, + const SmallVector& inputs) { + for (auto&& i : inputs) { + if (!m_checker->check(i)) { + mgb_throw(TensorChecksumCalc::Error, + "tensor modified before exec %s", + print_op(def).c_str()); + } + } + auto output = (*backup)(def, inputs); + for (auto&& i : output) { + mgb_assert(m_checker->check(i)); + } + for (auto&& i : inputs) { + if (!m_checker->check(i)) { + mgb_throw(TensorChecksumCalc::Error, + "tensor modified after exec %s", + print_op(def).c_str()); + } + } + return output; + }); m_checker->hook_list.push_back({&trait, std::move(backup)}); }); }