Browse Source

refactor(imperative): bind fallback impl on first op method call

GitOrigin-RevId: 82ae1e3205
tags/v1.6.0-rc1
Megvii Engine Team 3 years ago
parent
commit
20e8541bbb
3 changed files with 175 additions and 107 deletions
  1. +39
    -34
      imperative/src/impl/op_trait.cpp
  2. +112
    -53
      imperative/src/impl/op_trait.h
  3. +24
    -20
      imperative/src/impl/tensor_sanity_check.cpp

+ 39
- 34
imperative/src/impl/op_trait.cpp View File

@@ -38,6 +38,38 @@ StaticData& static_data() {
return data;
}

void OpMethFallback::impl(ApplyOnPhysicalTensor& func,
op_meth_tag::ApplyOnPhysicalTensor) {
func.Base::operator=(proxy_graph_detail::apply_on_physical_tensor);
}
void OpMethFallback::impl(Execute& func, op_meth_tag::Execute) {
func.Base::operator=(proxy_graph_detail::execute);
}
void OpMethFallback::impl(InferOutputMemDesc& func,
op_meth_tag::InferOutputMemDesc) {
func.Base::operator=(proxy_graph_detail::infer_output_mem_desc);
}
void OpMethFallback::impl(InferOutputAttrsFallible& func,
op_meth_tag::InferOutputAttrsFallible) {
func.Base::operator=(proxy_graph_detail::infer_output_attrs_fallible);
}
void OpMethFallback::impl(GradMaker& func, op_meth_tag::GradMaker) {
func.Base::operator=(proxy_graph_detail::make_backward_graph);
}
void OpMethFallback::impl(DecideDispatchMode& func,
op_meth_tag::DecideDispatchMode) {
static auto decide_dispatch_mode =
[](const OpDef&, const SmallVector<LogicalTensorDesc>&) {
return DispatchMode::KERNEL;
};
func.Base::operator=(decide_dispatch_mode);
}
void OpMethFallback::impl(MakeNameFunc& func, op_meth_tag::MakeNameFunc) {
static auto make_name = [](const OpDef& def) -> std::string {
return def.trait()->name;
};
func.Base::operator=(make_name);
}
} // detail

OpTrait::OpTrait(const char* name_): name(name_) {}
@@ -66,44 +98,17 @@ void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){
}
}

DispatchMode fallback_decide_dispatch_mode(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
return KERNEL;
}

OpTraitRegistry& OpTraitRegistry::fallback() {
if (trait->apply_on_var_node) {
// fallback to proxy graph impl
if (!trait->apply_on_physical_tensor) {
trait->apply_on_physical_tensor =
proxy_graph_detail::apply_on_physical_tensor;
}
if (!trait->execute) {
trait->execute = proxy_graph_detail::execute;
}
if (!trait->infer_output_mem_desc) {
trait->infer_output_mem_desc =
proxy_graph_detail::infer_output_mem_desc;
}
if (!trait->infer_output_attrs_fallible) {
trait->infer_output_attrs_fallible =
proxy_graph_detail::infer_output_attrs_fallible;
}
if (!trait->make_backward_graph) {
trait->make_backward_graph =
proxy_graph_detail::make_backward_graph;
}
}
if (!trait->decide_dispatch_mode) {
trait->decide_dispatch_mode = fallback_decide_dispatch_mode;
}
if (!trait->make_name) {
static auto make_name = [](const OpDef& def) -> std::string {
return def.trait()->name;
};
trait->make_name = make_name;
trait->apply_on_physical_tensor.allow_fallback = true;
trait->execute.allow_fallback = true;
trait->infer_output_mem_desc.allow_fallback = true;
trait->infer_output_attrs_fallible.allow_fallback = true;
trait->make_backward_graph.allow_fallback = true;
}
trait->decide_dispatch_mode.allow_fallback = true;
trait->make_name.allow_fallback = true;
return *this;
}



+ 112
- 53
imperative/src/impl/op_trait.h View File

@@ -15,21 +15,10 @@

