Browse Source

refactor(imperative): refactor OpTrait methods and registration

remove unused op methods and enable OpTrait registered from seperated files, also fix cond_take's infer_output_attrs

GitOrigin-RevId: 134c8215ce
release-1.1
Megvii Engine Team 4 years ago
parent
commit
120e719e17
8 changed files with 91 additions and 167 deletions
  1. +0
    -13
      imperative/src/impl/op_def.cpp
  2. +24
    -78
      imperative/src/impl/op_trait.cpp
  3. +37
    -49
      imperative/src/impl/op_trait.h
  4. +8
    -8
      imperative/src/impl/ops/cond_take.cpp
  5. +5
    -3
      imperative/src/impl/profiler.cpp
  6. +15
    -2
      imperative/src/impl/proxy_graph_detail.cpp
  7. +2
    -5
      imperative/src/impl/proxy_graph_detail.h
  8. +0
    -9
      imperative/src/include/megbrain/imperative/op_def.h

+ 0
- 13
imperative/src/impl/op_def.cpp View File

@@ -36,13 +36,6 @@ SmallVector<TensorPtr> OpDef::apply_on_physical_tensor(
return def.trait()->apply_on_physical_tensor(def, inputs);
}

void OpDef::exec(
const OpDef& def,
const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs) {
def.trait()->exec(def, inputs, outputs);
}

cg::OperatorNodeBase* OpDef::apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
@@ -55,12 +48,6 @@ SmallVector<LogicalTensorDesc> OpDef::infer_output_attrs_fallible(
return def.trait()->infer_output_attrs_fallible(def, inputs);
}

SmallVector<LogicalTensorDesc> OpDef::infer_output_attrs(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
return def.trait()->infer_output_attrs(def, inputs);
}

BackwardGraphResult OpDef::make_backward_graph(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs,


+ 24
- 78
imperative/src/impl/op_trait.cpp View File

@@ -34,16 +34,6 @@ StaticData& static_data() {
return data;
}

template<typename T>
struct __not_implementation__;

template<typename RType, typename ...Args>
struct __not_implementation__<RType(Args...)> {
static RType raise(Args ...) {
mgb_throw(MegBrainError, "Not Implemented");
}
};

} // detail

OpTrait::OpTrait(const char* name_): name(name_) {}
@@ -72,89 +62,45 @@ void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){
}
}

OpTraitRegistry& OpTraitRegistry::finalize() {
std::ostringstream msg;
#define CHECK(field) if (!trait->field) { \
msg << ", " #field; \
trait->field = \
detail::__not_implementation__<decltype(OpDef::field)>::raise; \
}
CHECK(make_from_op_node);
CHECK(apply_on_physical_tensor);
CHECK(exec);
CHECK(apply_on_var_node);
CHECK(infer_output_attrs_fallible);
CHECK(infer_output_attrs);
CHECK(make_backward_graph);
#undef CHECK
#ifdef DEBUG
if (msg.tellp() > 0) {
mgb_log_warn(
"%s op trait missing: %s",
trait->name ? trait->name : "(anonymous)",
msg.str().c_str() + 2 /* skip first ", " */);
}
#endif
return *this;
}

SmallVector<TensorPtr> fallback_apply_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
auto desc = OpDef::infer_output_attrs(def, inputs);
SmallVector<TensorPtr> outputs;
for (auto&& i : desc) {
outputs.push_back(Tensor::make(i.layout, i.comp_node));
}
OpDef::exec(def, inputs, outputs);
return outputs;
}

SmallVector<LogicalTensorDesc> fallback_infer_output_attrs(const OpDef& def,
const SmallVector<TensorPtr>& inputs){
SmallVector<LogicalTensorDesc> input_descs;
for(auto&& input: inputs){
input_descs.push_back({input->layout(), input->comp_node()});
}
return input_descs;
}

