Browse Source

chore(imperative): refine tblgen for generating op name

GitOrigin-RevId: f47ceae726
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
ad87f78a14
6 changed files with 66 additions and 38 deletions
  1. +9
    -1
      dnn/scripts/gen_tablegen.py
  2. +1
    -0
      imperative/src/include/megbrain/imperative/ops/autogen.h
  3. +36
    -1
      imperative/tablegen/autogen.cpp
  4. +9
    -30
      imperative/tablegen/helper.h
  5. +5
    -4
      src/core/include/megbrain/ir/base.td
  6. +6
    -2
      src/core/include/megbrain/ir/ops.td

+ 9
- 1
dnn/scripts/gen_tablegen.py View File

@@ -11,6 +11,11 @@ import io


from gen_param_defs import member_defs, ParamDef, IndentWriterBase from gen_param_defs import member_defs, ParamDef, IndentWriterBase


# FIXME: move supportToString flag definition into the param def source file
ENUM_TO_STRING_SPECIAL_RULES = [
("Elemwise", "Mode"),
("ElemwiseMultiType", "Mode")
]


class ConverterWriter(IndentWriterBase): class ConverterWriter(IndentWriterBase):
_skip_current_param = False _skip_current_param = False
@@ -86,7 +91,10 @@ class ConverterWriter(IndentWriterBase):
def format(v): def format(v):
return '\"{}\"'.format(str(v)) return '\"{}\"'.format(str(v))
enum_def += ','.join(format(i) for i in e.members) enum_def += ','.join(format(i) for i in e.members)
enum_def += "]>"
enum_def += "]"
if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)):
enum_def += ", 1" # whether generate ToStringTrait
enum_def += ">"
self._write("def {} : {};".format(td_class, enum_def)) self._write("def {} : {};".format(td_class, enum_def))


if self._skip_current_param: if self._skip_current_param:


+ 1
- 0
imperative/src/include/megbrain/imperative/ops/autogen.h View File

@@ -12,6 +12,7 @@
#pragma once #pragma once


#include "megbrain/imperative/op_def.h" #include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/utils/to_string.h"
#include "megdnn/opr_param_defs.h" #include "megdnn/opr_param_defs.h"
#include "megbrain/opr/param_defs.h" #include "megbrain/opr/param_defs.h"




+ 36
- 1
imperative/tablegen/autogen.cpp View File

@@ -179,6 +179,34 @@ static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) {
); );
} }


static void gen_to_string_trait_for_enum(raw_ostream &os, MgbOp& op) {
for (auto &&i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
if (attr->supportToString()) {
std::vector<std::string> case_body;
std::string ename = formatv("{0}::{1}",
op.getCppClassName(), attr->getEnumName());
llvm::for_each(attr->getEnumMembers(), [&](auto&& v){
case_body.push_back(formatv(
"case {0}::{1}: return \"{1}\";", ename, v));
});
os << formatv(R"(
template <>
struct ToStringTrait<{0}> {
std::string operator()({0} e) const {
switch (e) {
{1}
default:
return "{0}::Unknown";
}
}
};
)", ename, llvm::join(case_body, "\n"));
}
}
}
}

static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) {
auto&& className = op.getCppClassName(); auto&& className = op.getCppClassName();
os << formatv( os << formatv(
@@ -241,7 +269,13 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) {
os << formatv( os << formatv(
"std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name") "std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name")
); );
os << mlir::tblgen::tgfmt(hashable->getNameFunctionTemplate(), &ctx);
os << formatv(
" auto&& op_ = def_.cast_final_safe<{0}>();\n"
" static_cast<void>(op_);\n",
className
);
ctx.withSelf("op_");
os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx);
os << "}\n"; os << "}\n";


os << "} // anonymous namespace\n"; os << "} // anonymous namespace\n";
@@ -577,6 +611,7 @@ static void for_each_operator(raw_ostream &os, RecordKeeper &keeper,


static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) { static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) {
for_each_operator(os, keeper, gen_op_def_c_header_single); for_each_operator(os, keeper, gen_op_def_c_header_single);
for_each_operator(os, keeper, gen_to_string_trait_for_enum);
return false; return false;
} }




+ 9
- 30
imperative/tablegen/helper.h View File

@@ -74,6 +74,9 @@ struct MgbEnumAttrMixin : public MgbAttrWrapperBase {
std::vector<StringRef> getEnumMembers() const { std::vector<StringRef> getEnumMembers() const {
return getBaseRecord()->getValueAsListOfStrings("enumMembers"); return getBaseRecord()->getValueAsListOfStrings("enumMembers");
} }
bool supportToString() const {
return getBaseRecord()->getValueAsBit("supportToString");
}
}; };


