remove unused op methods and enable OpTrait registered from seperated files, also fix cond_take's infer_output_attrs
GitOrigin-RevId: 134c8215ce
release-1.1
@@ -36,13 +36,6 @@ SmallVector<TensorPtr> OpDef::apply_on_physical_tensor( | |||
return def.trait()->apply_on_physical_tensor(def, inputs); | |||
} | |||
void OpDef::exec( | |||
const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs, | |||
const SmallVector<TensorPtr>& outputs) { | |||
def.trait()->exec(def, inputs, outputs); | |||
} | |||
cg::OperatorNodeBase* OpDef::apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
@@ -55,12 +48,6 @@ SmallVector<LogicalTensorDesc> OpDef::infer_output_attrs_fallible( | |||
return def.trait()->infer_output_attrs_fallible(def, inputs); | |||
} | |||
SmallVector<LogicalTensorDesc> OpDef::infer_output_attrs( | |||
const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs) { | |||
return def.trait()->infer_output_attrs(def, inputs); | |||
} | |||
BackwardGraphResult OpDef::make_backward_graph( | |||
const OpDef& def, | |||
const SmallVector<LogicalTensorDesc>& inputs, | |||
@@ -34,16 +34,6 @@ StaticData& static_data() { | |||
return data; | |||
} | |||
template<typename T> | |||
struct __not_implementation__; | |||
template<typename RType, typename ...Args> | |||
struct __not_implementation__<RType(Args...)> { | |||
static RType raise(Args ...) { | |||
mgb_throw(MegBrainError, "Not Implemented"); | |||
} | |||
}; | |||
} // detail | |||
OpTrait::OpTrait(const char* name_): name(name_) {} | |||
@@ -72,89 +62,45 @@ void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){ | |||
} | |||
} | |||
OpTraitRegistry& OpTraitRegistry::finalize() { | |||
std::ostringstream msg; | |||
#define CHECK(field) if (!trait->field) { \ | |||
msg << ", " #field; \ | |||
trait->field = \ | |||
detail::__not_implementation__<decltype(OpDef::field)>::raise; \ | |||
} | |||
CHECK(make_from_op_node); | |||
CHECK(apply_on_physical_tensor); | |||
CHECK(exec); | |||
CHECK(apply_on_var_node); | |||
CHECK(infer_output_attrs_fallible); | |||
CHECK(infer_output_attrs); | |||
CHECK(make_backward_graph); | |||
#undef CHECK | |||
#ifdef DEBUG | |||
if (msg.tellp() > 0) { | |||
mgb_log_warn( | |||
"%s op trait missing: %s", | |||
trait->name ? trait->name : "(anonymous)", | |||
msg.str().c_str() + 2 /* skip first ", " */); | |||
} | |||
#endif | |||
return *this; | |||
} | |||
SmallVector<TensorPtr> fallback_apply_on_physical_tensor( | |||
const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs) { | |||
auto desc = OpDef::infer_output_attrs(def, inputs); | |||
SmallVector<TensorPtr> outputs; | |||
for (auto&& i : desc) { | |||
outputs.push_back(Tensor::make(i.layout, i.comp_node)); | |||
} | |||
OpDef::exec(def, inputs, outputs); | |||
return outputs; | |||
} | |||
SmallVector<LogicalTensorDesc> fallback_infer_output_attrs(const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs){ | |||
SmallVector<LogicalTensorDesc> input_descs; | |||
for(auto&& input: inputs){ | |||
input_descs.push_back({input->layout(), input->comp_node()}); | |||
} | |||
return input_descs; | |||
} | |||
OpTraitRegistry& OpTraitRegistry::fallback() { | |||
if (!trait->exec && trait->apply_on_var_node) { | |||
trait->exec = proxy_graph_detail::exec; | |||
} | |||
if (!trait->infer_output_attrs && trait->apply_on_var_node) { | |||
trait->infer_output_attrs = proxy_graph_detail::infer_output_attrs; | |||
} | |||
if (!trait->infer_output_attrs_fallible && trait->apply_on_var_node) { | |||
trait->infer_output_attrs_fallible = proxy_graph_detail::infer_output_attrs_fallible; | |||
} | |||
if (!trait->make_backward_graph && trait->apply_on_var_node) { | |||
trait->make_backward_graph = proxy_graph_detail::make_backward_graph; | |||
} | |||
if (!trait->apply_on_physical_tensor && trait->infer_output_attrs && trait->exec) { | |||
trait->apply_on_physical_tensor = fallback_apply_on_physical_tensor; | |||
} | |||
if(!trait->infer_output_attrs && trait->infer_output_attrs_fallible){ | |||
trait->infer_output_attrs = fallback_infer_output_attrs; | |||
if (trait->apply_on_var_node) { | |||
// fallback to proxy graph impl | |||
if (!trait->apply_on_physical_tensor) { | |||
trait->apply_on_physical_tensor = | |||
proxy_graph_detail::apply_on_physical_tensor; | |||
} | |||
if (!trait->infer_output_attrs_fallible) { | |||
trait->infer_output_attrs_fallible = | |||
proxy_graph_detail::infer_output_attrs_fallible; | |||
} | |||
if (!trait->make_backward_graph) { | |||
trait->make_backward_graph = | |||
proxy_graph_detail::make_backward_graph; | |||
} | |||
} | |||
return *this; | |||
} | |||
void OpTraitRegistry::do_insert(Typeinfo* type) { | |||
auto&& sd = detail::static_data(); | |||
mgb_assert(sd.type2reg.emplace(type, trait).second); | |||
auto ret = sd.type2reg.emplace(type, trait); | |||
mgb_assert(ret.second || ret.first->second == trait, | |||
"OpTrait for %s has already been registered", type->name); | |||
} | |||
OpTraitRegistry OpTraitRegistry::do_insert(const char* name) { | |||
auto&& sd = detail::static_data(); | |||
if (name) { | |||
mgb_assert(!sd.name2reg.count(name), | |||
"duplicated opr trait %s", name); | |||
auto iter = sd.name2reg.find(name); | |||
if (iter != sd.name2reg.end()) { | |||
return {iter->second}; | |||
} | |||
} | |||
sd.registries.emplace_back(name); | |||
auto ret = &sd.registries.back(); | |||
sd.name2reg.emplace(name, ret); | |||
if (name) { | |||
sd.name2reg.emplace(name, ret); | |||
} | |||
return {ret}; | |||
} | |||
@@ -16,29 +16,39 @@ | |||
namespace mgb { | |||
namespace imperative { | |||
using OpDefMaker = thin_function< | |||
namespace detail { | |||
template<typename Signature> | |||
struct OpMeth; | |||
template<typename RType, typename ...Args> | |||
struct OpMeth<RType(Args...)>: public thin_function<RType(Args...)> { | |||
using Base = thin_function<RType(Args...)>; | |||
using Base::Base; | |||
RType operator()(Args... args) const { | |||
if (!this->Base::operator bool()) { | |||
mgb_throw(MegBrainError, "Not Implemented"); | |||
} | |||
return this->Base::operator ()(args...); | |||
} | |||
}; | |||
} // detail | |||
using OpDefMaker = detail::OpMeth< | |||
decltype(OpDef::make_from_op_node)>; | |||
using ApplyOnPhysicalTensor = thin_function< | |||
using ApplyOnPhysicalTensor = detail::OpMeth< | |||
decltype(OpDef::apply_on_physical_tensor)>; | |||
using PhysicalTensorExecutor = thin_function< | |||
decltype(OpDef::exec)>; | |||
using ApplyOnVarNode = thin_function< | |||
using ApplyOnVarNode = detail::OpMeth< | |||
decltype(OpDef::apply_on_var_node)>; | |||
using InferOutputAttrsFallible = thin_function< | |||
using InferOutputAttrsFallible = detail::OpMeth< | |||
decltype(OpDef::infer_output_attrs_fallible)>; | |||
using InferOutputAttrs = thin_function< | |||
decltype(OpDef::infer_output_attrs)>; | |||
using GradMaker = thin_function< | |||
using GradMaker = detail::OpMeth< | |||
decltype(OpDef::make_backward_graph)>; | |||
struct OpTrait { | |||
const char* name; | |||
OpDefMaker make_from_op_node; | |||
ApplyOnPhysicalTensor apply_on_physical_tensor; | |||
PhysicalTensorExecutor exec; | |||
ApplyOnVarNode apply_on_var_node; | |||
InferOutputAttrsFallible infer_output_attrs_fallible; | |||
InferOutputAttrs infer_output_attrs; | |||
GradMaker make_backward_graph; | |||
OpTrait(const char* name); | |||
static OpTrait* find_by_name(const char* name); | |||
@@ -46,38 +56,25 @@ struct OpTrait { | |||
static void for_each_trait(thin_function<void(OpTrait&)> visitor); | |||
}; | |||
#define FOR_EACH_OP_METH(cb) \ | |||
cb(make_from_op_node) \ | |||
cb(apply_on_physical_tensor) \ | |||
cb(apply_on_var_node) \ | |||
cb(infer_output_attrs_fallible) \ | |||
cb(make_backward_graph) | |||
struct OpTraitRegistry { | |||
OpTrait* trait; | |||
OpTraitRegistry& make_from_op_node(OpDefMaker f) { | |||
trait->make_from_op_node = f; | |||
return *this; | |||
} | |||
OpTraitRegistry& apply_on_physical_tensor(ApplyOnPhysicalTensor f) { | |||
trait->apply_on_physical_tensor = f; | |||
return *this; | |||
} | |||
OpTraitRegistry& physical_tensor_executor(PhysicalTensorExecutor f) { | |||
trait->exec = f; | |||
return *this; | |||
} | |||
OpTraitRegistry& apply_on_var_node(ApplyOnVarNode f) { | |||
trait->apply_on_var_node = f; | |||
return *this; | |||
} | |||
OpTraitRegistry& infer_output_attrs_fallible(InferOutputAttrsFallible f) { | |||
trait->infer_output_attrs_fallible = f; | |||
return *this; | |||
} | |||
OpTraitRegistry& infer_output_attrs(InferOutputAttrs f) { | |||
trait->infer_output_attrs = f; | |||
return *this; | |||
} | |||
OpTraitRegistry& grad_maker(GradMaker f) { | |||
trait->make_backward_graph = f; | |||
return *this; | |||
#define DECL(meth) \ | |||
OpTraitRegistry& meth(decltype(OpTrait::meth) f) { \ | |||
mgb_assert(!trait->meth, "op %s has duplicate method %s", trait->name, #meth); \ | |||
trait->meth = f; \ | |||
return *this; \ | |||
} | |||
FOR_EACH_OP_METH(DECL) | |||
#undef DECL | |||
OpTraitRegistry& fallback(); | |||
OpTraitRegistry& finalize(); | |||
template<typename T> | |||
void insert() { | |||
@@ -102,20 +99,11 @@ struct OpTraitRegistry { | |||
static OpTraitRegistry do_insert(const char* name); | |||
}; | |||
namespace detail { | |||
struct _RegisterHelper { | |||
OpTraitRegistry registry; | |||
~_RegisterHelper() { | |||
registry.finalize(); | |||
} | |||
}; | |||
} // namespace detail | |||
} // namespace imperative | |||
} // namespace mgb | |||
#define OP_TRAIT_REG(name, ...) \ | |||
static OpTraitRegistry __##name##_global_registry__ = \ | |||
detail::_RegisterHelper{OpTraitRegistry::insert<__VA_ARGS__>(#name)}.registry | |||
OpTraitRegistry::insert<__VA_ARGS__>(#name) | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -110,20 +110,20 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
return out; | |||
} | |||
SmallVector<LogicalTensorDesc> infer_output_attrs( | |||
SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( | |||
const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs) { | |||
SmallVector<LogicalTensorDesc> out; | |||
for (size_t i = 0; i < 2; ++ i) { | |||
out.push_back({TensorLayout(), inputs[0]->comp_node()}); | |||
} | |||
return out; | |||
const SmallVector<LogicalTensorDesc>& inputs) { | |||
auto cn = inputs[0].comp_node; | |||
return { | |||
{TensorLayout(inputs[0].layout.dtype), cn}, | |||
{TensorLayout(dtype::Int32()), cn} | |||
}; | |||
} | |||
OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) | |||
.apply_on_var_node(apply_on_var_node) | |||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||
.infer_output_attrs(infer_output_attrs) | |||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
.fallback(); | |||
} // namespace | |||
@@ -28,10 +28,12 @@ namespace { | |||
CompNode::UnorderedSet collect_comp_nodes( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
CompNode::UnorderedSet comp_nodes; | |||
for (auto&& input : inputs) { | |||
comp_nodes.insert(input->comp_node()); | |||
SmallVector<LogicalTensorDesc> descs; | |||
for (auto&& i : inputs) { | |||
comp_nodes.insert(i->comp_node()); | |||
descs.push_back({i->layout(), i->comp_node(), {}}); | |||
} | |||
for (auto&& output_attr : def.infer_output_attrs(def, inputs)) { | |||
for (auto&& output_attr : def.infer_output_attrs_fallible(def, descs)) { | |||
comp_nodes.insert(output_attr.comp_node); | |||
} | |||
return comp_nodes; | |||
@@ -31,7 +31,6 @@ SmallVector<Tensor*> to_raw_ptr_array( | |||
} | |||
return ret; | |||
} | |||
} // anonymous namespace | |||
void exec(const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs_, | |||
@@ -61,11 +60,25 @@ void exec(const OpDef& def, | |||
} | |||
} | |||
SmallVector<LogicalTensorDesc> infer_output_attrs(const OpDef& def, | |||
SmallVector<LogicalTensorDesc> | |||
infer_output_attrs(const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs) { | |||
auto&& graph = ProxyGraph::get_default_graph(); | |||
return graph->infer_output_attrs(def, to_raw_ptr_array(inputs)); | |||
} | |||
} // anonymous namespace | |||
SmallVector<TensorPtr> | |||
apply_on_physical_tensor(const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs) { | |||
auto desc = infer_output_attrs(def, inputs); | |||
SmallVector<TensorPtr> outputs; | |||
for (auto&& i : desc) { | |||
outputs.push_back(Tensor::make(i.layout, i.comp_node)); | |||
} | |||
exec(def, inputs, outputs); | |||
return outputs; | |||
} | |||
SmallVector<LogicalTensorDesc> | |||
infer_output_attrs_fallible(const OpDef& def, | |||
@@ -17,11 +17,8 @@ namespace mgb { | |||
namespace imperative { | |||
namespace proxy_graph_detail { | |||
void exec(const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs_, | |||
const SmallVector<TensorPtr>& outputs_); | |||
SmallVector<LogicalTensorDesc> infer_output_attrs(const OpDef& def, | |||
SmallVector<TensorPtr> | |||
apply_on_physical_tensor(const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs); | |||
SmallVector<LogicalTensorDesc> | |||
@@ -40,11 +40,6 @@ public: | |||
const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs); | |||
static void exec( | |||
const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs, | |||
const SmallVector<TensorPtr>& outputs); | |||
static cg::OperatorNodeBase* apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs); | |||
@@ -53,10 +48,6 @@ public: | |||
const OpDef& def, | |||
const SmallVector<LogicalTensorDesc>& inputs); | |||
static SmallVector<LogicalTensorDesc> infer_output_attrs( | |||
const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs); | |||
static BackwardGraphResult make_backward_graph( | |||
const OpDef& def, | |||
const SmallVector<LogicalTensorDesc>& inputs, | |||