#include #include #include #include "./helper.h" using llvm::raw_ostream; using llvm::RecordKeeper; enum ActionType { None, CppHeader, CppBody, Pybind }; // NOLINTNEXTLINE llvm::cl::opt action( llvm::cl::desc("Action to perform:"), llvm::cl::values(clEnumValN(CppHeader, "gen-cpp-header", "Generate operator cpp header"), clEnumValN(CppBody, "gen-cpp-body", "Generate operator cpp body"), clEnumValN(Pybind, "gen-python-binding", "Generate pybind11 python bindings"))); using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase; using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin; using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin; using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin; using MgbOp = mlir::tblgen::MgbOpBase; using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin; llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) { // Note: we have already registered the corresponding attr wrappers // for following basic ctypes so we needn't handle them here /* auto&& attr_type_name = attr.getAttrDefName(); if (attr_type_name == "UI32Attr") { return "uint32_t"; } if (attr_type_name == "UI64Attr") { return "uint64_t"; } if (attr_type_name == "I32Attr") { return "int32_t"; } if (attr_type_name == "F32Attr") { return "float"; } if (attr_type_name == "F64Attr") { return "double"; } if (attr_type_name == "StrAttr") { return "std::string"; } if (attr_type_name == "BoolAttr") { return "bool"; }*/ auto&& attr = llvm::cast(attr_); if (auto e = llvm::dyn_cast(&attr)) { return e->getEnumName(); } return attr.getUnderlyingType(); } static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) { os << formatv( "class {0} : public OpDefImplBase<{0}> {{\n" " MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n" "public:\n", op.getCppClassName() ); // handle enum alias for (auto &&i : op.getMgbAttributes()) { if (auto attr = llvm::dyn_cast(&i.attr)) { os << formatv( " using {0} = {1};\n", attr->getEnumName(), attr->getUnderlyingType() ); } } for (auto &&i : op.getMgbAttributes()) { auto defaultValue = i.attr.getDefaultValue().str(); if (!defaultValue.empty()) { defaultValue = formatv(" = {0}", defaultValue); } os << formatv( " {0} {1}{2};\n", attr_to_ctype(i.attr), i.name, defaultValue ); } auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) { os << formatv( " {0}({1}){2}{3}\n", op.getCppClassName(), paramList, memInitList, body ); }; gen_ctor("", "", " = default;"); if (!op.getMgbAttributes().empty()) { std::vector paramList, initList; for (auto &&i : op.getMgbAttributes()) { paramList.push_back(formatv( "{0} {1}_", attr_to_ctype(i.attr), i.name )); initList.push_back(formatv( "{0}({0}_)", i.name )); } gen_ctor(llvm::join(paramList, ", "), ": " + llvm::join(initList, ", "), " {}"); } auto packedParams = op.getPackedParams(); if (!packedParams.empty()) { std::vector paramList, initList; for (auto &&p : packedParams) { auto&& paramFields = p.getFields(); auto&& paramType = p.getFullName(); auto&& paramName = formatv("packed_param_{0}", paramList.size()); paramList.push_back( paramFields.empty() ? paramType.str() : formatv("{0} {1}", paramType, paramName) ); for (auto&& i : paramFields) { initList.push_back(formatv( "{0}({1}.{0})", i.name, paramName )); } } for (auto&& i : op.getExtraArguments()) { paramList.push_back(formatv( "{0} {1}_", attr_to_ctype(i.attr), i.name )); initList.push_back(formatv( "{0}({0}_)", i.name )); } gen_ctor(llvm::join(paramList, ", "), initList.empty() ? "" : ": " + llvm::join(initList, ", "), " {}"); } if (!packedParams.empty()) { for (auto&& p : packedParams) { auto accessor = p.getAccessor(); if (!accessor.empty()) { os << formatv( " {0} {1}() const {{\n", p.getFullName(), accessor ); std::vector fields; for (auto&& i : p.getFields()) { fields.push_back(i.name); } os << formatv( " return {{{0}};\n", llvm::join(fields, ", ") ); os << " }\n"; } } } if (auto decl = op.getExtraOpdefDecl()) { os << decl.getValue(); } os << formatv( "};\n\n" ); } static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { auto&& className = op.getCppClassName(); os << formatv( "MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className ); auto formatMethImpl = [&](auto&& meth) { return formatv( "{0}_{1}_impl", className, meth ); }; std::vector methods; if (auto hashable = llvm::dyn_cast(&op)) { os << "namespace {\n"; // generate hash() mlir::tblgen::FmtContext ctx; os << formatv( "size_t {0}(const OpDef& def_) {{\n", formatMethImpl("hash") ); os << formatv( " auto op_ = def_.cast_final_safe<{0}>();\n" " static_cast(op_);\n", className ); ctx.withSelf("op_"); os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx); os << "}\n"; // generate is_same_st() os << formatv( "bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n", formatMethImpl("is_same_st") ); os << formatv( " auto a_ = lhs_.cast_final_safe<{0}>(),\n" " b_ = rhs_.cast_final_safe<{0}>();\n" " static_cast(a_);\n" " static_cast(b_);\n", className ); os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_"); os << "}\n"; os << "} // anonymous namespace\n"; methods.push_back("hash"); methods.push_back("is_same_st"); } if (!methods.empty()) { os << formatv( "OP_TRAIT_REG({0}, {0})", op.getCppClassName() ); for (auto&& i : methods) { os << formatv( "\n .{0}({1})", i, formatMethImpl(i) ); } os << ";\n\n"; } } struct PybindContext { std::unordered_map enumAlias; }; static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext& ctx) { auto class_name = op.getCppClassName(); os << formatv( "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n", class_name ); for (auto&& i : op.getMgbAttributes()) { if (auto attr = llvm::dyn_cast(&i.attr)) { unsigned int enumID; if (auto alias = llvm::dyn_cast(attr)) { auto&& aliasBase = alias->getAliasBase(); enumID = llvm::cast(aliasBase) .getBaseRecord()->getID(); } else { enumID = attr->getBaseRecord()->getID(); } auto&& enumAlias = ctx.enumAlias; auto&& iter = enumAlias.find(enumID); if (iter == enumAlias.end()) { os << formatv( "py::enum_<{0}::{1}>({0}Inst, \"{1}\")", class_name, attr->getEnumName() ); std::vector body; for (auto&& i: attr->getEnumMembers()) { os << formatv( "\n .value(\"{2}\", {0}::{1}::{2})", class_name, attr->getEnumName(), i ); body.push_back(formatv( "if (str == \"{2}\") return {0}::{1}::{2};", class_name, attr->getEnumName(), i )); } os << formatv( "\n .def(py::init([](const std::string& in) {" "\n auto&& str = normalize_enum(in);" "\n {0}" "\n throw py::cast_error(\"invalid enum value \" + in);" "\n }));\n", llvm::join(body, "\n ") ); os << formatv( "py::implicitly_convertible();\n\n", class_name, attr->getEnumName() ); enumAlias.emplace(enumID, formatv( "{0}Inst.attr(\"{1}\")", class_name, attr->getEnumName() )); } else { os << formatv( "{0}Inst.attr(\"{1}\") = {2};\n\n", class_name, attr->getEnumName(), iter->second ); } } } // generate op class binding os << formatv("{0}Inst", class_name); bool hasDefaultCtor = op.getMgbAttributes().empty(); if (!hasDefaultCtor) { os << "\n .def(py::init<"; std::vector targs; for (auto &&i : op.getMgbAttributes()) { targs.push_back(i.attr.getReturnType()); } os << llvm::join(targs, ", "); os << ">()"; for (auto &&i : op.getMgbAttributes()) { os << formatv(", py::arg(\"{0}\")", i.name); auto defaultValue = i.attr.getDefaultValue(); if (!defaultValue.empty()) { os << formatv(" = {0}", defaultValue); } else { hasDefaultCtor = true; } } os << ")"; } if (hasDefaultCtor) { os << "\n .def(py::init<>())"; } for (auto &&i : op.getMgbAttributes()) { os << formatv( "\n .def_readwrite(\"{0}\", &{1}::{0})", i.name, class_name ); } os << ";\n\n"; } static void for_each_operator(raw_ostream &os, RecordKeeper &keeper, std::function callback) { auto op_base_class = keeper.getClass("Op"); ASSERT(op_base_class, "could not find base class Op"); for (auto&& i: keeper.getDefs()) { auto&& r = i.second; if (r->isSubClassOf(op_base_class)) { auto op = mlir::tblgen::Operator(r.get()); if (op.getDialectName().str() == "mgb") { std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl; callback(os, llvm::cast(op)); } } } } static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) { for_each_operator(os, keeper, gen_op_def_c_header_single); return false; } static bool gen_op_def_c_body(raw_ostream &os, RecordKeeper &keeper) { for_each_operator(os, keeper, gen_op_def_c_body_single); return false; } static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) { PybindContext ctx; using namespace std::placeholders; for_each_operator(os, keeper, std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx))); return false; } int main(int argc, char **argv) { llvm::InitLLVM y(argc, argv); llvm::cl::ParseCommandLineOptions(argc, argv); if (action == ActionType::CppHeader) { return TableGenMain(argv[0], &gen_op_def_c_header); } if (action == ActionType::CppBody) { return TableGenMain(argv[0], &gen_op_def_c_body); } if (action == ActionType::Pybind) { return TableGenMain(argv[0], &gen_op_def_pybind11); } return -1; }