struct MgbHashableAttrMixin : public MgbAttrWrapperBase { struct MgbHashableAttrMixin : public MgbAttrWrapperBase {
@@ -170,6 +173,12 @@ public:
} }
return ret; return ret;
} }
std::string getNameFunctionTemplate() const {
if (auto f = getDef().getValueAsOptionalString("nameFunction")) {
return f.getValue().str();
}
return formatv(" return \"{0}\";\n", getCppClassName());
}
}; };


struct MgbHashableOpMixin : public MgbOpBase { struct MgbHashableOpMixin : public MgbOpBase {
@@ -241,30 +250,6 @@ private:
body += " return props_;\n"; body += " return props_;\n";
return body; return body;
} }
std::string getModeName() const {
std::string body = formatv(
" auto&& op_ = def_.cast_final_safe<{0}>();\n"
" static_cast<void>(op_);\n",
getCppClassName()
);
for (auto&& it : getMgbAttributes()) {
if (it.name == "mode") {
auto* enumAttr = llvm::dyn_cast<MgbEnumAttrMixin>(&it.attr);
body += " switch (op_.mode){\n";
for (auto&& enumMember: enumAttr->getEnumMembers()) {
body += formatv(
" case {0}::{1}::{2}:\n",
getCppClassName(), enumAttr->getEnumName(), enumMember
);
body += formatv(" return \"{0}\";\n", enumMember);
}
body += formatv(
" default: return \"{0}::Unknown\";\n", getCppClassName());
body += " }\n";
}
}
return body;
}
public: public:
static bool classof(const Operator* op) { static bool classof(const Operator* op) {
return op->getDef().isSubClassOf("MgbHashableOpMixin"); return op->getDef().isSubClassOf("MgbHashableOpMixin");
@@ -288,12 +273,6 @@ public:
} }
return getDefaultPropsFunction(); return getDefaultPropsFunction();
} }
std::string getNameFunctionTemplate() const {
if (getDef().getValueAsBit("usingModeName")) {
return getModeName();
}
return formatv(" return \"{0}\";\n", getCppClassName());
}
}; };


} // namespace tblgen } // namespace tblgen


+ 5
- 4
src/core/include/megbrain/ir/base.td View File

@@ -33,10 +33,11 @@ class MgbHashableAttrMixin {
string reprFunction = "std::to_string($0)"; string reprFunction = "std::to_string($0)";
} }


class MgbEnumAttrMixin<string namespace, string name, list<string> members> {
class MgbEnumAttrMixin<string namespace, string name, list<string> members, bit toString> {
string parentNamespace = namespace; string parentNamespace = namespace;
string enumName = name; string enumName = name;
list<string> enumMembers = members; list<string> enumMembers = members;
bit supportToString = toString;
} }


class MgbAttrWrapper; class MgbAttrWrapper;
@@ -165,8 +166,8 @@ class MgbTupleAttr<list<MgbAttrWrapper> args>:
} }


// -- enum types // -- enum types
class MgbEnumAttr<string namespace, string enumName, list<string> members>:
HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members> {
class MgbEnumAttr<string namespace, string enumName, list<string> members, bit toString=0>:
HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members, toString> {
let storageType = "::mlir::IntegerAttr"; let storageType = "::mlir::IntegerAttr";
let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())";
let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))"; let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))";
@@ -242,7 +243,6 @@ class MgbPackedParamBase<string className, string accessor>:
class MgbHashableOpMixin { class MgbHashableOpMixin {
string hashFunction = ?; string hashFunction = ?;
string cmpFunction = ?; string cmpFunction = ?;
bit usingModeName = 0;
} }


class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>: class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>:
@@ -251,6 +251,7 @@ class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=
dag extraArguments = (ins); dag extraArguments = (ins);
// TODO: remove it // TODO: remove it
code extraOpdefDecl = ?; code extraOpdefDecl = ?;
code nameFunction = ?;


let arguments = !con( let arguments = !con(
!foldl(inputs, params, args, param, !con(args, param.fields)), !foldl(inputs, params, args, param, !con(args, param.fields)),


+ 6
- 2
src/core/include/megbrain/ir/ops.td View File

@@ -21,7 +21,9 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> { def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> {
let inputs = (ins Variadic<AnyType>:$input); let inputs = (ins Variadic<AnyType>:$input);
let results = (outs AnyType); let results = (outs AnyType);
let usingModeName = 1;
let nameFunction = [{
return to_string($_self.mode);
}];
} }


def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>; def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>;
@@ -248,7 +250,9 @@ def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypePara
let extraArguments = (ins let extraArguments = (ins
MgbDTypeAttr:$dtype MgbDTypeAttr:$dtype
); );
let usingModeName = 1;
let nameFunction = [{
return to_string($_self.mode);
}];
} }


def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>; def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>;


Loading…
Cancel
Save