|
|
@@ -15,21 +15,10 @@ |
|
|
|
|
|
|
|
namespace mgb { |
|
|
|
namespace imperative { |
|
|
|
|
|
|
|
namespace detail { |
|
|
|
template <typename Signature> |
|
|
|
template <typename Tag, typename Signature> |
|
|
|
struct OpMeth; |
|
|
|
template <typename RType, typename... Args> |
|
|
|
struct OpMeth<RType(Args...)> : public thin_function<RType(Args...)> { |
|
|
|
using Base = thin_function<RType(Args...)>; |
|
|
|
using Base::Base; |
|
|
|
RType operator()(Args... args) const { |
|
|
|
if (!this->Base::operator bool()) { |
|
|
|
mgb_throw(MegBrainError, "Not Implemented"); |
|
|
|
} |
|
|
|
return this->Base::operator()(std::forward<Args>(args)...); |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
template<typename T> |
|
|
|
struct ToVarNodeArray: std::false_type {}; |
|
|
|
template<> |
|
|
@@ -58,28 +47,95 @@ struct ToVarNodeArray<cg::OperatorNodeBase*>: std::true_type { |
|
|
|
}; |
|
|
|
} // namespace detail |
|
|
|
|
|
|
|
using OpDefMaker = detail::OpMeth< |
|
|
|
decltype(OpDef::make_from_op_node)>; |
|
|
|
using DecideDispatchMode = detail::OpMeth< |
|
|
|
decltype(OpDef::decide_dispatch_mode)>; |
|
|
|
using ApplyOnPhysicalTensor = detail::OpMeth< |
|
|
|
decltype(OpDef::apply_on_physical_tensor)>; |
|
|
|
using InferOutputMemDesc = detail::OpMeth< |
|
|
|
decltype(OpDef::infer_output_mem_desc)>; |
|
|
|
using Execute = detail::OpMeth< |
|
|
|
decltype(OpDef::execute)>; |
|
|
|
using ApplyOnDeviceTensorND = detail::OpMeth< |
|
|
|
decltype(OpDef::apply_on_device_tensornd)>; |
|
|
|
using ApplyOnVarNode = detail::OpMeth< |
|
|
|
decltype(OpDef::apply_on_var_node)>; |
|
|
|
using InferOutputAttrsFallible = detail::OpMeth< |
|
|
|
decltype(OpDef::infer_output_attrs_fallible)>; |
|
|
|
using GradMaker = detail::OpMeth< |
|
|
|
decltype(OpDef::make_backward_graph)>; |
|
|
|
using Props = detail::OpMeth<decltype(OpDef::props)>; |
|
|
|
using HashFunc = detail::OpMeth<size_t(const OpDef&)>; |
|
|
|
using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>; |
|
|
|
using MakeNameFunc = detail::OpMeth<std::string(const OpDef&)>; |
|
|
|
// clang-format off |
|
|
|
#define OpMethType(TYPE, SIG) \ |
|
|
|
namespace detail::op_meth_tag { \ |
|
|
|
struct TYPE { \ |
|
|
|
constexpr static char name[] = #TYPE; \ |
|
|
|
}; \ |
|
|
|
} \ |
|
|
|
using TYPE = detail::OpMeth<detail::op_meth_tag::TYPE, SIG> |
|
|
|
|
|
|
|
OpMethType(OpDefMaker, |
|
|
|
decltype(OpDef::make_from_op_node)); |
|
|
|
|
|
|
|
OpMethType(DecideDispatchMode, |
|
|
|
decltype(OpDef::decide_dispatch_mode)); |
|
|
|
|
|
|
|
OpMethType(ApplyOnPhysicalTensor, |
|
|
|
decltype(OpDef::apply_on_physical_tensor)); |
|
|
|
|
|
|
|
OpMethType(InferOutputMemDesc, |
|
|
|
decltype(OpDef::infer_output_mem_desc)); |
|
|
|
|
|
|
|
OpMethType(Execute, |
|
|
|
decltype(OpDef::execute)); |
|
|
|
|
|
|
|
OpMethType(ApplyOnDeviceTensorND, |
|
|
|
decltype(OpDef::apply_on_device_tensornd)); |
|
|
|
|
|
|
|
OpMethType(ApplyOnVarNode, |
|
|
|
decltype(OpDef::apply_on_var_node)); |
|
|
|
|
|
|
|
OpMethType(InferOutputAttrsFallible, |
|
|
|
decltype(OpDef::infer_output_attrs_fallible)); |
|
|
|
|
|
|
|
OpMethType(GradMaker, |
|
|
|
decltype(OpDef::make_backward_graph)); |
|
|
|
|
|
|
|
OpMethType(Props, |
|
|
|
decltype(OpDef::props)); |
|
|
|
|
|
|
|
OpMethType(HashFunc, |
|
|
|
size_t(const OpDef&)); |
|
|
|
|
|
|
|
OpMethType(IsSame, |
|
|
|
bool(const OpDef&, const OpDef&)); |
|
|
|
|
|
|
|
OpMethType(MakeNameFunc, |
|
|
|
std::string(const OpDef&)); |
|
|
|
// clang-format on |
|
|
|
|
|
|
|
namespace detail { |
|
|
|
struct OpMethNotImpl { |
|
|
|
template <typename Tag, typename RType, typename... Args> |
|
|
|
static void impl(thin_function<RType(Args...)>& func, Tag) { |
|
|
|
func = [](Args... args) -> RType { |
|
|
|
mgb_throw(MegBrainError, "%s was not implemented yet", Tag::name); |
|
|
|
}; |
|
|
|
} |
|
|
|
}; |
|
|
|
struct OpMethFallback : public OpMethNotImpl { |
|
|
|
using OpMethNotImpl::impl; |
|
|
|
static void impl(ApplyOnPhysicalTensor& func, |
|
|
|
op_meth_tag::ApplyOnPhysicalTensor); |
|
|
|
static void impl(Execute& func, op_meth_tag::Execute); |
|
|
|
static void impl(InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc); |
|
|
|
static void impl(InferOutputAttrsFallible& func, |
|
|
|
op_meth_tag::InferOutputAttrsFallible); |
|
|
|
static void impl(GradMaker& func, op_meth_tag::GradMaker); |
|
|
|
static void impl(DecideDispatchMode& func, op_meth_tag::DecideDispatchMode); |
|
|
|
static void impl(MakeNameFunc& func, op_meth_tag::MakeNameFunc); |
|
|
|
}; |
|
|
|
template <typename Tag, typename RType, typename... Args> |
|
|
|
struct OpMeth<Tag, RType(Args...)> : public thin_function<RType(Args...)> { |
|
|
|
using Base = thin_function<RType(Args...)>; |
|
|
|
using Base::operator bool; |
|
|
|
OpMeth() : Base{}, allow_fallback(false){}; |
|
|
|
explicit OpMeth(const Base& base) { this->Base::operator=(base); } |
|
|
|
RType operator()(Args... args) const { |
|
|
|
if (!this->Base::operator bool()) { |
|
|
|
if (allow_fallback) { |
|
|
|
OpMethFallback::impl(*const_cast<OpMeth*>(this), Tag{}); |
|
|
|
} else { |
|
|
|
OpMethNotImpl::impl(*const_cast<OpMeth*>(this), Tag{}); |
|
|
|
} |
|
|
|
} |
|
|
|
return this->Base::operator()(std::forward<Args>(args)...); |
|
|
|
} |
|
|
|
bool allow_fallback = false; |
|
|
|
}; |
|
|
|
} // namespace detail |
|
|
|
|
|
|
|
struct OpTrait { |
|
|
|
const char* name; |
|
|
@@ -102,28 +158,31 @@ struct OpTrait { |
|
|
|
static void for_each_trait(thin_function<void(OpTrait&)> visitor); |
|
|
|
}; |
|
|
|
|
|
|
|
#define FOR_EACH_OP_METH(cb) \ |
|
|
|
cb(make_from_op_node) \ |
|
|
|
cb(decide_dispatch_mode) \ |
|
|
|
cb(apply_on_physical_tensor) \ |
|
|
|
cb(infer_output_mem_desc) \ |
|
|
|
cb(execute) \ |
|
|
|
cb(apply_on_device_tensornd) \ |
|
|
|
cb(apply_on_var_node) \ |
|
|
|
// clang-format off |
|
|
|
#define FOR_EACH_OP_METH(cb) \ |
|
|
|
cb(make_from_op_node) \ |
|
|
|
cb(decide_dispatch_mode) \ |
|
|
|
cb(apply_on_physical_tensor) \ |
|
|
|
cb(infer_output_mem_desc) \ |
|
|
|
cb(execute) \ |
|
|
|
cb(apply_on_device_tensornd) \ |
|
|
|
cb(apply_on_var_node) \ |
|
|
|
cb(infer_output_attrs_fallible) \ |
|
|
|
cb(make_backward_graph) \ |
|
|
|
cb(props) \ |
|
|
|
cb(hash) \ |
|
|
|
cb(is_same_st) \ |
|
|
|
cb(make_backward_graph) \ |
|
|
|
cb(props) \ |
|
|
|
cb(hash) \ |
|
|
|
cb(is_same_st) \ |
|
|
|
cb(make_name) |
|
|
|
// clang-format on |
|
|
|
|
|
|
|
struct OpTraitRegistry { |
|
|
|
OpTrait* trait; |
|
|
|
#define DECL(meth) \ |
|
|
|
OpTraitRegistry& meth(decltype(OpTrait::meth) f) { \ |
|
|
|
mgb_assert(!trait->meth, "op %s has duplicate method %s", trait->name, #meth); \ |
|
|
|
trait->meth = f; \ |
|
|
|
return *this; \ |
|
|
|
#define DECL(meth) \ |
|
|
|
OpTraitRegistry& meth(decltype(OpTrait::meth)::Base f) { \ |
|
|
|
mgb_assert(!trait->meth, "op %s has duplicate method %s", trait->name, \ |
|
|
|
#meth); \ |
|
|
|
trait->meth.Base::operator=(f); \ |
|
|
|
return *this; \ |
|
|
|
} |
|
|
|
FOR_EACH_OP_METH(DECL) |
|
|
|
#undef DECL |
|
|
@@ -162,7 +221,7 @@ struct OpTraitRegistry { |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
} // namespace imperative |
|
|
|
} // namespace imperative |
|
|
|
} // namespace mgb |
|
|
|
|
|
|
|
#define OP_TRAIT_REG(name, ...) \ |
|
|
|