diff --git a/dnn/scripts/gen_tablegen.py b/dnn/scripts/gen_tablegen.py index 751e7578..f1c174c3 100755 --- a/dnn/scripts/gen_tablegen.py +++ b/dnn/scripts/gen_tablegen.py @@ -11,6 +11,11 @@ import io from gen_param_defs import member_defs, ParamDef, IndentWriterBase +# FIXME: move supportToString flag definition into the param def source file +ENUM_TO_STRING_SPECIAL_RULES = [ + ("Elemwise", "Mode"), + ("ElemwiseMultiType", "Mode") +] class ConverterWriter(IndentWriterBase): _skip_current_param = False @@ -86,7 +91,10 @@ class ConverterWriter(IndentWriterBase): def format(v): return '\"{}\"'.format(str(v)) enum_def += ','.join(format(i) for i in e.members) - enum_def += "]>" + enum_def += "]" + if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)): + enum_def += ", 1" # whether generate ToStringTrait + enum_def += ">" self._write("def {} : {};".format(td_class, enum_def)) if self._skip_current_param: diff --git a/imperative/src/include/megbrain/imperative/ops/autogen.h b/imperative/src/include/megbrain/imperative/ops/autogen.h index 96e39a52..e2540278 100644 --- a/imperative/src/include/megbrain/imperative/ops/autogen.h +++ b/imperative/src/include/megbrain/imperative/ops/autogen.h @@ -12,6 +12,7 @@ #pragma once #include "megbrain/imperative/op_def.h" +#include "megbrain/imperative/utils/to_string.h" #include "megdnn/opr_param_defs.h" #include "megbrain/opr/param_defs.h" diff --git a/imperative/tablegen/autogen.cpp b/imperative/tablegen/autogen.cpp index a86788ad..eec1a5ec 100644 --- a/imperative/tablegen/autogen.cpp +++ b/imperative/tablegen/autogen.cpp @@ -179,6 +179,34 @@ static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) { ); } +static void gen_to_string_trait_for_enum(raw_ostream &os, MgbOp& op) { + for (auto &&i : op.getMgbAttributes()) { + if (auto attr = llvm::dyn_cast(&i.attr)) { + if (attr->supportToString()) { + std::vector case_body; + std::string ename = formatv("{0}::{1}", + op.getCppClassName(), attr->getEnumName()); + llvm::for_each(attr->getEnumMembers(), [&](auto&& v){ + case_body.push_back(formatv( + "case {0}::{1}: return \"{1}\";", ename, v)); + }); + os << formatv(R"( +template <> +struct ToStringTrait<{0}> { + std::string operator()({0} e) const { + switch (e) { + {1} + default: + return "{0}::Unknown"; + } + } +}; +)", ename, llvm::join(case_body, "\n")); + } + } + } +} + static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { auto&& className = op.getCppClassName(); os << formatv( @@ -241,7 +269,13 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { os << formatv( "std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name") ); - os << mlir::tblgen::tgfmt(hashable->getNameFunctionTemplate(), &ctx); + os << formatv( + " auto&& op_ = def_.cast_final_safe<{0}>();\n" + " static_cast(op_);\n", + className + ); + ctx.withSelf("op_"); + os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx); os << "}\n"; os << "} // anonymous namespace\n"; @@ -577,6 +611,7 @@ static void for_each_operator(raw_ostream &os, RecordKeeper &keeper, static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) { for_each_operator(os, keeper, gen_op_def_c_header_single); + for_each_operator(os, keeper, gen_to_string_trait_for_enum); return false; } diff --git a/imperative/tablegen/helper.h b/imperative/tablegen/helper.h index 9145d387..c5f084e4 100644 --- a/imperative/tablegen/helper.h +++ b/imperative/tablegen/helper.h @@ -74,6 +74,9 @@ struct MgbEnumAttrMixin : public MgbAttrWrapperBase { std::vector getEnumMembers() const { return getBaseRecord()->getValueAsListOfStrings("enumMembers"); } + bool supportToString() const { + return getBaseRecord()->getValueAsBit("supportToString"); + } }; struct MgbHashableAttrMixin : public MgbAttrWrapperBase { @@ -170,6 +173,12 @@ public: } return ret; } + std::string getNameFunctionTemplate() const { + if (auto f = getDef().getValueAsOptionalString("nameFunction")) { + return f.getValue().str(); + } + return formatv(" return \"{0}\";\n", getCppClassName()); + } }; struct MgbHashableOpMixin : public MgbOpBase { @@ -241,30 +250,6 @@ private: body += " return props_;\n"; return body; } - std::string getModeName() const { - std::string body = formatv( - " auto&& op_ = def_.cast_final_safe<{0}>();\n" - " static_cast(op_);\n", - getCppClassName() - ); - for (auto&& it : getMgbAttributes()) { - if (it.name == "mode") { - auto* enumAttr = llvm::dyn_cast(&it.attr); - body += " switch (op_.mode){\n"; - for (auto&& enumMember: enumAttr->getEnumMembers()) { - body += formatv( - " case {0}::{1}::{2}:\n", - getCppClassName(), enumAttr->getEnumName(), enumMember - ); - body += formatv(" return \"{0}\";\n", enumMember); - } - body += formatv( - " default: return \"{0}::Unknown\";\n", getCppClassName()); - body += " }\n"; - } - } - return body; - } public: static bool classof(const Operator* op) { return op->getDef().isSubClassOf("MgbHashableOpMixin"); @@ -288,12 +273,6 @@ public: } return getDefaultPropsFunction(); } - std::string getNameFunctionTemplate() const { - if (getDef().getValueAsBit("usingModeName")) { - return getModeName(); - } - return formatv(" return \"{0}\";\n", getCppClassName()); - } }; } // namespace tblgen diff --git a/src/core/include/megbrain/ir/base.td b/src/core/include/megbrain/ir/base.td index 2b11392d..d1f35ebc 100644 --- a/src/core/include/megbrain/ir/base.td +++ b/src/core/include/megbrain/ir/base.td @@ -33,10 +33,11 @@ class MgbHashableAttrMixin { string reprFunction = "std::to_string($0)"; } -class MgbEnumAttrMixin members> { +class MgbEnumAttrMixin members, bit toString> { string parentNamespace = namespace; string enumName = name; list enumMembers = members; + bit supportToString = toString; } class MgbAttrWrapper; @@ -165,8 +166,8 @@ class MgbTupleAttr args>: } // -- enum types -class MgbEnumAttr members>: - HashableAttr, MgbEnumAttrMixin { +class MgbEnumAttr members, bit toString=0>: + HashableAttr, MgbEnumAttrMixin { let storageType = "::mlir::IntegerAttr"; let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast($0))"; @@ -242,7 +243,6 @@ class MgbPackedParamBase: class MgbHashableOpMixin { string hashFunction = ?; string cmpFunction = ?; - bit usingModeName = 0; } class MgbOp params=[], list traits=[]>: @@ -251,6 +251,7 @@ class MgbOp params=[], list traits= dag extraArguments = (ins); // TODO: remove it code extraOpdefDecl = ?; + code nameFunction = ?; let arguments = !con( !foldl(inputs, params, args, param, !con(args, param.fields)), diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 674acc7e..8dae0591 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -21,7 +21,9 @@ include "mlir/Interfaces/SideEffectInterfaces.td" def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> { let inputs = (ins Variadic:$input); let results = (outs AnyType); - let usingModeName = 1; + let nameFunction = [{ + return to_string($_self.mode); + }]; } def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>; @@ -248,7 +250,9 @@ def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypePara let extraArguments = (ins MgbDTypeAttr:$dtype ); - let usingModeName = 1; + let nameFunction = [{ + return to_string($_self.mode); + }]; } def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>;