|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350 |
- /**
- * \file imperative/tablegen/targets/python_c_extension.cpp
- * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- *
- * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- */
-
- #include "python_c_extension.h"
- #include "../emitter.h"
-
- namespace mlir::tblgen {
- namespace {
- struct Initproc {
- std::string func;
- Initproc(std::string&& s): func(std::move(s)) {}
- std::string operator()(std::string argument) {
- return formatv("{0}({1})", func, argument);
- }
- };
-
- class OpDefEmitter: public EmitterBase {
- public:
- OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_):
- EmitterBase(os_, env_), op(op_) {
- ctx.withSelf(op.getCppClassName());
- }
-
- Initproc emit();
- private:
- void emit_class();
- void emit_py_init();
- void emit_py_getsetters();
- Initproc emit_initproc();
-
- MgbOp& op;
- std::vector<Initproc> subclasses;
- mlir::tblgen::FmtContext ctx;
- };
-
- class EnumAttrEmitter: public EmitterBase {
- public:
- EnumAttrEmitter(llvm::StringRef parent, MgbEnumAttr* attr_, raw_ostream& os_, Environment& env_):
- EmitterBase(os_, env_), attr(attr_) {
- unsigned int enumID;
- if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
- auto&& aliasBase = alias->getAliasBase();
- enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID();
- } else {
- enumID = attr->getBaseRecord()->getID();
- }
- ctx.addSubst("enumTpl", attr->getEnumCombinedFlag() ? "BitCombinedEnumWrapper" : "EnumWrapper");
- ctx.addSubst("opClass", parent);
- ctx.addSubst("enumClass", attr->getEnumName());
- firstOccur = env().enumAlias.emplace(enumID, std::make_pair(parent, attr->getEnumName())).second;
- }
-
- Initproc emit();
- protected:
- void emit_trait();
- void emit_tpl_spl();
- Initproc emit_initproc();
-
- MgbEnumAttr* attr;
- bool firstOccur;
- mlir::tblgen::FmtContext ctx;
- };
-
- Initproc EnumAttrEmitter::emit() {
- emit_trait();
- emit_tpl_spl();
- return emit_initproc();
- }
-
- void EnumAttrEmitter::emit_trait() {
- if (!firstOccur) return;
-
- auto enumMax = [&] {
- if (attr->getEnumCombinedFlag()) {
- return formatv("(1llu << {0}) - 1", attr->getEnumMembers().size());
- } else {
- return formatv("{0} - 1", attr->getEnumMembers().size());
- }
- };
- os << tgfmt(R"(
- template<> struct EnumTrait<$opClass::$enumClass> {
- static constexpr const char *name = "$opClass.$enumClass";
- static constexpr std::underlying_type_t<$opClass::$enumClass> max = $0;
- };
- )", &ctx, enumMax());
- }
-
- void EnumAttrEmitter::emit_tpl_spl() {
- if (!firstOccur) return;
-
- os << tgfmt(
- "template<> PyTypeObject* $enumTpl<$opClass::$enumClass>::type = nullptr;\n",
- &ctx);
-
- auto quote = [&](auto&& i) -> std::string {
- return formatv("\"{0}\"", i);
- };
- os << tgfmt(R"(
- template<> const char*
- $enumTpl<$opClass::$enumClass>::members[] = {$0};
- )", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), quote), ", "));
-
- auto mem2value = [&](auto&& i) -> std::string {
- return tgfmt("{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx, i);
- };
- os << tgfmt(R"(
- template<> std::unordered_map<std::string, $opClass::$enumClass>
- $enumTpl<$opClass::$enumClass>::mem2value = {$0};
- )", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), mem2value), ", "));
-
- os << tgfmt(
- "template<> PyObject* "
- "$enumTpl<$opClass::$enumClass>::pyobj_insts[$0] = {nullptr};\n",
- &ctx, attr->getEnumMembers().size());
- }
-
- Initproc EnumAttrEmitter::emit_initproc() {
- std::string initproc = formatv("_init_py_{0}_{1}",
- ctx.getSubstFor("opClass"), ctx.getSubstFor("enumClass"));
-
- os << tgfmt(R"(
- void $0(PyTypeObject& py_type) {
- auto& e_type = $enumTpl<$opClass::$enumClass>::type;
- )", &ctx, initproc);
-
- if (firstOccur) {
- os << tgfmt(R"(
- static PyType_Slot slots[] = {
- {Py_tp_repr, (void*)$enumTpl<$opClass::$enumClass>::py_repr},
- {Py_tp_richcompare, (void*)$enumTpl<$opClass::$enumClass>::tp_richcompare},
- )", &ctx);
- if (attr->getEnumCombinedFlag()) {
- // only bit combined enum could new instance because bitwise operation,
- // others should always use singleton
- os << tgfmt(R"(
- {Py_tp_new, (void*)$enumTpl<$opClass::$enumClass>::py_new_combined_enum},
- {Py_nb_or, (void*)$enumTpl<$opClass::$enumClass>::py_or},
- {Py_nb_and, (void*)$enumTpl<$opClass::$enumClass>::py_and},
- )", &ctx);
- }
- os << R"(
- {0, NULL}
- };)";
-
- os << tgfmt(R"(
- static PyType_Spec spec = {
- // name
- "megengine.core._imperative_rt.ops.$opClass.$enumClass",
- // basicsize
- sizeof($enumTpl<$opClass::$enumClass>),
- // itemsize
- 0,
- // flags
- Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE,
- // slots
- slots
- };)", &ctx);
-
- os << tgfmt(R"(
- e_type = reinterpret_cast<PyTypeObject*>(PyType_FromSpec(&spec));
- )", &ctx);
-
- for (auto&& i : {
- std::pair<std::string, std::string>{"__name__", tgfmt("$enumClass", &ctx)},
- {"__module__", "megengine.core._imperative_rt.ops"},
- {"__qualname__", tgfmt("$opClass.$enumClass", &ctx)}}) {
- os << formatv(R"(
- mgb_assert(
- e_type->tp_setattro(
- reinterpret_cast<PyObject*>(e_type),
- py::cast("{0}").release().ptr(),
- py::cast("{1}").release().ptr()) >= 0);
- )", i.first, i.second);
- }
-
-
- auto&& members = attr->getEnumMembers();
- for (size_t idx = 0; idx < members.size(); ++ idx) {
- os << tgfmt(R"({
- PyObject* inst = e_type->tp_alloc(e_type, 0);
- reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0;
- mgb_assert(PyDict_SetItemString(e_type->tp_dict, "$0", inst) >= 0);
- $enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst;
- })", &ctx, members[idx], idx);
- }
- }
-
- os << tgfmt(R"(
- Py_INCREF(e_type);
- mgb_assert(PyDict_SetItemString(
- py_type.tp_dict, "$enumClass", reinterpret_cast<PyObject*>(e_type)) >= 0);
- )", &ctx);
- os << "}\n";
- return initproc;
- }
-
- Initproc OpDefEmitter::emit() {
- for (auto&& i : op.getMgbAttributes()) {
- if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
- subclasses.push_back(EnumAttrEmitter(op.getCppClassName(), attr, os, env()).emit());
- }
- }
-
- emit_class();
- emit_py_init();
- emit_py_getsetters();
- return emit_initproc();
- }
-
- void OpDefEmitter::emit_class() {
- os << tgfmt(R"(
- PyOpDefBegin($_self) // {
- static PyGetSetDef py_getsetters[];
- static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
- // };
- PyOpDefEnd($_self)
- )", &ctx);
- }
-
- void OpDefEmitter::emit_py_init() {
- std::string initBody;
- if (!op.getMgbAttributes().empty()) {
- initBody += "static const char* kwlist[] = {";
-
- std::vector<llvm::StringRef> attr_name_list;
- llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
- attr_name_list.push_back(attr.name);
- });
- attr_name_list.push_back("scope");
-
- llvm::for_each(attr_name_list, [&](auto&& attr) {
- initBody += formatv("\"{0}\", ", attr);
- });
- initBody += "NULL};\n";
- initBody += " PyObject ";
- auto initializer = [&](auto&& attr) -> std::string {
- return formatv("*{0} = NULL", attr);
- };
- initBody += llvm::join(llvm::map_range(attr_name_list, initializer), ", ") + ";\n";
- initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|";
- // an extra slot created for name
- initBody += std::string(attr_name_list.size(), 'O');
- initBody += "\", const_cast<char**>(kwlist)";
- llvm::for_each(attr_name_list, [&](auto&& attr) {
- initBody += formatv(", &{0}", attr);
- });
- initBody += "))\n";
- initBody += " return -1;\n";
-
- llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
- initBody += tgfmt(R"(
- if ($0) {
- try {
- // TODO: remove this guard which is used for pybind11 implicit conversion
- py::detail::loader_life_support guard{};
- reinterpret_cast<PyOp($_self)*>(self)->inst().$0 =
- py::cast<decltype($_self::$0)>(py::handle($0));
- } CATCH_ALL(-1)
- }
- )", &ctx, attr.name);
- });
-
- initBody += tgfmt(R"(
- if (scope) {
- try {
- reinterpret_cast<PyOp(OpDef)*>(self)->op
- ->set_scope(py::cast<std::string>(py::handle(scope)));
- } CATCH_ALL(-1)
- }
- )", &ctx);
-
- }
- initBody += "\n return 0;";
-
-
- os << tgfmt(R"(
- int PyOp($_self)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
- $0
- }
- )", &ctx, initBody);
- }
-
- void OpDefEmitter::emit_py_getsetters() {
- auto f = [&](auto&& attr) -> std::string {
- return tgfmt(
- "{const_cast<char*>(\"$0\"), py_get_generic($_self, $0), py_set_generic($_self, $0), const_cast<char*>(\"$0\"), NULL},",
- &ctx, attr.name);
- };
- os << tgfmt(R"(
- PyGetSetDef PyOp($_self)::py_getsetters[] = {
- $0
- {NULL} /* Sentinel */
- };
- )", &ctx, llvm::join(llvm::map_range(op.getMgbAttributes(), f), "\n "));
- }
-
- Initproc OpDefEmitter::emit_initproc() {
- std::string initproc = formatv("_init_py_{0}", op.getCppClassName());
- std::string subclass_init_call;
- for (auto&& i : subclasses) {
- subclass_init_call += formatv(" {0};\n", i("py_type"));
- }
- os << tgfmt(R"(
- void $0(py::module m) {
- using py_op = PyOp($_self);
- auto& py_type = PyOpType($_self);
- py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
- py_type.tp_name = "megengine.core._imperative_rt.ops.$_self";
- py_type.tp_basicsize = sizeof(PyOp($_self));
- py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
- py_type.tp_doc = "$_self";
- py_type.tp_base = &PyOpType(OpDef);
- py_type.tp_dealloc = py_dealloc_generic<py_op>;
- py_type.tp_new = py_new_generic<py_op>;
- py_type.tp_init = py_op::py_init;
- py_type.tp_getset = py_op::py_getsetters;
- mgb_assert(PyType_Ready(&py_type) >= 0);
- $1
- PyType_Modified(&py_type);
- m.add_object("$_self", reinterpret_cast<PyObject*>(&py_type));
- mgb_assert(PyOp(OpDef)::ctype2pytype.emplace($_self::typeinfo(), &py_type).second);
- }
- )", &ctx, initproc, subclass_init_call);
- return initproc;
- }
- } // namespace
-
- bool gen_op_def_python_c_extension(raw_ostream &os, llvm::RecordKeeper &keeper) {
- Environment env;
- using namespace std::placeholders;
- std::vector<Initproc> initprocs;
- foreach_operator(keeper, [&](MgbOp& op) {
- initprocs.emplace_back(OpDefEmitter(op, os, env).emit());
- });
- os << "#define INIT_ALL_OP(m)";
- for(auto&& init : initprocs) {
- os << formatv(" \\\n {0};", init("m"));
- }
- os << "\n";
- return false;
- }
- } // namespace mlir::tblgen
|