GitOrigin-RevId: f47ceae726
tags/v1.3.0
@@ -11,6 +11,11 @@ import io | |||||
from gen_param_defs import member_defs, ParamDef, IndentWriterBase | from gen_param_defs import member_defs, ParamDef, IndentWriterBase | ||||
# FIXME: move supportToString flag definition into the param def source file | |||||
ENUM_TO_STRING_SPECIAL_RULES = [ | |||||
("Elemwise", "Mode"), | |||||
("ElemwiseMultiType", "Mode") | |||||
] | |||||
class ConverterWriter(IndentWriterBase): | class ConverterWriter(IndentWriterBase): | ||||
_skip_current_param = False | _skip_current_param = False | ||||
@@ -86,7 +91,10 @@ class ConverterWriter(IndentWriterBase): | |||||
def format(v): | def format(v): | ||||
return '\"{}\"'.format(str(v)) | return '\"{}\"'.format(str(v)) | ||||
enum_def += ','.join(format(i) for i in e.members) | enum_def += ','.join(format(i) for i in e.members) | ||||
enum_def += "]>" | |||||
enum_def += "]" | |||||
if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)): | |||||
enum_def += ", 1" # whether generate ToStringTrait | |||||
enum_def += ">" | |||||
self._write("def {} : {};".format(td_class, enum_def)) | self._write("def {} : {};".format(td_class, enum_def)) | ||||
if self._skip_current_param: | if self._skip_current_param: | ||||
@@ -12,6 +12,7 @@ | |||||
#pragma once | #pragma once | ||||
#include "megbrain/imperative/op_def.h" | #include "megbrain/imperative/op_def.h" | ||||
#include "megbrain/imperative/utils/to_string.h" | |||||
#include "megdnn/opr_param_defs.h" | #include "megdnn/opr_param_defs.h" | ||||
#include "megbrain/opr/param_defs.h" | #include "megbrain/opr/param_defs.h" | ||||
@@ -179,6 +179,34 @@ static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) { | |||||
); | ); | ||||
} | } | ||||
static void gen_to_string_trait_for_enum(raw_ostream &os, MgbOp& op) { | |||||
for (auto &&i : op.getMgbAttributes()) { | |||||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||||
if (attr->supportToString()) { | |||||
std::vector<std::string> case_body; | |||||
std::string ename = formatv("{0}::{1}", | |||||
op.getCppClassName(), attr->getEnumName()); | |||||
llvm::for_each(attr->getEnumMembers(), [&](auto&& v){ | |||||
case_body.push_back(formatv( | |||||
"case {0}::{1}: return \"{1}\";", ename, v)); | |||||
}); | |||||
os << formatv(R"( | |||||
template <> | |||||
struct ToStringTrait<{0}> { | |||||
std::string operator()({0} e) const { | |||||
switch (e) { | |||||
{1} | |||||
default: | |||||
return "{0}::Unknown"; | |||||
} | |||||
} | |||||
}; | |||||
)", ename, llvm::join(case_body, "\n")); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | ||||
auto&& className = op.getCppClassName(); | auto&& className = op.getCppClassName(); | ||||
os << formatv( | os << formatv( | ||||
@@ -241,7 +269,13 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||||
os << formatv( | os << formatv( | ||||
"std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name") | "std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name") | ||||
); | ); | ||||
os << mlir::tblgen::tgfmt(hashable->getNameFunctionTemplate(), &ctx); | |||||
os << formatv( | |||||
" auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||||
" static_cast<void>(op_);\n", | |||||
className | |||||
); | |||||
ctx.withSelf("op_"); | |||||
os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx); | |||||
os << "}\n"; | os << "}\n"; | ||||
os << "} // anonymous namespace\n"; | os << "} // anonymous namespace\n"; | ||||
@@ -577,6 +611,7 @@ static void for_each_operator(raw_ostream &os, RecordKeeper &keeper, | |||||
static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) { | static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) { | ||||
for_each_operator(os, keeper, gen_op_def_c_header_single); | for_each_operator(os, keeper, gen_op_def_c_header_single); | ||||
for_each_operator(os, keeper, gen_to_string_trait_for_enum); | |||||
return false; | return false; | ||||
} | } | ||||
@@ -74,6 +74,9 @@ struct MgbEnumAttrMixin : public MgbAttrWrapperBase { | |||||
std::vector<StringRef> getEnumMembers() const { | std::vector<StringRef> getEnumMembers() const { | ||||
return getBaseRecord()->getValueAsListOfStrings("enumMembers"); | return getBaseRecord()->getValueAsListOfStrings("enumMembers"); | ||||
} | } | ||||
bool supportToString() const { | |||||
return getBaseRecord()->getValueAsBit("supportToString"); | |||||
} | |||||
}; | }; | ||||
struct MgbHashableAttrMixin : public MgbAttrWrapperBase { | struct MgbHashableAttrMixin : public MgbAttrWrapperBase { | ||||
@@ -170,6 +173,12 @@ public: | |||||
} | } | ||||
return ret; | return ret; | ||||
} | } | ||||
std::string getNameFunctionTemplate() const { | |||||
if (auto f = getDef().getValueAsOptionalString("nameFunction")) { | |||||
return f.getValue().str(); | |||||
} | |||||
return formatv(" return \"{0}\";\n", getCppClassName()); | |||||
} | |||||
}; | }; | ||||
struct MgbHashableOpMixin : public MgbOpBase { | struct MgbHashableOpMixin : public MgbOpBase { | ||||
@@ -241,30 +250,6 @@ private: | |||||
body += " return props_;\n"; | body += " return props_;\n"; | ||||
return body; | return body; | ||||
} | } | ||||
std::string getModeName() const { | |||||
std::string body = formatv( | |||||
" auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||||
" static_cast<void>(op_);\n", | |||||
getCppClassName() | |||||
); | |||||
for (auto&& it : getMgbAttributes()) { | |||||
if (it.name == "mode") { | |||||
auto* enumAttr = llvm::dyn_cast<MgbEnumAttrMixin>(&it.attr); | |||||
body += " switch (op_.mode){\n"; | |||||
for (auto&& enumMember: enumAttr->getEnumMembers()) { | |||||
body += formatv( | |||||
" case {0}::{1}::{2}:\n", | |||||
getCppClassName(), enumAttr->getEnumName(), enumMember | |||||
); | |||||
body += formatv(" return \"{0}\";\n", enumMember); | |||||
} | |||||
body += formatv( | |||||
" default: return \"{0}::Unknown\";\n", getCppClassName()); | |||||
body += " }\n"; | |||||
} | |||||
} | |||||
return body; | |||||
} | |||||
public: | public: | ||||
static bool classof(const Operator* op) { | static bool classof(const Operator* op) { | ||||
return op->getDef().isSubClassOf("MgbHashableOpMixin"); | return op->getDef().isSubClassOf("MgbHashableOpMixin"); | ||||
@@ -288,12 +273,6 @@ public: | |||||
} | } | ||||
return getDefaultPropsFunction(); | return getDefaultPropsFunction(); | ||||
} | } | ||||
std::string getNameFunctionTemplate() const { | |||||
if (getDef().getValueAsBit("usingModeName")) { | |||||
return getModeName(); | |||||
} | |||||
return formatv(" return \"{0}\";\n", getCppClassName()); | |||||
} | |||||
}; | }; | ||||
} // namespace tblgen | } // namespace tblgen | ||||
@@ -33,10 +33,11 @@ class MgbHashableAttrMixin { | |||||
string reprFunction = "std::to_string($0)"; | string reprFunction = "std::to_string($0)"; | ||||
} | } | ||||
class MgbEnumAttrMixin<string namespace, string name, list<string> members> { | |||||
class MgbEnumAttrMixin<string namespace, string name, list<string> members, bit toString> { | |||||
string parentNamespace = namespace; | string parentNamespace = namespace; | ||||
string enumName = name; | string enumName = name; | ||||
list<string> enumMembers = members; | list<string> enumMembers = members; | ||||
bit supportToString = toString; | |||||
} | } | ||||
class MgbAttrWrapper; | class MgbAttrWrapper; | ||||
@@ -165,8 +166,8 @@ class MgbTupleAttr<list<MgbAttrWrapper> args>: | |||||
} | } | ||||
// -- enum types | // -- enum types | ||||
class MgbEnumAttr<string namespace, string enumName, list<string> members>: | |||||
HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members> { | |||||
class MgbEnumAttr<string namespace, string enumName, list<string> members, bit toString=0>: | |||||
HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members, toString> { | |||||
let storageType = "::mlir::IntegerAttr"; | let storageType = "::mlir::IntegerAttr"; | ||||
let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; | let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; | ||||
let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))"; | let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))"; | ||||
@@ -242,7 +243,6 @@ class MgbPackedParamBase<string className, string accessor>: | |||||
class MgbHashableOpMixin { | class MgbHashableOpMixin { | ||||
string hashFunction = ?; | string hashFunction = ?; | ||||
string cmpFunction = ?; | string cmpFunction = ?; | ||||
bit usingModeName = 0; | |||||
} | } | ||||
class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>: | class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>: | ||||
@@ -251,6 +251,7 @@ class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits= | |||||
dag extraArguments = (ins); | dag extraArguments = (ins); | ||||
// TODO: remove it | // TODO: remove it | ||||
code extraOpdefDecl = ?; | code extraOpdefDecl = ?; | ||||
code nameFunction = ?; | |||||
let arguments = !con( | let arguments = !con( | ||||
!foldl(inputs, params, args, param, !con(args, param.fields)), | !foldl(inputs, params, args, param, !con(args, param.fields)), | ||||
@@ -21,7 +21,9 @@ include "mlir/Interfaces/SideEffectInterfaces.td" | |||||
def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> { | def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> { | ||||
let inputs = (ins Variadic<AnyType>:$input); | let inputs = (ins Variadic<AnyType>:$input); | ||||
let results = (outs AnyType); | let results = (outs AnyType); | ||||
let usingModeName = 1; | |||||
let nameFunction = [{ | |||||
return to_string($_self.mode); | |||||
}]; | |||||
} | } | ||||
def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>; | def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>; | ||||
@@ -248,7 +250,9 @@ def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypePara | |||||
let extraArguments = (ins | let extraArguments = (ins | ||||
MgbDTypeAttr:$dtype | MgbDTypeAttr:$dtype | ||||
); | ); | ||||
let usingModeName = 1; | |||||
let nameFunction = [{ | |||||
return to_string($_self.mode); | |||||
}]; | |||||
} | } | ||||
def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>; | def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>; | ||||