GitOrigin-RevId: f47ceae726
tags/v1.3.0
@@ -11,6 +11,11 @@ import io | |||
from gen_param_defs import member_defs, ParamDef, IndentWriterBase | |||
# FIXME: move supportToString flag definition into the param def source file | |||
ENUM_TO_STRING_SPECIAL_RULES = [ | |||
("Elemwise", "Mode"), | |||
("ElemwiseMultiType", "Mode") | |||
] | |||
class ConverterWriter(IndentWriterBase): | |||
_skip_current_param = False | |||
@@ -86,7 +91,10 @@ class ConverterWriter(IndentWriterBase): | |||
def format(v): | |||
return '\"{}\"'.format(str(v)) | |||
enum_def += ','.join(format(i) for i in e.members) | |||
enum_def += "]>" | |||
enum_def += "]" | |||
if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)): | |||
enum_def += ", 1" # whether generate ToStringTrait | |||
enum_def += ">" | |||
self._write("def {} : {};".format(td_class, enum_def)) | |||
if self._skip_current_param: | |||
@@ -12,6 +12,7 @@ | |||
#pragma once | |||
#include "megbrain/imperative/op_def.h" | |||
#include "megbrain/imperative/utils/to_string.h" | |||
#include "megdnn/opr_param_defs.h" | |||
#include "megbrain/opr/param_defs.h" | |||
@@ -179,6 +179,34 @@ static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) { | |||
); | |||
} | |||
static void gen_to_string_trait_for_enum(raw_ostream &os, MgbOp& op) { | |||
for (auto &&i : op.getMgbAttributes()) { | |||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
if (attr->supportToString()) { | |||
std::vector<std::string> case_body; | |||
std::string ename = formatv("{0}::{1}", | |||
op.getCppClassName(), attr->getEnumName()); | |||
llvm::for_each(attr->getEnumMembers(), [&](auto&& v){ | |||
case_body.push_back(formatv( | |||
"case {0}::{1}: return \"{1}\";", ename, v)); | |||
}); | |||
os << formatv(R"( | |||
template <> | |||
struct ToStringTrait<{0}> { | |||
std::string operator()({0} e) const { | |||
switch (e) { | |||
{1} | |||
default: | |||
return "{0}::Unknown"; | |||
} | |||
} | |||
}; | |||
)", ename, llvm::join(case_body, "\n")); | |||
} | |||
} | |||
} | |||
} | |||
static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||
auto&& className = op.getCppClassName(); | |||
os << formatv( | |||
@@ -241,7 +269,13 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||
os << formatv( | |||
"std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name") | |||
); | |||
os << mlir::tblgen::tgfmt(hashable->getNameFunctionTemplate(), &ctx); | |||
os << formatv( | |||
" auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
" static_cast<void>(op_);\n", | |||
className | |||
); | |||
ctx.withSelf("op_"); | |||
os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx); | |||
os << "}\n"; | |||
os << "} // anonymous namespace\n"; | |||
@@ -577,6 +611,7 @@ static void for_each_operator(raw_ostream &os, RecordKeeper &keeper, | |||
static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) { | |||
for_each_operator(os, keeper, gen_op_def_c_header_single); | |||
for_each_operator(os, keeper, gen_to_string_trait_for_enum); | |||
return false; | |||
} | |||
@@ -74,6 +74,9 @@ struct MgbEnumAttrMixin : public MgbAttrWrapperBase { | |||
std::vector<StringRef> getEnumMembers() const { | |||
return getBaseRecord()->getValueAsListOfStrings("enumMembers"); | |||
} | |||
bool supportToString() const { | |||
return getBaseRecord()->getValueAsBit("supportToString"); | |||
} | |||
}; | |||
struct MgbHashableAttrMixin : public MgbAttrWrapperBase { | |||
@@ -170,6 +173,12 @@ public: | |||
} | |||
return ret; | |||
} | |||
std::string getNameFunctionTemplate() const { | |||
if (auto f = getDef().getValueAsOptionalString("nameFunction")) { | |||
return f.getValue().str(); | |||
} | |||
return formatv(" return \"{0}\";\n", getCppClassName()); | |||
} | |||
}; | |||
struct MgbHashableOpMixin : public MgbOpBase { | |||
@@ -241,30 +250,6 @@ private: | |||
body += " return props_;\n"; | |||
return body; | |||
} | |||
std::string getModeName() const { | |||
std::string body = formatv( | |||
" auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
" static_cast<void>(op_);\n", | |||
getCppClassName() | |||
); | |||
for (auto&& it : getMgbAttributes()) { | |||
if (it.name == "mode") { | |||
auto* enumAttr = llvm::dyn_cast<MgbEnumAttrMixin>(&it.attr); | |||
body += " switch (op_.mode){\n"; | |||
for (auto&& enumMember: enumAttr->getEnumMembers()) { | |||
body += formatv( | |||
" case {0}::{1}::{2}:\n", | |||
getCppClassName(), enumAttr->getEnumName(), enumMember | |||
); | |||
body += formatv(" return \"{0}\";\n", enumMember); | |||
} | |||
body += formatv( | |||
" default: return \"{0}::Unknown\";\n", getCppClassName()); | |||
body += " }\n"; | |||
} | |||
} | |||
return body; | |||
} | |||
public: | |||
static bool classof(const Operator* op) { | |||
return op->getDef().isSubClassOf("MgbHashableOpMixin"); | |||
@@ -288,12 +273,6 @@ public: | |||
} | |||
return getDefaultPropsFunction(); | |||
} | |||
std::string getNameFunctionTemplate() const { | |||
if (getDef().getValueAsBit("usingModeName")) { | |||
return getModeName(); | |||
} | |||
return formatv(" return \"{0}\";\n", getCppClassName()); | |||
} | |||
}; | |||
} // namespace tblgen | |||
@@ -33,10 +33,11 @@ class MgbHashableAttrMixin { | |||
string reprFunction = "std::to_string($0)"; | |||
} | |||
class MgbEnumAttrMixin<string namespace, string name, list<string> members> { | |||
class MgbEnumAttrMixin<string namespace, string name, list<string> members, bit toString> { | |||
string parentNamespace = namespace; | |||
string enumName = name; | |||
list<string> enumMembers = members; | |||
bit supportToString = toString; | |||
} | |||
class MgbAttrWrapper; | |||
@@ -165,8 +166,8 @@ class MgbTupleAttr<list<MgbAttrWrapper> args>: | |||
} | |||
// -- enum types | |||
class MgbEnumAttr<string namespace, string enumName, list<string> members>: | |||
HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members> { | |||
class MgbEnumAttr<string namespace, string enumName, list<string> members, bit toString=0>: | |||
HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members, toString> { | |||
let storageType = "::mlir::IntegerAttr"; | |||
let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; | |||
let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))"; | |||
@@ -242,7 +243,6 @@ class MgbPackedParamBase<string className, string accessor>: | |||
class MgbHashableOpMixin { | |||
string hashFunction = ?; | |||
string cmpFunction = ?; | |||
bit usingModeName = 0; | |||
} | |||
class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>: | |||
@@ -251,6 +251,7 @@ class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits= | |||
dag extraArguments = (ins); | |||
// TODO: remove it | |||
code extraOpdefDecl = ?; | |||
code nameFunction = ?; | |||
let arguments = !con( | |||
!foldl(inputs, params, args, param, !con(args, param.fields)), | |||
@@ -21,7 +21,9 @@ include "mlir/Interfaces/SideEffectInterfaces.td" | |||
def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> { | |||
let inputs = (ins Variadic<AnyType>:$input); | |||
let results = (outs AnyType); | |||
let usingModeName = 1; | |||
let nameFunction = [{ | |||
return to_string($_self.mode); | |||
}]; | |||
} | |||
def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>; | |||
@@ -248,7 +250,9 @@ def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypePara | |||
let extraArguments = (ins | |||
MgbDTypeAttr:$dtype | |||
); | |||
let usingModeName = 1; | |||
let nameFunction = [{ | |||
return to_string($_self.mode); | |||
}]; | |||
} | |||
def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>; | |||