remove unused op methods and enable OpTrait registered from seperated files, also fix cond_take's infer_output_attrs
GitOrigin-RevId: 134c8215ce
release-1.1
@@ -36,13 +36,6 @@ SmallVector<TensorPtr> OpDef::apply_on_physical_tensor( | |||||
return def.trait()->apply_on_physical_tensor(def, inputs); | return def.trait()->apply_on_physical_tensor(def, inputs); | ||||
} | } | ||||
void OpDef::exec( | |||||
const OpDef& def, | |||||
const SmallVector<TensorPtr>& inputs, | |||||
const SmallVector<TensorPtr>& outputs) { | |||||
def.trait()->exec(def, inputs, outputs); | |||||
} | |||||
cg::OperatorNodeBase* OpDef::apply_on_var_node( | cg::OperatorNodeBase* OpDef::apply_on_var_node( | ||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
@@ -55,12 +48,6 @@ SmallVector<LogicalTensorDesc> OpDef::infer_output_attrs_fallible( | |||||
return def.trait()->infer_output_attrs_fallible(def, inputs); | return def.trait()->infer_output_attrs_fallible(def, inputs); | ||||
} | } | ||||
SmallVector<LogicalTensorDesc> OpDef::infer_output_attrs( | |||||
const OpDef& def, | |||||
const SmallVector<TensorPtr>& inputs) { | |||||
return def.trait()->infer_output_attrs(def, inputs); | |||||
} | |||||
BackwardGraphResult OpDef::make_backward_graph( | BackwardGraphResult OpDef::make_backward_graph( | ||||
const OpDef& def, | const OpDef& def, | ||||
const SmallVector<LogicalTensorDesc>& inputs, | const SmallVector<LogicalTensorDesc>& inputs, | ||||
@@ -34,16 +34,6 @@ StaticData& static_data() { | |||||
return data; | return data; | ||||
} | } | ||||
template<typename T> | |||||
struct __not_implementation__; | |||||
template<typename RType, typename ...Args> | |||||
struct __not_implementation__<RType(Args...)> { | |||||
static RType raise(Args ...) { | |||||
mgb_throw(MegBrainError, "Not Implemented"); | |||||
} | |||||
}; | |||||
} // detail | } // detail | ||||
OpTrait::OpTrait(const char* name_): name(name_) {} | OpTrait::OpTrait(const char* name_): name(name_) {} | ||||
@@ -72,89 +62,45 @@ void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){ | |||||
} | } | ||||
} | } | ||||
OpTraitRegistry& OpTraitRegistry::finalize() { | |||||
std::ostringstream msg; | |||||
#define CHECK(field) if (!trait->field) { \ | |||||
msg << ", " #field; \ | |||||
trait->field = \ | |||||
detail::__not_implementation__<decltype(OpDef::field)>::raise; \ | |||||
} | |||||
CHECK(make_from_op_node); | |||||
CHECK(apply_on_physical_tensor); | |||||
CHECK(exec); | |||||
CHECK(apply_on_var_node); | |||||
CHECK(infer_output_attrs_fallible); | |||||
CHECK(infer_output_attrs); | |||||
CHECK(make_backward_graph); | |||||
#undef CHECK | |||||
#ifdef DEBUG | |||||
if (msg.tellp() > 0) { | |||||
mgb_log_warn( | |||||
"%s op trait missing: %s", | |||||
trait->name ? trait->name : "(anonymous)", | |||||
msg.str().c_str() + 2 /* skip first ", " */); | |||||
} | |||||
#endif | |||||
return *this; | |||||
} | |||||
SmallVector<TensorPtr> fallback_apply_on_physical_tensor( | |||||
const OpDef& def, | |||||
const SmallVector<TensorPtr>& inputs) { | |||||
auto desc = OpDef::infer_output_attrs(def, inputs); | |||||
SmallVector<TensorPtr> outputs; | |||||
for (auto&& i : desc) { | |||||
outputs.push_back(Tensor::make(i.layout, i.comp_node)); | |||||
} | |||||
OpDef::exec(def, inputs, outputs); | |||||
return outputs; | |||||
} | |||||
SmallVector<LogicalTensorDesc> fallback_infer_output_attrs(const OpDef& def, | |||||
const SmallVector<TensorPtr>& inputs){ | |||||
SmallVector<LogicalTensorDesc> input_descs; | |||||
for(auto&& input: inputs){ | |||||
input_descs.push_back({input->layout(), input->comp_node()}); | |||||
} | |||||
return input_descs; | |||||
} | |||||
OpTraitRegistry& OpTraitRegistry::fallback() { | OpTraitRegistry& OpTraitRegistry::fallback() { | ||||
if (!trait->exec && trait->apply_on_var_node) { | |||||
trait->exec = proxy_graph_detail::exec; | |||||
} | |||||
if (!trait->infer_output_attrs && trait->apply_on_var_node) { | |||||
trait->infer_output_attrs = proxy_graph_detail::infer_output_attrs; | |||||
} | |||||
if (!trait->infer_output_attrs_fallible && trait->apply_on_var_node) { | |||||
trait->infer_output_attrs_fallible = proxy_graph_detail::infer_output_attrs_fallible; | |||||
} | |||||
if (!trait->make_backward_graph && trait->apply_on_var_node) { | |||||
trait->make_backward_graph = proxy_graph_detail::make_backward_graph; | |||||
} | |||||
if (!trait->apply_on_physical_tensor && trait->infer_output_attrs && trait->exec) { | |||||
trait->apply_on_physical_tensor = fallback_apply_on_physical_tensor; | |||||
} | |||||
if(!trait->infer_output_attrs && trait->infer_output_attrs_fallible){ | |||||
trait->infer_output_attrs = fallback_infer_output_attrs; | |||||
if (trait->apply_on_var_node) { | |||||
// fallback to proxy graph impl | |||||
if (!trait->apply_on_physical_tensor) { | |||||
trait->apply_on_physical_tensor = | |||||
proxy_graph_detail::apply_on_physical_tensor; | |||||
} | |||||
if (!trait->infer_output_attrs_fallible) { | |||||
trait->infer_output_attrs_fallible = | |||||
proxy_graph_detail::infer_output_attrs_fallible; | |||||
} | |||||
if (!trait->make_backward_graph) { | |||||
trait->make_backward_graph = | |||||
proxy_graph_detail::make_backward_graph; | |||||
} | |||||
} | } | ||||
return *this; | return *this; | ||||
} | } | ||||
void OpTraitRegistry::do_insert(Typeinfo* type) { | void OpTraitRegistry::do_insert(Typeinfo* type) { | ||||
auto&& sd = detail::static_data(); | auto&& sd = detail::static_data(); | ||||
mgb_assert(sd.type2reg.emplace(type, trait).second); | |||||
auto ret = sd.type2reg.emplace(type, trait); | |||||
mgb_assert(ret.second || ret.first->second == trait, | |||||
"OpTrait for %s has already been registered", type->name); | |||||
} | } | ||||
OpTraitRegistry OpTraitRegistry::do_insert(const char* name) { | OpTraitRegistry OpTraitRegistry::do_insert(const char* name) { | ||||
auto&& sd = detail::static_data(); | auto&& sd = detail::static_data(); | ||||
if (name) { | if (name) { | ||||
mgb_assert(!sd.name2reg.count(name), | |||||
"duplicated opr trait %s", name); | |||||
auto iter = sd.name2reg.find(name); | |||||
if (iter != sd.name2reg.end()) { | |||||
return {iter->second}; | |||||
} | |||||
} | } | ||||
sd.registries.emplace_back(name); | sd.registries.emplace_back(name); | ||||
auto ret = &sd.registries.back(); | auto ret = &sd.registries.back(); | ||||
sd.name2reg.emplace(name, ret); | |||||
if (name) { | |||||
sd.name2reg.emplace(name, ret); | |||||
} | |||||
return {ret}; | return {ret}; | ||||
} | } | ||||
@@ -16,29 +16,39 @@ | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
using OpDefMaker = thin_function< | |||||
namespace detail { | |||||
template<typename Signature> | |||||
struct OpMeth; | |||||
template<typename RType, typename ...Args> | |||||
struct OpMeth<RType(Args...)>: public thin_function<RType(Args...)> { | |||||
using Base = thin_function<RType(Args...)>; | |||||
using Base::Base; | |||||
RType operator()(Args... args) const { | |||||
if (!this->Base::operator bool()) { | |||||
mgb_throw(MegBrainError, "Not Implemented"); | |||||
} | |||||
return this->Base::operator ()(args...); | |||||
} | |||||
}; | |||||
} // detail | |||||
using OpDefMaker = detail::OpMeth< | |||||
decltype(OpDef::make_from_op_node)>; | decltype(OpDef::make_from_op_node)>; | ||||
using ApplyOnPhysicalTensor = thin_function< | |||||
using ApplyOnPhysicalTensor = detail::OpMeth< | |||||
decltype(OpDef::apply_on_physical_tensor)>; | decltype(OpDef::apply_on_physical_tensor)>; | ||||
using PhysicalTensorExecutor = thin_function< | |||||
decltype(OpDef::exec)>; | |||||
using ApplyOnVarNode = thin_function< | |||||
using ApplyOnVarNode = detail::OpMeth< | |||||
decltype(OpDef::apply_on_var_node)>; | decltype(OpDef::apply_on_var_node)>; | ||||
using InferOutputAttrsFallible = thin_function< | |||||
using InferOutputAttrsFallible = detail::OpMeth< | |||||
decltype(OpDef::infer_output_attrs_fallible)>; | decltype(OpDef::infer_output_attrs_fallible)>; | ||||
using InferOutputAttrs = thin_function< | |||||
decltype(OpDef::infer_output_attrs)>; | |||||
using GradMaker = thin_function< | |||||
using GradMaker = detail::OpMeth< | |||||
decltype(OpDef::make_backward_graph)>; | decltype(OpDef::make_backward_graph)>; | ||||
struct OpTrait { | struct OpTrait { | ||||
const char* name; | const char* name; | ||||
OpDefMaker make_from_op_node; | OpDefMaker make_from_op_node; | ||||
ApplyOnPhysicalTensor apply_on_physical_tensor; | ApplyOnPhysicalTensor apply_on_physical_tensor; | ||||
PhysicalTensorExecutor exec; | |||||
ApplyOnVarNode apply_on_var_node; | ApplyOnVarNode apply_on_var_node; | ||||
InferOutputAttrsFallible infer_output_attrs_fallible; | InferOutputAttrsFallible infer_output_attrs_fallible; | ||||
InferOutputAttrs infer_output_attrs; | |||||
GradMaker make_backward_graph; | GradMaker make_backward_graph; | ||||
OpTrait(const char* name); | OpTrait(const char* name); | ||||
static OpTrait* find_by_name(const char* name); | static OpTrait* find_by_name(const char* name); | ||||
@@ -46,38 +56,25 @@ struct OpTrait { | |||||
static void for_each_trait(thin_function<void(OpTrait&)> visitor); | static void for_each_trait(thin_function<void(OpTrait&)> visitor); | ||||
}; | }; | ||||
#define FOR_EACH_OP_METH(cb) \ | |||||
cb(make_from_op_node) \ | |||||
cb(apply_on_physical_tensor) \ | |||||
cb(apply_on_var_node) \ | |||||
cb(infer_output_attrs_fallible) \ | |||||
cb(make_backward_graph) | |||||
struct OpTraitRegistry { | struct OpTraitRegistry { | ||||
OpTrait* trait; | OpTrait* trait; | ||||
OpTraitRegistry& make_from_op_node(OpDefMaker f) { | |||||
trait->make_from_op_node = f; | |||||
return *this; | |||||
} | |||||
OpTraitRegistry& apply_on_physical_tensor(ApplyOnPhysicalTensor f) { | |||||
trait->apply_on_physical_tensor = f; | |||||
return *this; | |||||
} | |||||
OpTraitRegistry& physical_tensor_executor(PhysicalTensorExecutor f) { | |||||
trait->exec = f; | |||||
return *this; | |||||
} | |||||
OpTraitRegistry& apply_on_var_node(ApplyOnVarNode f) { | |||||
trait->apply_on_var_node = f; | |||||
return *this; | |||||
} | |||||
OpTraitRegistry& infer_output_attrs_fallible(InferOutputAttrsFallible f) { | |||||
trait->infer_output_attrs_fallible = f; | |||||
return *this; | |||||
} | |||||
OpTraitRegistry& infer_output_attrs(InferOutputAttrs f) { | |||||
trait->infer_output_attrs = f; | |||||
return *this; | |||||
} | |||||
OpTraitRegistry& grad_maker(GradMaker f) { | |||||
trait->make_backward_graph = f; | |||||
return *this; | |||||
#define DECL(meth) \ | |||||
OpTraitRegistry& meth(decltype(OpTrait::meth) f) { \ | |||||
mgb_assert(!trait->meth, "op %s has duplicate method %s", trait->name, #meth); \ | |||||
trait->meth = f; \ | |||||
return *this; \ | |||||
} | } | ||||
FOR_EACH_OP_METH(DECL) | |||||
#undef DECL | |||||
OpTraitRegistry& fallback(); | OpTraitRegistry& fallback(); | ||||
OpTraitRegistry& finalize(); | |||||
template<typename T> | template<typename T> | ||||
void insert() { | void insert() { | ||||
@@ -102,20 +99,11 @@ struct OpTraitRegistry { | |||||
static OpTraitRegistry do_insert(const char* name); | static OpTraitRegistry do_insert(const char* name); | ||||
}; | }; | ||||
namespace detail { | |||||
struct _RegisterHelper { | |||||
OpTraitRegistry registry; | |||||
~_RegisterHelper() { | |||||
registry.finalize(); | |||||
} | |||||
}; | |||||
} // namespace detail | |||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | } // namespace mgb | ||||
#define OP_TRAIT_REG(name, ...) \ | #define OP_TRAIT_REG(name, ...) \ | ||||
static OpTraitRegistry __##name##_global_registry__ = \ | static OpTraitRegistry __##name##_global_registry__ = \ | ||||
detail::_RegisterHelper{OpTraitRegistry::insert<__VA_ARGS__>(#name)}.registry | |||||
OpTraitRegistry::insert<__VA_ARGS__>(#name) | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -110,20 +110,20 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
return out; | return out; | ||||
} | } | ||||
SmallVector<LogicalTensorDesc> infer_output_attrs( | |||||
SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( | |||||
const OpDef& def, | const OpDef& def, | ||||
const SmallVector<TensorPtr>& inputs) { | |||||
SmallVector<LogicalTensorDesc> out; | |||||
for (size_t i = 0; i < 2; ++ i) { | |||||
out.push_back({TensorLayout(), inputs[0]->comp_node()}); | |||||
} | |||||
return out; | |||||
const SmallVector<LogicalTensorDesc>& inputs) { | |||||
auto cn = inputs[0].comp_node; | |||||
return { | |||||
{TensorLayout(inputs[0].layout.dtype), cn}, | |||||
{TensorLayout(dtype::Int32()), cn} | |||||
}; | |||||
} | } | ||||
OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) | OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
.apply_on_physical_tensor(apply_on_physical_tensor) | .apply_on_physical_tensor(apply_on_physical_tensor) | ||||
.infer_output_attrs(infer_output_attrs) | |||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
.fallback(); | .fallback(); | ||||
} // namespace | } // namespace | ||||
@@ -28,10 +28,12 @@ namespace { | |||||
CompNode::UnorderedSet collect_comp_nodes( | CompNode::UnorderedSet collect_comp_nodes( | ||||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | const OpDef& def, const SmallVector<TensorPtr>& inputs) { | ||||
CompNode::UnorderedSet comp_nodes; | CompNode::UnorderedSet comp_nodes; | ||||
for (auto&& input : inputs) { | |||||
comp_nodes.insert(input->comp_node()); | |||||
SmallVector<LogicalTensorDesc> descs; | |||||
for (auto&& i : inputs) { | |||||
comp_nodes.insert(i->comp_node()); | |||||
descs.push_back({i->layout(), i->comp_node(), {}}); | |||||
} | } | ||||
for (auto&& output_attr : def.infer_output_attrs(def, inputs)) { | |||||
for (auto&& output_attr : def.infer_output_attrs_fallible(def, descs)) { | |||||
comp_nodes.insert(output_attr.comp_node); | comp_nodes.insert(output_attr.comp_node); | ||||
} | } | ||||
return comp_nodes; | return comp_nodes; | ||||
@@ -31,7 +31,6 @@ SmallVector<Tensor*> to_raw_ptr_array( | |||||
} | } | ||||
return ret; | return ret; | ||||
} | } | ||||
} // anonymous namespace | |||||
void exec(const OpDef& def, | void exec(const OpDef& def, | ||||
const SmallVector<TensorPtr>& inputs_, | const SmallVector<TensorPtr>& inputs_, | ||||
@@ -61,11 +60,25 @@ void exec(const OpDef& def, | |||||
} | } | ||||
} | } | ||||
SmallVector<LogicalTensorDesc> infer_output_attrs(const OpDef& def, | |||||
SmallVector<LogicalTensorDesc> | |||||
infer_output_attrs(const OpDef& def, | |||||
const SmallVector<TensorPtr>& inputs) { | const SmallVector<TensorPtr>& inputs) { | ||||
auto&& graph = ProxyGraph::get_default_graph(); | auto&& graph = ProxyGraph::get_default_graph(); | ||||
return graph->infer_output_attrs(def, to_raw_ptr_array(inputs)); | return graph->infer_output_attrs(def, to_raw_ptr_array(inputs)); | ||||
} | } | ||||
} // anonymous namespace | |||||
SmallVector<TensorPtr> | |||||
apply_on_physical_tensor(const OpDef& def, | |||||
const SmallVector<TensorPtr>& inputs) { | |||||
auto desc = infer_output_attrs(def, inputs); | |||||
SmallVector<TensorPtr> outputs; | |||||
for (auto&& i : desc) { | |||||
outputs.push_back(Tensor::make(i.layout, i.comp_node)); | |||||
} | |||||
exec(def, inputs, outputs); | |||||
return outputs; | |||||
} | |||||
SmallVector<LogicalTensorDesc> | SmallVector<LogicalTensorDesc> | ||||
infer_output_attrs_fallible(const OpDef& def, | infer_output_attrs_fallible(const OpDef& def, | ||||
@@ -17,11 +17,8 @@ namespace mgb { | |||||
namespace imperative { | namespace imperative { | ||||
namespace proxy_graph_detail { | namespace proxy_graph_detail { | ||||
void exec(const OpDef& def, | |||||
const SmallVector<TensorPtr>& inputs_, | |||||
const SmallVector<TensorPtr>& outputs_); | |||||
SmallVector<LogicalTensorDesc> infer_output_attrs(const OpDef& def, | |||||
SmallVector<TensorPtr> | |||||
apply_on_physical_tensor(const OpDef& def, | |||||
const SmallVector<TensorPtr>& inputs); | const SmallVector<TensorPtr>& inputs); | ||||
SmallVector<LogicalTensorDesc> | SmallVector<LogicalTensorDesc> | ||||
@@ -40,11 +40,6 @@ public: | |||||
const OpDef& def, | const OpDef& def, | ||||
const SmallVector<TensorPtr>& inputs); | const SmallVector<TensorPtr>& inputs); | ||||
static void exec( | |||||
const OpDef& def, | |||||
const SmallVector<TensorPtr>& inputs, | |||||
const SmallVector<TensorPtr>& outputs); | |||||
static cg::OperatorNodeBase* apply_on_var_node( | static cg::OperatorNodeBase* apply_on_var_node( | ||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs); | const VarNodeArray& inputs); | ||||
@@ -53,10 +48,6 @@ public: | |||||
const OpDef& def, | const OpDef& def, | ||||
const SmallVector<LogicalTensorDesc>& inputs); | const SmallVector<LogicalTensorDesc>& inputs); | ||||
static SmallVector<LogicalTensorDesc> infer_output_attrs( | |||||
const OpDef& def, | |||||
const SmallVector<TensorPtr>& inputs); | |||||
static BackwardGraphResult make_backward_graph( | static BackwardGraphResult make_backward_graph( | ||||
const OpDef& def, | const OpDef& def, | ||||
const SmallVector<LogicalTensorDesc>& inputs, | const SmallVector<LogicalTensorDesc>& inputs, | ||||