namespace mgb {
namespace imperative {

namespace detail {
template <typename Signature>
template <typename Tag, typename Signature>
struct OpMeth;
template <typename RType, typename... Args>
struct OpMeth<RType(Args...)> : public thin_function<RType(Args...)> {
using Base = thin_function<RType(Args...)>;
using Base::Base;
RType operator()(Args... args) const {
if (!this->Base::operator bool()) {
mgb_throw(MegBrainError, "Not Implemented");
}
return this->Base::operator()(std::forward<Args>(args)...);
}
};

template<typename T>
struct ToVarNodeArray: std::false_type {};
template<>
@@ -58,28 +47,95 @@ struct ToVarNodeArray<cg::OperatorNodeBase*>: std::true_type {
};
} // namespace detail

using OpDefMaker = detail::OpMeth<
decltype(OpDef::make_from_op_node)>;
using DecideDispatchMode = detail::OpMeth<
decltype(OpDef::decide_dispatch_mode)>;
using ApplyOnPhysicalTensor = detail::OpMeth<
decltype(OpDef::apply_on_physical_tensor)>;
using InferOutputMemDesc = detail::OpMeth<
decltype(OpDef::infer_output_mem_desc)>;
using Execute = detail::OpMeth<
decltype(OpDef::execute)>;
using ApplyOnDeviceTensorND = detail::OpMeth<
decltype(OpDef::apply_on_device_tensornd)>;
using ApplyOnVarNode = detail::OpMeth<
decltype(OpDef::apply_on_var_node)>;
using InferOutputAttrsFallible = detail::OpMeth<
decltype(OpDef::infer_output_attrs_fallible)>;
using GradMaker = detail::OpMeth<
decltype(OpDef::make_backward_graph)>;
using Props = detail::OpMeth<decltype(OpDef::props)>;
using HashFunc = detail::OpMeth<size_t(const OpDef&)>;
using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>;
using MakeNameFunc = detail::OpMeth<std::string(const OpDef&)>;
// clang-format off
#define OpMethType(TYPE, SIG) \
namespace detail::op_meth_tag { \
struct TYPE { \
constexpr static char name[] = #TYPE; \
}; \
} \
using TYPE = detail::OpMeth<detail::op_meth_tag::TYPE, SIG>

OpMethType(OpDefMaker,
decltype(OpDef::make_from_op_node));

OpMethType(DecideDispatchMode,
decltype(OpDef::decide_dispatch_mode));

OpMethType(ApplyOnPhysicalTensor,
decltype(OpDef::apply_on_physical_tensor));

OpMethType(InferOutputMemDesc,
decltype(OpDef::infer_output_mem_desc));

OpMethType(Execute,
decltype(OpDef::execute));

OpMethType(ApplyOnDeviceTensorND,
decltype(OpDef::apply_on_device_tensornd));

OpMethType(ApplyOnVarNode,
decltype(OpDef::apply_on_var_node));

OpMethType(InferOutputAttrsFallible,
decltype(OpDef::infer_output_attrs_fallible));

OpMethType(GradMaker,
decltype(OpDef::make_backward_graph));

OpMethType(Props,
decltype(OpDef::props));

OpMethType(HashFunc,
size_t(const OpDef&));

OpMethType(IsSame,
bool(const OpDef&, const OpDef&));

OpMethType(MakeNameFunc,
std::string(const OpDef&));
// clang-format on

namespace detail {
struct OpMethNotImpl {
template <typename Tag, typename RType, typename... Args>
static void impl(thin_function<RType(Args...)>& func, Tag) {
func = [](Args... args) -> RType {
mgb_throw(MegBrainError, "%s was not implemented yet", Tag::name);
};
}
};
struct OpMethFallback : public OpMethNotImpl {
using OpMethNotImpl::impl;
static void impl(ApplyOnPhysicalTensor& func,
op_meth_tag::ApplyOnPhysicalTensor);
static void impl(Execute& func, op_meth_tag::Execute);
static void impl(InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc);
static void impl(InferOutputAttrsFallible& func,
op_meth_tag::InferOutputAttrsFallible);
static void impl(GradMaker& func, op_meth_tag::GradMaker);
static void impl(DecideDispatchMode& func, op_meth_tag::DecideDispatchMode);
static void impl(MakeNameFunc& func, op_meth_tag::MakeNameFunc);
};
template <typename Tag, typename RType, typename... Args>
struct OpMeth<Tag, RType(Args...)> : public thin_function<RType(Args...)> {
using Base = thin_function<RType(Args...)>;
using Base::operator bool;
OpMeth() : Base{}, allow_fallback(false){};
explicit OpMeth(const Base& base) { this->Base::operator=(base); }
RType operator()(Args... args) const {
if (!this->Base::operator bool()) {
if (allow_fallback) {
OpMethFallback::impl(*const_cast<OpMeth*>(this), Tag{});
} else {
OpMethNotImpl::impl(*const_cast<OpMeth*>(this), Tag{});
}
}
return this->Base::operator()(std::forward<Args>(args)...);
}
bool allow_fallback = false;
};
} // namespace detail

struct OpTrait {
const char* name;
@@ -102,28 +158,31 @@ struct OpTrait {
static void for_each_trait(thin_function<void(OpTrait&)> visitor);
};

#define FOR_EACH_OP_METH(cb) \
cb(make_from_op_node) \
cb(decide_dispatch_mode) \
cb(apply_on_physical_tensor) \
cb(infer_output_mem_desc) \
cb(execute) \
cb(apply_on_device_tensornd) \
cb(apply_on_var_node) \
// clang-format off
#define FOR_EACH_OP_METH(cb) \
cb(make_from_op_node) \
cb(decide_dispatch_mode) \
cb(apply_on_physical_tensor) \
cb(infer_output_mem_desc) \
cb(execute) \
cb(apply_on_device_tensornd) \
cb(apply_on_var_node) \
cb(infer_output_attrs_fallible) \
cb(make_backward_graph) \
cb(props) \
cb(hash) \
cb(is_same_st) \
cb(make_backward_graph) \
cb(props) \
cb(hash) \
cb(is_same_st) \
cb(make_name)
// clang-format on

struct OpTraitRegistry {
OpTrait* trait;
#define DECL(meth) \
OpTraitRegistry& meth(decltype(OpTrait::meth) f) { \
mgb_assert(!trait->meth, "op %s has duplicate method %s", trait->name, #meth); \
trait->meth = f; \
return *this; \
#define DECL(meth) \
OpTraitRegistry& meth(decltype(OpTrait::meth)::Base f) { \
mgb_assert(!trait->meth, "op %s has duplicate method %s", trait->name, \
#meth); \
trait->meth.Base::operator=(f); \
return *this; \
}
FOR_EACH_OP_METH(DECL)
#undef DECL
@@ -162,7 +221,7 @@ struct OpTraitRegistry {
}
};

} // namespace imperative
} // namespace imperative
} // namespace mgb

#define OP_TRAIT_REG(name, ...) \


+ 24
- 20
imperative/src/impl/tensor_sanity_check.cpp View File

@@ -80,26 +80,30 @@ void TensorSanityCheck::enable() {
OpTrait::for_each_trait([this](OpTrait& trait) {
auto backup = std::make_unique<ApplyOnPhysicalTensor>(
std::move(trait.apply_on_physical_tensor));
trait.apply_on_physical_tensor = [this, backup = backup.get()] (
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
for (auto&& i: inputs) {
if (!m_checker->check(i)) {
mgb_throw(TensorChecksumCalc::Error,
"tensor modified before exec %s", print_op(def).c_str());
}
}
auto output = (*backup)(def, inputs);
for (auto&& i: output) {
mgb_assert(m_checker->check(i));
}
for (auto&& i: inputs) {
if (!m_checker->check(i)) {
mgb_throw(TensorChecksumCalc::Error,
"tensor modified after exec %s", print_op(def).c_str());
}
}
return output;
};
trait.apply_on_physical_tensor = ApplyOnPhysicalTensor(
[this, backup = backup.get()](
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
for (auto&& i : inputs) {
if (!m_checker->check(i)) {
mgb_throw(TensorChecksumCalc::Error,
"tensor modified before exec %s",
print_op(def).c_str());
}
}
auto output = (*backup)(def, inputs);
for (auto&& i : output) {
mgb_assert(m_checker->check(i));
}
for (auto&& i : inputs) {
if (!m_checker->check(i)) {
mgb_throw(TensorChecksumCalc::Error,
"tensor modified after exec %s",
print_op(def).c_str());
}
}
return output;
});
m_checker->hook_list.push_back({&trait, std::move(backup)});
});
}


Loading…
Cancel
Save