#include #include #include #include #include #include #include #include #include "../emitter.h" #include "python_c_extension.h" namespace mlir::tblgen { namespace { class TypeInfo; std::pair parse_type(const std::string&, const int); std::pair, int> parse_namespace(const std::string&, const int); struct Unit {}; Unit unit; struct ParseError {}; class TypeInfo { public: TypeInfo(std::string name) : name(name) {} std::string to_python_type_string() { std::stringstream ss; ss << translate_type_name(name); if (params.size() > 0) { ss << "[" << params[0].to_python_type_string(); for (auto i = 1; i < params.size(); i++) { ss << ", " << params[i].to_python_type_string(); } ss << "]"; } return ss.str(); } std::string translate_type_name(const std::string& cppTypeName) { auto res = translation.find(cppTypeName); if (res != translation.end()) return res->second; try { auto segments = parse_namespace(cppTypeName, 0).first; // special rules if (segments.size() > 3 && segments[0] == "megdnn" && segments[1] == "param") { segments.erase(segments.begin(), segments.begin() + 3); } else if ( segments.size() == 2 && segments[0] == "megdnn" && segments[1] == "DType") { segments.erase(segments.begin(), segments.begin() + 1); segments[0] = "str"; } else if ( segments.size() == 2 && segments[0] == "mgb" && segments[1] == "CompNode") { segments.erase(segments.begin(), segments.begin() + 1); segments[0] = "str"; } std::stringstream joined; joined << segments[0]; for (auto i = 1; i < segments.size(); i++) { joined << "." << segments[i]; } return joined.str(); } catch (ParseError) { return cppTypeName; } } std::string name; std::vector params; private: static const std::unordered_map translation; }; const std::unordered_map TypeInfo::translation = { {"bool", "bool"}, {"double", "float"}, {"float", "float"}, {"int32_t", "int"}, {"int8_t", "int"}, {"size_t", "int"}, {"std::string", "str"}, {"std::tuple", "tuple"}, {"std::vector", "list"}, {"uint32_t", "int"}, {"uint64_t", "int"}, }; // a parser takes: // 1. a string to parse // 2. location to parse from (index of character) // returns: // 1. parsing result (type T) // 2. end location of substring which is consumed by parsing // throws exception when failed to parse template using Parser = std::function(const std::string&, const int)>; std::pair parse_blank(const std::string& text, const int begin) { auto now = begin; while (now < text.length() && isblank(text[now])) now += 1; return {unit, now}; } Parser parse_non_blank_char(char ch) { return [=](const std::string& text, const int begin) -> std::pair { auto blankEnd = parse_blank(text, begin).second; if (blankEnd >= text.length() || text[blankEnd] != ch) throw ParseError{}; return {unit, blankEnd + 1}; }; } Parser parse_allowed_chars(std::function allow) { return [=](const std::string& text, const int begin) -> std::pair { auto now = begin; while (now < text.length() && allow(text[now])) now += 1; return {text.substr(begin, now - begin), now}; }; } template Parser> parse_seq(Parser only) { return [=](const std::string& text, const int begin) -> std::pair, int> { auto res = only(text, begin); return {{res.first}, res.second}; }; } template Parser> parse_seq(Parser head, Parser... tail) { return [=](const std::string& text, const int begin) -> std::pair, int> { std::pair headRes = head(text, begin); std::pair, int> tailRes = parse_seq(tail...)(text, headRes.second); return {std::tuple_cat(std::tuple(headRes.first), tailRes.first), tailRes.second}; }; } template Parser> parse_many_at_least0(Parser one) { return [=](const std::string& text, const int begin) -> std::pair, int> { std::vector ret; auto now = begin; try { while (true) { auto oneRes = one(text, now); ret.emplace_back(oneRes.first); now = oneRes.second; } } catch (ParseError) { } return {ret, now}; }; } template Parser> parse_sep_by_at_least1( Parser separator, Parser component) { return [=](const std::string& text, const int begin) -> std::pair, int> { std::vector ret; auto headRes = component(text, begin); ret.emplace_back(headRes.first); auto tailRes = parse_many_at_least0(parse_seq(separator, component))( text, headRes.second); for (const auto& elem : tailRes.first) { ret.emplace_back(std::get<1>(elem)); } return {ret, tailRes.second}; }; } std::pair parse_identifier(const std::string& text, const int begin) { auto blankEnd = parse_blank(text, begin).second; auto indentRes = parse_allowed_chars( [](char ch) { return std::isalnum(ch) || ch == '_'; })(text, blankEnd); if (indentRes.first.empty()) throw ParseError{}; return indentRes; }; std::pair parse_qualified(const std::string& text, const int begin) { auto blankEnd = parse_blank(text, begin).second; auto indentRes = parse_allowed_chars([](char ch) { return std::isalnum(ch) || ch == '_' || ch == ':'; })(text, blankEnd); if (indentRes.first.empty()) throw ParseError{}; return indentRes; }; std::pair, int> parse_namespace( const std::string& text, const int begin) { auto res = parse_many_at_least0(parse_seq( parse_non_blank_char(':'), parse_non_blank_char(':'), Parser(parse_identifier)))(text, begin); std::vector ret; for (const auto& elem : res.first) { ret.emplace_back(std::get<2>(elem)); } return {ret, res.second}; } std::pair parse_leaf_type(const std::string& text, const int begin) { auto ret = parse_qualified(text, begin); return {TypeInfo(ret.first), ret.second}; }; std::pair parse_node_type(const std::string& text, const int begin) { auto nameRes = parse_qualified(text, begin); auto ret = TypeInfo(nameRes.first); auto now = parse_non_blank_char('<')(text, nameRes.second).second; auto argsRes = parse_sep_by_at_least1( parse_non_blank_char(','), Parser(parse_type))(text, now); ret.params = argsRes.first; now = parse_non_blank_char('>')(text, argsRes.second).second; return {ret, now}; }; std::pair parse_type(const std::string& text, const int begin) { try { return parse_node_type(text, begin); } catch (ParseError) { } return parse_leaf_type(text, begin); }; std::string cpp_type_to_python_type(const std::string& input) { auto res = parse_type(input, 0); return res.first.to_python_type_string(); } struct Initproc { std::string func; Initproc(std::string&& s) : func(std::move(s)) {} std::string operator()(std::string argument) { return formatv("{0}({1})", func, argument); } }; class OpDefEmitter : public EmitterBase { public: OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_) : EmitterBase(os_, env_), op(op_) { ctx.withSelf(op.getCppClassName()); } Initproc emit(); private: void emit_class(); void emit_py_init(); void emit_py_getsetters(); void emit_py_methods(); void emit_py_init_proxy(); void emit_py_init_methoddef( const std::unordered_map>& enum_attr_members); Initproc emit_initproc(); MgbOp& op; std::vector subclasses; mlir::tblgen::FmtContext ctx; }; class EnumAttrEmitter : public EmitterBase { public: EnumAttrEmitter( llvm::StringRef parent, MgbEnumAttr* attr_, raw_ostream& os_, Environment& env_) : EmitterBase(os_, env_), attr(attr_) { unsigned int enumID; if (auto alias = llvm::dyn_cast(attr)) { auto&& aliasBase = alias->getAliasBase(); enumID = llvm::cast(aliasBase).getBaseRecord()->getID(); } else { enumID = attr->getBaseRecord()->getID(); } ctx.addSubst( "enumTpl", attr->getEnumCombinedFlag() ? "BitCombinedEnumWrapper" : "EnumWrapper"); ctx.addSubst("opClass", parent); ctx.addSubst("enumClass", attr->getEnumName()); firstOccur = env().enumAlias .emplace(enumID, std::make_pair(parent, attr->getEnumName())) .second; } Initproc emit(); protected: void emit_trait(); void emit_tpl_spl(); Initproc emit_initproc(); MgbEnumAttr* attr; bool firstOccur; mlir::tblgen::FmtContext ctx; }; Initproc EnumAttrEmitter::emit() { emit_trait(); emit_tpl_spl(); return emit_initproc(); } void EnumAttrEmitter::emit_trait() { if (!firstOccur) return; auto enumMax = [&] { if (attr->getEnumCombinedFlag()) { return formatv("(1llu << {0}) - 1", attr->getEnumMembers().size()); } else { return formatv("{0} - 1", attr->getEnumMembers().size()); } }; os << tgfmt( R"( template<> struct EnumTrait<$opClass::$enumClass> { static constexpr const char *name = "$opClass.$enumClass"; static constexpr std::underlying_type_t<$opClass::$enumClass> max = $0; }; )", &ctx, enumMax()); } void EnumAttrEmitter::emit_tpl_spl() { if (!firstOccur) return; os << tgfmt( "template<> PyTypeObject* $enumTpl<$opClass::$enumClass>::type = " "nullptr;\n", &ctx); auto quote = [&](auto&& i) -> std::string { size_t d1 = i.find(' '); size_t d2 = i.find('='); size_t d = d1 <= d2 ? d1 : d2; return formatv("\"{0}\"", i.substr(0, d)); }; os << tgfmt( R"( template<> const char* $enumTpl<$opClass::$enumClass>::members[] = {$0}; )", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), quote), ", ")); auto mem2value = [&](auto&& i) -> std::string { size_t d1 = i.find(' '); size_t d2 = i.find('='); size_t d = d1 <= d2 ? d1 : d2; return tgfmt( "{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx, i.substr(0, d)); }; os << tgfmt( R"( template<> std::unordered_map $enumTpl<$opClass::$enumClass>::mem2value = {$0}; )", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), mem2value), ", ")); os << tgfmt( "template<> PyObject* " "$enumTpl<$opClass::$enumClass>::pyobj_insts[$0] = {nullptr};\n", &ctx, attr->getEnumMembers().size()); } Initproc EnumAttrEmitter::emit_initproc() { std::string initproc = formatv("_init_py_{0}_{1}", ctx.getSubstFor("opClass"), ctx.getSubstFor("enumClass")); os << tgfmt( R"( void $0(PyTypeObject& py_type) { auto& e_type = $enumTpl<$opClass::$enumClass>::type; )", &ctx, initproc); if (firstOccur) { os << tgfmt( R"( static PyMethodDef tp_methods[] = { {const_cast("dump"), (PyCFunction)$enumTpl<$opClass::$enumClass>::py_dump, METH_NOARGS, NULL}, {NULL} /* Sentinel */ }; )", &ctx); os << tgfmt( R"( static PyType_Slot slots[] = { {Py_tp_repr, (void*)$enumTpl<$opClass::$enumClass>::py_repr}, {Py_tp_richcompare, (void*)$enumTpl<$opClass::$enumClass>::tp_richcompare}, {Py_tp_methods, tp_methods}, )", &ctx); if (attr->getEnumCombinedFlag()) { // only bit combined enum could new instance because bitwise operation, // others should always use singleton os << tgfmt( R"( {Py_tp_new, (void*)$enumTpl<$opClass::$enumClass>::py_new_combined_enum}, {Py_nb_or, (void*)$enumTpl<$opClass::$enumClass>::py_or}, {Py_nb_and, (void*)$enumTpl<$opClass::$enumClass>::py_and}, )", &ctx); } os << R"( {0, NULL} };)"; os << tgfmt( R"( static PyType_Spec spec = { // name "megengine.core._imperative_rt.ops.$opClass.$enumClass", // basicsize sizeof($enumTpl<$opClass::$enumClass>), // itemsize 0, // flags Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, // slots slots };)", &ctx); os << tgfmt( R"( e_type = reinterpret_cast(PyType_FromSpec(&spec)); )", &ctx); for (auto&& i : {std::pair{ "__name__", tgfmt("$enumClass", &ctx)}, {"__module__", "megengine.core._imperative_rt.ops"}, {"__qualname__", tgfmt("$opClass.$enumClass", &ctx)}}) { os << formatv( R"( mgb_assert( e_type->tp_setattro( reinterpret_cast(e_type), py::cast("{0}").release().ptr(), py::cast("{1}").release().ptr()) >= 0); )", i.first, i.second); } auto&& members = attr->getEnumMembers(); for (size_t idx = 0; idx < members.size(); ++idx) { size_t d1 = members[idx].find(' '); size_t d2 = members[idx].find('='); size_t d = d1 <= d2 ? d1 : d2; os << tgfmt( R"({ PyObject* inst = e_type->tp_alloc(e_type, 0); reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0; mgb_assert(PyDict_SetItemString(e_type->tp_dict, "$0", inst) >= 0); $enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst; })", &ctx, members[idx].substr(0, d), idx); } } os << tgfmt( R"( Py_INCREF(e_type); mgb_assert(PyDict_SetItemString( py_type.tp_dict, "$enumClass", reinterpret_cast(e_type)) >= 0); )", &ctx); os << "}\n"; return initproc; } Initproc OpDefEmitter::emit() { std::unordered_map> enum_attr_members; for (auto&& i : op.getMgbAttributes()) { if (auto attr = llvm::dyn_cast(&i.attr)) { subclasses.push_back( EnumAttrEmitter(op.getCppClassName(), attr, os, env()).emit()); auto retType = cpp_type_to_python_type(std::string(attr->getReturnType())); enum_attr_members[retType] = std::vector(); for (const auto& member : attr->getEnumMembers()) { enum_attr_members[retType].emplace_back(member); } } } emit_class(); emit_py_init(); emit_py_getsetters(); emit_py_methods(); emit_py_init_proxy(); emit_py_init_methoddef(enum_attr_members); return emit_initproc(); } void OpDefEmitter::emit_class() { auto&& className = op.getCppClassName(); std::string method_defs; std::vector body; llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { body.push_back( formatv(R"( {{"{0}", serialization::dump(opdef.{0})})", attr.name)); }); method_defs += formatv(R"( static PyObject* getstate(PyObject* self, PyObject*) {{ auto& opdef = reinterpret_cast(self)->inst(); static_cast(opdef); std::unordered_map state {{ {1} }; return py::cast(state).release().ptr(); })", className, llvm::join(body, ",")); body.clear(); llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { body.push_back( formatv(R"( {{ auto&& iter = state.find("{0}"); if (iter != state.end()) { opdef.{0} = serialization::load(iter->second); } })", attr.name)); }); method_defs += formatv(R"( static PyObject* setstate(PyObject* self, PyObject* args) {{ PyObject* dict = PyTuple_GetItem(args, 0); if (!dict) return NULL; auto state = py::cast>(dict); auto& opdef = reinterpret_cast(self)->inst(); static_cast(opdef); {1} Py_RETURN_NONE; })", className, llvm::join(body, "\n")); os << tgfmt( R"( PyOpDefBegin($_self) // { static PyGetSetDef py_getsetters[]; static PyMethodDef tp_methods[]; $0 static int py_init(PyObject *self, PyObject *args, PyObject *kwds); static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds); static PyMethodDef py_init_methoddef; // }; PyOpDefEnd($_self) )", &ctx, method_defs); } void OpDefEmitter::emit_py_init() { std::string initBody; if (!op.getMgbAttributes().empty()) { initBody += "static const char* kwlist[] = {"; std::vector attr_name_list; llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { attr_name_list.push_back(attr.name); }); attr_name_list.push_back("scope"); llvm::for_each(attr_name_list, [&](auto&& attr) { initBody += formatv("\"{0}\", ", attr); }); initBody += "NULL};\n"; initBody += " PyObject "; auto initializer = [&](auto&& attr) -> std::string { return formatv("*{0} = NULL", attr); }; initBody += llvm::join(llvm::map_range(attr_name_list, initializer), ", ") + ";\n"; initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|"; // an extra slot created for name initBody += std::string(attr_name_list.size(), 'O'); initBody += "\", const_cast(kwlist)"; llvm::for_each(attr_name_list, [&](auto&& attr) { initBody += formatv(", &{0}", attr); }); initBody += "))\n"; initBody += " return -1;\n"; llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { initBody += tgfmt(R"( if ($0) { try { // TODO: remove this guard which is used for pybind11 implicit conversion py::detail::loader_life_support guard{}; reinterpret_cast(self)->inst().$0 = py::cast(py::handle($0)); } CATCH_ALL(-1) } )", &ctx, attr.name); }); initBody += tgfmt(R"( if (scope) { try { reinterpret_cast(self)->op ->set_scope(py::cast(py::handle(scope))); } CATCH_ALL(-1) } )", &ctx); } initBody += "\n return 0;"; os << tgfmt( R"( int PyOp($_self)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { $0 } )", &ctx, initBody); } void OpDefEmitter::emit_py_getsetters() { auto f = [&](auto&& attr) -> std::string { return tgfmt( "{const_cast(\"$0\"), py_get_generic($_self, $0), " "py_set_generic($_self, $0), const_cast(\"$0\"), NULL},", &ctx, attr.name); }; os << tgfmt( R"( PyGetSetDef PyOp($_self)::py_getsetters[] = { $0 {NULL} /* Sentinel */ }; )", &ctx, llvm::join(llvm::map_range(op.getMgbAttributes(), f), "\n ")); } void OpDefEmitter::emit_py_methods() { // generate methods std::string method_defs; std::vector method_items; { auto&& className = op.getCppClassName(); // generate getstate method_items.push_back( formatv("{{const_cast(\"__getstate__\"), PyOp({0})::getstate, " "METH_NOARGS, \"{0} getstate\"},", className)); // generate setstate method_items.push_back( formatv("{{const_cast(\"__setstate__\"), PyOp({0})::setstate, " "METH_VARARGS, \"{0} setstate\"},", className)); } os << tgfmt( R"( PyMethodDef PyOp($_self)::tp_methods[] = { $0 {NULL} /* Sentinel */ }; )", &ctx, llvm::join(method_items, "\n ")); } void OpDefEmitter::emit_py_init_proxy() { os << tgfmt( R"( PyObject *PyOp($_self)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) { if (PyOp($_self)::py_init(self, args, kwds) < 0) { return NULL; } Py_RETURN_NONE; } )", &ctx); } void OpDefEmitter::emit_py_init_methoddef( const std::unordered_map>& enum_attr_members) { std::string docstring = "__init__(self"; for (const auto& attr : op.getMgbAttributes()) { if (attr.name == "workspace_limit") continue; auto pyType = cpp_type_to_python_type(std::string(attr.attr.getReturnType())); auto findRes = enum_attr_members.find(pyType); if (findRes != enum_attr_members.end()) { pyType = formatv("Union[str, {0}]", pyType); // TODO stubgen cannot handle Literal strings for now // auto members = findRes->second; // std::string enumTypeString = "Literal["; // enumTypeString += formatv("'{0}'", lowercase(members[0])); // for (auto i = 1; i < members.size(); i++) { // enumTypeString += formatv(", '{0}'", lowercase(members[i])); // } // enumTypeString += "]"; // pyType = enumTypeString; } docstring += formatv(", {0}: {1} = ...", attr.name, pyType); } docstring += ") -> None\\n"; os << tgfmt( R"( PyMethodDef PyOp($_self)::py_init_methoddef = { "__init__", (PyCFunction)PyOp($_self)::py_init_proxy, METH_VARARGS | METH_KEYWORDS, "$0" }; )", &ctx, docstring); } Initproc OpDefEmitter::emit_initproc() { std::string initproc = formatv("_init_py_{0}", op.getCppClassName()); std::string subclass_init_call; for (auto&& i : subclasses) { subclass_init_call += formatv(" {0};\n", i("py_type")); } os << tgfmt( R"( void $0(py::module m) { using py_op = PyOp($_self); auto& py_type = PyOpType($_self); py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; py_type.tp_name = "megengine.core._imperative_rt.ops.$_self"; py_type.tp_basicsize = sizeof(PyOp($_self)); py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; py_type.tp_doc = "$_self"; py_type.tp_base = &PyOpType(OpDef); py_type.tp_dealloc = py_dealloc_generic; py_type.tp_new = py_new_generic; py_type.tp_init = py_op::py_init; py_type.tp_methods = py_op::tp_methods; py_type.tp_getset = py_op::py_getsetters; py_type.tp_dict = PyDict_New(); PyObject* descr = PyDescr_NewMethod(&PyOpType($_self), &PyOp($_self)::py_init_methoddef); PyDict_SetItemString(py_type.tp_dict, "__init__", descr); mgb_assert(PyType_Ready(&py_type) >= 0); $1 PyType_Modified(&py_type); m.add_object("$_self", reinterpret_cast(&py_type)); mgb_assert(PyOp(OpDef)::ctype2pytype.emplace($_self::typeinfo(), &py_type).second); } )", &ctx, initproc, subclass_init_call); return initproc; } } // namespace bool gen_op_def_python_c_extension(raw_ostream& os, llvm::RecordKeeper& keeper) { Environment env; using namespace std::placeholders; std::vector initprocs; foreach_operator(keeper, [&](MgbOp& op) { initprocs.emplace_back(OpDefEmitter(op, os, env).emit()); }); os << "#define INIT_ALL_OP(m)"; for (auto&& init : initprocs) { os << formatv(" \\\n {0};", init("m")); } os << "\n"; return false; } } // namespace mlir::tblgen