OpTraitRegistry& OpTraitRegistry::fallback() {
if (!trait->exec && trait->apply_on_var_node) {
trait->exec = proxy_graph_detail::exec;
}
if (!trait->infer_output_attrs && trait->apply_on_var_node) {
trait->infer_output_attrs = proxy_graph_detail::infer_output_attrs;
}
if (!trait->infer_output_attrs_fallible && trait->apply_on_var_node) {
trait->infer_output_attrs_fallible = proxy_graph_detail::infer_output_attrs_fallible;
}
if (!trait->make_backward_graph && trait->apply_on_var_node) {
trait->make_backward_graph = proxy_graph_detail::make_backward_graph;
}
if (!trait->apply_on_physical_tensor && trait->infer_output_attrs && trait->exec) {
trait->apply_on_physical_tensor = fallback_apply_on_physical_tensor;
}
if(!trait->infer_output_attrs && trait->infer_output_attrs_fallible){
trait->infer_output_attrs = fallback_infer_output_attrs;
if (trait->apply_on_var_node) {
// fallback to proxy graph impl
if (!trait->apply_on_physical_tensor) {
trait->apply_on_physical_tensor =
proxy_graph_detail::apply_on_physical_tensor;
}
if (!trait->infer_output_attrs_fallible) {
trait->infer_output_attrs_fallible =
proxy_graph_detail::infer_output_attrs_fallible;
}
if (!trait->make_backward_graph) {
trait->make_backward_graph =
proxy_graph_detail::make_backward_graph;
}
}
return *this;
}

void OpTraitRegistry::do_insert(Typeinfo* type) {
auto&& sd = detail::static_data();
mgb_assert(sd.type2reg.emplace(type, trait).second);
auto ret = sd.type2reg.emplace(type, trait);
mgb_assert(ret.second || ret.first->second == trait,
"OpTrait for %s has already been registered", type->name);
}

OpTraitRegistry OpTraitRegistry::do_insert(const char* name) {
auto&& sd = detail::static_data();
if (name) {
mgb_assert(!sd.name2reg.count(name),
"duplicated opr trait %s", name);
auto iter = sd.name2reg.find(name);
if (iter != sd.name2reg.end()) {
return {iter->second};
}
}
sd.registries.emplace_back(name);
auto ret = &sd.registries.back();
sd.name2reg.emplace(name, ret);
if (name) {
sd.name2reg.emplace(name, ret);
}
return {ret};
}



+ 37
- 49
imperative/src/impl/op_trait.h View File

@@ -16,29 +16,39 @@
namespace mgb {
namespace imperative {

using OpDefMaker = thin_function<
namespace detail {
template<typename Signature>
struct OpMeth;
template<typename RType, typename ...Args>
struct OpMeth<RType(Args...)>: public thin_function<RType(Args...)> {
using Base = thin_function<RType(Args...)>;
using Base::Base;
RType operator()(Args... args) const {
if (!this->Base::operator bool()) {
mgb_throw(MegBrainError, "Not Implemented");
}
return this->Base::operator ()(args...);
}
};
} // detail

using OpDefMaker = detail::OpMeth<
decltype(OpDef::make_from_op_node)>;
using ApplyOnPhysicalTensor = thin_function<
using ApplyOnPhysicalTensor = detail::OpMeth<
decltype(OpDef::apply_on_physical_tensor)>;
using PhysicalTensorExecutor = thin_function<
decltype(OpDef::exec)>;
using ApplyOnVarNode = thin_function<
using ApplyOnVarNode = detail::OpMeth<
decltype(OpDef::apply_on_var_node)>;
using InferOutputAttrsFallible = thin_function<
using InferOutputAttrsFallible = detail::OpMeth<
decltype(OpDef::infer_output_attrs_fallible)>;
using InferOutputAttrs = thin_function<
decltype(OpDef::infer_output_attrs)>;
using GradMaker = thin_function<
using GradMaker = detail::OpMeth<
decltype(OpDef::make_backward_graph)>;

struct OpTrait {
const char* name;
OpDefMaker make_from_op_node;
ApplyOnPhysicalTensor apply_on_physical_tensor;
PhysicalTensorExecutor exec;
ApplyOnVarNode apply_on_var_node;
InferOutputAttrsFallible infer_output_attrs_fallible;
InferOutputAttrs infer_output_attrs;
GradMaker make_backward_graph;
OpTrait(const char* name);
static OpTrait* find_by_name(const char* name);
@@ -46,38 +56,25 @@ struct OpTrait {
static void for_each_trait(thin_function<void(OpTrait&)> visitor);
};

#define FOR_EACH_OP_METH(cb) \
cb(make_from_op_node) \
cb(apply_on_physical_tensor) \
cb(apply_on_var_node) \
cb(infer_output_attrs_fallible) \
cb(make_backward_graph)

struct OpTraitRegistry {
OpTrait* trait;
OpTraitRegistry& make_from_op_node(OpDefMaker f) {
trait->make_from_op_node = f;
return *this;
}
OpTraitRegistry& apply_on_physical_tensor(ApplyOnPhysicalTensor f) {
trait->apply_on_physical_tensor = f;
return *this;
}
OpTraitRegistry& physical_tensor_executor(PhysicalTensorExecutor f) {
trait->exec = f;
return *this;
}
OpTraitRegistry& apply_on_var_node(ApplyOnVarNode f) {
trait->apply_on_var_node = f;
return *this;
}
OpTraitRegistry& infer_output_attrs_fallible(InferOutputAttrsFallible f) {
trait->infer_output_attrs_fallible = f;
return *this;
}
OpTraitRegistry& infer_output_attrs(InferOutputAttrs f) {
trait->infer_output_attrs = f;
return *this;
}
OpTraitRegistry& grad_maker(GradMaker f) {
trait->make_backward_graph = f;
return *this;
#define DECL(meth) \
OpTraitRegistry& meth(decltype(OpTrait::meth) f) { \
mgb_assert(!trait->meth, "op %s has duplicate method %s", trait->name, #meth); \
trait->meth = f; \
return *this; \
}
FOR_EACH_OP_METH(DECL)
#undef DECL

OpTraitRegistry& fallback();
OpTraitRegistry& finalize();

template<typename T>
void insert() {
@@ -102,20 +99,11 @@ struct OpTraitRegistry {
static OpTraitRegistry do_insert(const char* name);
};

namespace detail {
struct _RegisterHelper {
OpTraitRegistry registry;
~_RegisterHelper() {
registry.finalize();
}
};
} // namespace detail

} // namespace imperative
} // namespace mgb

#define OP_TRAIT_REG(name, ...) \
static OpTraitRegistry __##name##_global_registry__ = \
detail::_RegisterHelper{OpTraitRegistry::insert<__VA_ARGS__>(#name)}.registry
OpTraitRegistry::insert<__VA_ARGS__>(#name)

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 8
- 8
imperative/src/impl/ops/cond_take.cpp View File

@@ -110,20 +110,20 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return out;
}

SmallVector<LogicalTensorDesc> infer_output_attrs(
SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
SmallVector<LogicalTensorDesc> out;
for (size_t i = 0; i < 2; ++ i) {
out.push_back({TensorLayout(), inputs[0]->comp_node()});
}
return out;
const SmallVector<LogicalTensorDesc>& inputs) {
auto cn = inputs[0].comp_node;
return {
{TensorLayout(inputs[0].layout.dtype), cn},
{TensorLayout(dtype::Int32()), cn}
};
}

OP_TRAIT_REG(CondTake, CondTake, opr::CondTake)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_attrs(infer_output_attrs)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.fallback();

} // namespace


+ 5
- 3
imperative/src/impl/profiler.cpp View File

@@ -28,10 +28,12 @@ namespace {
CompNode::UnorderedSet collect_comp_nodes(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
CompNode::UnorderedSet comp_nodes;
for (auto&& input : inputs) {
comp_nodes.insert(input->comp_node());
SmallVector<LogicalTensorDesc> descs;
for (auto&& i : inputs) {
comp_nodes.insert(i->comp_node());
descs.push_back({i->layout(), i->comp_node(), {}});
}
for (auto&& output_attr : def.infer_output_attrs(def, inputs)) {
for (auto&& output_attr : def.infer_output_attrs_fallible(def, descs)) {
comp_nodes.insert(output_attr.comp_node);
}
return comp_nodes;


+ 15
- 2
imperative/src/impl/proxy_graph_detail.cpp View File

@@ -31,7 +31,6 @@ SmallVector<Tensor*> to_raw_ptr_array(
}
return ret;
}
} // anonymous namespace

void exec(const OpDef& def,
const SmallVector<TensorPtr>& inputs_,
@@ -61,11 +60,25 @@ void exec(const OpDef& def,
}
}

SmallVector<LogicalTensorDesc> infer_output_attrs(const OpDef& def,
SmallVector<LogicalTensorDesc>
infer_output_attrs(const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
auto&& graph = ProxyGraph::get_default_graph();
return graph->infer_output_attrs(def, to_raw_ptr_array(inputs));
}
} // anonymous namespace

SmallVector<TensorPtr>
apply_on_physical_tensor(const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
auto desc = infer_output_attrs(def, inputs);
SmallVector<TensorPtr> outputs;
for (auto&& i : desc) {
outputs.push_back(Tensor::make(i.layout, i.comp_node));
}
exec(def, inputs, outputs);
return outputs;
}

SmallVector<LogicalTensorDesc>
infer_output_attrs_fallible(const OpDef& def,


+ 2
- 5
imperative/src/impl/proxy_graph_detail.h View File

@@ -17,11 +17,8 @@ namespace mgb {
namespace imperative {
namespace proxy_graph_detail {

void exec(const OpDef& def,
const SmallVector<TensorPtr>& inputs_,
const SmallVector<TensorPtr>& outputs_);

SmallVector<LogicalTensorDesc> infer_output_attrs(const OpDef& def,
SmallVector<TensorPtr>
apply_on_physical_tensor(const OpDef& def,
const SmallVector<TensorPtr>& inputs);

SmallVector<LogicalTensorDesc>


+ 0
- 9
imperative/src/include/megbrain/imperative/op_def.h View File

@@ -40,11 +40,6 @@ public:
const OpDef& def,
const SmallVector<TensorPtr>& inputs);

static void exec(
const OpDef& def,
const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs);

static cg::OperatorNodeBase* apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs);
@@ -53,10 +48,6 @@ public:
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs);

static SmallVector<LogicalTensorDesc> infer_output_attrs(
const OpDef& def,
const SmallVector<TensorPtr>& inputs);

static BackwardGraphResult make_backward_graph(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs,


Loading…
Cancel
Save