GitOrigin-RevId: b81b085762
release-1.3
@@ -1,6 +1,7 @@ | |||
# mgb tablegen executable | |||
set(TABLE_TARGET mgb-mlir-autogen) | |||
add_executable(${TABLE_TARGET} autogen.cpp) | |||
file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR}/*.h ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) | |||
add_executable(${TABLE_TARGET} ${SRCS}) | |||
target_include_directories(${TABLE_TARGET} PRIVATE ${MLIR_LLVM_INCLUDE_DIR}) | |||
target_link_libraries(${TABLE_TARGET} PRIVATE LLVMTableGen MLIRTableGen LLVMSupport) | |||
set(MGB_TABLEGEN_EXE ${TABLE_TARGET}) | |||
@@ -1,8 +1,17 @@ | |||
#include <iostream> | |||
#include <unordered_map> | |||
#include <functional> | |||
#include "./helper.h" | |||
/** | |||
* \file imperative/tablegen/autogen.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "./targets/cpp_class.h" | |||
#include "./targets/pybind11.h" | |||
#include "./targets/python_c_extension.h" | |||
using llvm::raw_ostream; | |||
using llvm::RecordKeeper; | |||
@@ -27,731 +36,7 @@ llvm::cl::opt<ActionType> action( | |||
clEnumValN(CPython, "gen-python-c-extension", | |||
"Generate python c extensions"))); | |||
using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase; | |||
using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin; | |||
using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin; | |||
using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin; | |||
using MgbOp = mlir::tblgen::MgbOpBase; | |||
using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin; | |||
llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) { | |||
// Note: we have already registered the corresponding attr wrappers | |||
// for following basic ctypes so we needn't handle them here | |||
/* auto&& attr_type_name = attr.getAttrDefName(); | |||
if (attr_type_name == "UI32Attr") { | |||
return "uint32_t"; | |||
} | |||
if (attr_type_name == "UI64Attr") { | |||
return "uint64_t"; | |||
} | |||
if (attr_type_name == "I32Attr") { | |||
return "int32_t"; | |||
} | |||
if (attr_type_name == "F32Attr") { | |||
return "float"; | |||
} | |||
if (attr_type_name == "F64Attr") { | |||
return "double"; | |||
} | |||
if (attr_type_name == "StrAttr") { | |||
return "std::string"; | |||
} | |||
if (attr_type_name == "BoolAttr") { | |||
return "bool"; | |||
}*/ | |||
auto&& attr = llvm::cast<MgbAttrWrapper>(attr_); | |||
if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) { | |||
return e->getEnumName(); | |||
} | |||
return attr.getUnderlyingType(); | |||
} | |||
static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) { | |||
os << formatv( | |||
"class {0} : public OpDefImplBase<{0}> {{\n" | |||
" MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n" | |||
"public:\n", | |||
op.getCppClassName() | |||
); | |||
// handle enum alias | |||
for (auto &&i : op.getMgbAttributes()) { | |||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
os << formatv( | |||
" using {0} = {1};\n", | |||
attr->getEnumName(), attr->getUnderlyingType() | |||
); | |||
} | |||
} | |||
for (auto &&i : op.getMgbAttributes()) { | |||
auto defaultValue = i.attr.getDefaultValue().str(); | |||
if (!defaultValue.empty()) { | |||
defaultValue = formatv(" = {0}", defaultValue); | |||
} | |||
os << formatv( | |||
" {0} {1}{2};\n", | |||
attr_to_ctype(i.attr), i.name, defaultValue | |||
); | |||
} | |||
auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) { | |||
os << formatv( | |||
" {0}({1}){2}{3}\n", | |||
op.getCppClassName(), paramList, memInitList, body | |||
); | |||
}; | |||
gen_ctor("", "", " = default;"); | |||
if (!op.getMgbAttributes().empty()) { | |||
std::vector<std::string> paramList, initList; | |||
for (auto &&i : op.getMgbAttributes()) { | |||
paramList.push_back(formatv( | |||
"{0} {1}_", attr_to_ctype(i.attr), i.name | |||
)); | |||
initList.push_back(formatv( | |||
"{0}({0}_)", i.name | |||
)); | |||
} | |||
paramList.push_back("std::string scope_ = {}"); | |||
gen_ctor(llvm::join(paramList, ", "), | |||
": " + llvm::join(initList, ", "), | |||
" { set_scope(scope_); }"); | |||
} | |||
auto packedParams = op.getPackedParams(); | |||
if (!packedParams.empty()) { | |||
std::vector<std::string> paramList, initList; | |||
for (auto &&p : packedParams) { | |||
auto&& paramFields = p.getFields(); | |||
auto&& paramType = p.getFullName(); | |||
auto&& paramName = formatv("packed_param_{0}", paramList.size()); | |||
paramList.push_back( | |||
paramFields.empty() ? paramType.str() | |||
: formatv("{0} {1}", paramType, paramName) | |||
); | |||
for (auto&& i : paramFields) { | |||
initList.push_back(formatv( | |||
"{0}({1}.{0})", i.name, paramName | |||
)); | |||
} | |||
} | |||
for (auto&& i : op.getExtraArguments()) { | |||
paramList.push_back(formatv( | |||
"{0} {1}_", attr_to_ctype(i.attr), i.name | |||
)); | |||
initList.push_back(formatv( | |||
"{0}({0}_)", i.name | |||
)); | |||
} | |||
gen_ctor(llvm::join(paramList, ", "), | |||
initList.empty() ? "" : ": " + llvm::join(initList, ", "), | |||
" {}"); | |||
} | |||
if (!packedParams.empty()) { | |||
for (auto&& p : packedParams) { | |||
auto accessor = p.getAccessor(); | |||
if (!accessor.empty()) { | |||
os << formatv( | |||
" {0} {1}() const {{\n", | |||
p.getFullName(), accessor | |||
); | |||
std::vector<llvm::StringRef> fields; | |||
for (auto&& i : p.getFields()) { | |||
fields.push_back(i.name); | |||
} | |||
os << formatv( | |||
" return {{{0}};\n", | |||
llvm::join(fields, ", ") | |||
); | |||
os << " }\n"; | |||
} | |||
} | |||
} | |||
if (auto decl = op.getExtraOpdefDecl()) { | |||
os << decl.getValue(); | |||
} | |||
os << formatv( | |||
"};\n\n" | |||
); | |||
} | |||
static void gen_to_string_trait_for_enum(raw_ostream &os, MgbOp& op) { | |||
for (auto &&i : op.getMgbAttributes()) { | |||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
if (attr->supportToString()) { | |||
std::vector<std::string> case_body; | |||
std::string ename = formatv("{0}::{1}", | |||
op.getCppClassName(), attr->getEnumName()); | |||
llvm::for_each(attr->getEnumMembers(), [&](auto&& v){ | |||
case_body.push_back(formatv( | |||
"case {0}::{1}: return \"{1}\";", ename, v)); | |||
}); | |||
os << formatv(R"( | |||
template <> | |||
struct ToStringTrait<{0}> { | |||
std::string operator()({0} e) const { | |||
switch (e) { | |||
{1} | |||
default: | |||
return "{0}::Unknown"; | |||
} | |||
} | |||
}; | |||
)", ename, llvm::join(case_body, "\n")); | |||
} | |||
} | |||
} | |||
} | |||
static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||
auto&& className = op.getCppClassName(); | |||
os << formatv( | |||
"MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className | |||
); | |||
auto formatMethImpl = [&](auto&& meth) { | |||
return formatv( | |||
"{0}_{1}_impl", className, meth | |||
); | |||
}; | |||
std::vector<std::string> methods; | |||
if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) { | |||
os << "namespace {\n"; | |||
// generate hash() | |||
mlir::tblgen::FmtContext ctx; | |||
os << formatv( | |||
"size_t {0}(const OpDef& def_) {{\n", | |||
formatMethImpl("hash") | |||
); | |||
os << formatv( | |||
" auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
" static_cast<void>(op_);\n", | |||
className | |||
); | |||
ctx.withSelf("op_"); | |||
os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx); | |||
os << "}\n"; | |||
// generate is_same_st() | |||
os << formatv( | |||
"bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n", | |||
formatMethImpl("is_same_st") | |||
); | |||
os << formatv( | |||
" auto &&a_ = lhs_.cast_final_safe<{0}>(),\n" | |||
" &&b_ = rhs_.cast_final_safe<{0}>();\n" | |||
" static_cast<void>(a_);\n" | |||
" static_cast<void>(b_);\n", | |||
className | |||
); | |||
os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_"); | |||
os << "}\n"; | |||
// generate props() | |||
os << formatv( | |||
"std::vector<std::pair<const char*, std::string>> {0}(const OpDef& def_) {{\n", | |||
formatMethImpl("props") | |||
); | |||
os << formatv( | |||
" auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
" static_cast<void>(op_);\n", | |||
className | |||
); | |||
ctx.withSelf("op_"); | |||
os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx); | |||
os << "}\n"; | |||
// generate make_name() | |||
os << formatv( | |||
"std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name") | |||
); | |||
os << formatv( | |||
" auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
" static_cast<void>(op_);\n", | |||
className | |||
); | |||
ctx.withSelf("op_"); | |||
os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx); | |||
os << "}\n"; | |||
os << "} // anonymous namespace\n"; | |||
methods.push_back("hash"); | |||
methods.push_back("is_same_st"); | |||
methods.push_back("props"); | |||
methods.push_back("make_name"); | |||
} | |||
if (!methods.empty()) { | |||
os << formatv( | |||
"OP_TRAIT_REG({0}, {0})", op.getCppClassName() | |||
); | |||
for (auto&& i : methods) { | |||
os << formatv( | |||
"\n .{0}({1})", i, formatMethImpl(i) | |||
); | |||
} | |||
os << ";\n\n"; | |||
} | |||
} | |||
struct EnumContext { | |||
std::unordered_map<unsigned int, std::pair<llvm::StringRef, llvm::StringRef>> enumAlias; | |||
}; | |||
static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { | |||
auto className = op.getCppClassName(); | |||
os << formatv( | |||
"py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n", | |||
className | |||
); | |||
for (auto&& i : op.getMgbAttributes()) { | |||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
unsigned int enumID; | |||
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
auto&& aliasBase = alias->getAliasBase(); | |||
enumID = | |||
llvm::cast<MgbEnumAttr>(aliasBase) | |||
.getBaseRecord()->getID(); | |||
} else { | |||
enumID = attr->getBaseRecord()->getID(); | |||
} | |||
auto&& enumAlias = ctx.enumAlias; | |||
auto&& iter = enumAlias.find(enumID); | |||
if (iter == enumAlias.end()) { | |||
os << formatv( | |||
"py::enum_<{0}::{1}>({0}Inst, \"{1}\")", | |||
className, attr->getEnumName() | |||
); | |||
std::vector<std::string> body; | |||
for (auto&& i: attr->getEnumMembers()) { | |||
os << formatv( | |||
"\n .value(\"{2}\", {0}::{1}::{2})", | |||
className, attr->getEnumName(), i | |||
); | |||
body.push_back(formatv( | |||
"if (str == \"{2}\") return {0}::{1}::{2};", | |||
className, attr->getEnumName(), i | |||
)); | |||
} | |||
if (attr->getEnumCombinedFlag()) { | |||
//! define operator | | |||
os << formatv( | |||
"\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ " | |||
"\n return static_cast<{0}::{1}>(uint32_t(s0) | uint32_t(s1));" | |||
"\n })", | |||
className, attr->getEnumName()); | |||
//! define operator & | |||
os << formatv( | |||
"\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{" | |||
"\n return static_cast<{0}::{1}>(uint32_t(s0) & uint32_t(s1));" | |||
"\n })", | |||
className, attr->getEnumName()); | |||
} | |||
os << formatv( | |||
"\n .def(py::init([](const std::string& in) {" | |||
"\n auto&& str = normalize_enum(in);" | |||
"\n {0}" | |||
"\n throw py::cast_error(\"invalid enum value \" + in);" | |||
"\n }));\n", | |||
llvm::join(body, "\n ") | |||
); | |||
os << formatv( | |||
"py::implicitly_convertible<std::string, {0}::{1}>();\n\n", | |||
className, attr->getEnumName() | |||
); | |||
enumAlias.emplace(enumID, | |||
std::make_pair(className, attr->getEnumName())); | |||
} else { | |||
os << formatv( | |||
"{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n", | |||
className, attr->getEnumName(), | |||
iter->second.first, iter->second.second | |||
); | |||
} | |||
} | |||
} | |||
// generate op class binding | |||
os << formatv("{0}Inst", className); | |||
bool hasDefaultCtor = op.getMgbAttributes().empty(); | |||
if (!hasDefaultCtor) { | |||
os << "\n .def(py::init<"; | |||
std::vector<llvm::StringRef> targs; | |||
for (auto &&i : op.getMgbAttributes()) { | |||
targs.push_back(i.attr.getReturnType()); | |||
} | |||
os << llvm::join(targs, ", "); | |||
os << ", std::string>()"; | |||
for (auto &&i : op.getMgbAttributes()) { | |||
os << formatv(", py::arg(\"{0}\")", i.name); | |||
auto defaultValue = i.attr.getDefaultValue(); | |||
if (!defaultValue.empty()) { | |||
os << formatv(" = {0}", defaultValue); | |||
} else { | |||
hasDefaultCtor = true; | |||
} | |||
} | |||
os << ", py::arg(\"scope\") = {})"; | |||
} | |||
if (hasDefaultCtor) { | |||
os << "\n .def(py::init<>())"; | |||
} | |||
for (auto &&i : op.getMgbAttributes()) { | |||
os << formatv( | |||
"\n .def_readwrite(\"{0}\", &{1}::{0})", | |||
i.name, className | |||
); | |||
} | |||
os << ";\n\n"; | |||
} | |||
static std::string gen_op_def_python_c_extension_enum( | |||
raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr, | |||
llvm::StringRef className) { | |||
std::string body; | |||
unsigned int enumID; | |||
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
auto&& aliasBase = alias->getAliasBase(); | |||
enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID(); | |||
} else { | |||
enumID = attr->getBaseRecord()->getID(); | |||
} | |||
auto&& enumAlias = ctx.enumAlias; | |||
auto&& iter = enumAlias.find(enumID); | |||
auto enumName = attr->getEnumName(); | |||
body += "{\n"; | |||
body += formatv("auto& e_type = EnumWrapper<{0}::{1}>::type;", className, | |||
enumName); | |||
if (iter == enumAlias.end()) { | |||
os << formatv( | |||
"template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n", | |||
className, enumName); | |||
os << formatv( | |||
"template<> const char* EnumWrapper<{0}::{1}>::name = " | |||
"\"{0}.{1}\";\n", | |||
className, enumName); | |||
std::vector<std::string> pairStr; | |||
for (auto&& i : attr->getEnumMembers()) { | |||
pairStr.push_back( | |||
formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", | |||
className, enumName, i)); | |||
} | |||
os << formatv(R"( | |||
template<> std::unordered_map<std::string, {0}::{1}> | |||
EnumWrapper<{0}::{1}>::str2type = {{ | |||
{2} | |||
}; | |||
)", | |||
className, enumName, llvm::join(pairStr, ", ")); | |||
pairStr.clear(); | |||
for (auto&& i : attr->getEnumMembers()) { | |||
pairStr.push_back( | |||
formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", | |||
className, enumName, i)); | |||
} | |||
os << formatv(R"( | |||
template<> std::unordered_map<{0}::{1}, std::string> | |||
EnumWrapper<{0}::{1}>::type2str = {{ | |||
{2} | |||
}; | |||
)", | |||
className, enumName, llvm::join(pairStr, ", ")); | |||
body += formatv(R"( | |||
e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||
e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; | |||
e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>); | |||
e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
e_type.tp_doc = "{0}.{1}"; | |||
e_type.tp_base = &PyBaseObject_Type; | |||
e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr; | |||
e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare; | |||
mgb_assert(PyType_Ready(&e_type) >= 0); | |||
)", | |||
className, enumName); | |||
for (auto&& i : attr->getEnumMembers()) { | |||
body += formatv(R"({{ | |||
PyObject* inst = e_type.tp_alloc(&e_type, 0); | |||
reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2}; | |||
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); | |||
})", | |||
className, enumName, i); | |||
} | |||
enumAlias.emplace(enumID, std::make_pair(className, enumName)); | |||
} | |||
body += formatv(R"( | |||
PyType_Modified(&e_type); | |||
mgb_assert(PyDict_SetItemString( | |||
py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0); | |||
)", | |||
enumName); | |||
body += "}\n"; | |||
return body; | |||
} | |||
static std::string gen_op_def_python_c_extension_bit_combined_enum( | |||
raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr, | |||
llvm::StringRef className) { | |||
std::string body; | |||
unsigned int enumID; | |||
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
auto&& aliasBase = alias->getAliasBase(); | |||
enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID(); | |||
} else { | |||
enumID = attr->getBaseRecord()->getID(); | |||
} | |||
auto&& enumAlias = ctx.enumAlias; | |||
auto&& iter = enumAlias.find(enumID); | |||
auto enumName = attr->getEnumName(); | |||
body += "{\n"; | |||
body += formatv("auto& e_type = BitCombinedEnumWrapper<{0}::{1}>::type;", | |||
className, enumName); | |||
if (iter == enumAlias.end()) { | |||
os << formatv( | |||
"template<> PyTypeObject " | |||
"BitCombinedEnumWrapper<{0}::{1}>::type={{};\n", | |||
className, enumName); | |||
os << formatv( | |||
"template<> PyNumberMethods " | |||
"BitCombinedEnumWrapper<{0}::{1}>::number_methods={{};\n", | |||
className, enumName); | |||
os << formatv( | |||
"template<> const char* BitCombinedEnumWrapper<{0}::{1}>::name " | |||
"= \"{0}.{1}\";\n", | |||
className, enumName); | |||
os << formatv( | |||
"template<> struct EnumTrait<{0}::{1}> {{ static constexpr " | |||
"bool is_bit_combined = true;};\n", | |||
className, enumName); | |||
std::vector<std::string> pairStr; | |||
for (auto&& i : attr->getEnumMembers()) { | |||
pairStr.push_back( | |||
formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", | |||
className, enumName, i)); | |||
} | |||
os << formatv(R"( | |||
template<> std::unordered_map<std::string, {0}::{1}> | |||
BitCombinedEnumWrapper<{0}::{1}>::str2type = {{ | |||
{2} | |||
}; | |||
)", | |||
className, enumName, llvm::join(pairStr, ", ")); | |||
pairStr.clear(); | |||
for (auto&& i : attr->getEnumMembers()) { | |||
pairStr.push_back( | |||
formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", | |||
className, enumName, i)); | |||
} | |||
os << formatv(R"( | |||
template<> std::unordered_map<{0}::{1}, std::string> | |||
BitCombinedEnumWrapper<{0}::{1}>::type2str = {{ | |||
{2} | |||
}; | |||
)", | |||
className, enumName, llvm::join(pairStr, ", ")); | |||
body += formatv(R"( | |||
e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||
e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; | |||
e_type.tp_basicsize = sizeof(BitCombinedEnumWrapper<{0}::{1}>); | |||
e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
e_type.tp_doc = "{0}.{1}"; | |||
e_type.tp_base = &PyBaseObject_Type; | |||
e_type.tp_new = BitCombinedEnumWrapper<{0}::{1}>::py_new_combined_enum; | |||
e_type.tp_init = BitCombinedEnumWrapper<{0}::{1}>::py_init; | |||
e_type.tp_repr = BitCombinedEnumWrapper<{0}::{1}>::py_repr; | |||
e_type.tp_richcompare = BitCombinedEnumWrapper<{0}::{1}>::tp_richcompare; | |||
auto& number_method = BitCombinedEnumWrapper<{0}::{1}>::number_methods; | |||
number_method.nb_or = BitCombinedEnumWrapper<{0}::{1}>::py_or; | |||
number_method.nb_and = BitCombinedEnumWrapper<{0}::{1}>::py_and; | |||
e_type.tp_as_number = &number_method; | |||
mgb_assert(PyType_Ready(&e_type) >= 0); | |||
)", | |||
className, enumName); | |||
for (auto&& i : attr->getEnumMembers()) { | |||
body += formatv(R"({{ | |||
PyObject* inst = e_type.tp_alloc(&e_type, 0); | |||
reinterpret_cast<BitCombinedEnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2}; | |||
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); | |||
})", | |||
className, enumName, i); | |||
} | |||
enumAlias.emplace(enumID, std::make_pair(className, enumName)); | |||
} | |||
body += formatv(R"( | |||
PyType_Modified(&e_type); | |||
mgb_assert(PyDict_SetItemString( | |||
py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0); | |||
)", | |||
enumName); | |||
body += "}\n"; | |||
return body; | |||
} | |||
static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { | |||
auto className = op.getCppClassName(); | |||
std::string body; | |||
// generate PyType for enum class member | |||
for (auto&& i : op.getMgbAttributes()) { | |||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
if (attr->getEnumCombinedFlag()) { | |||
body += gen_op_def_python_c_extension_bit_combined_enum( | |||
os, ctx, attr, className); | |||
} else { | |||
body += gen_op_def_python_c_extension_enum(os, ctx, attr, | |||
className); | |||
} | |||
} | |||
} | |||
// generate getsetters | |||
std::vector<std::string> getsetters; | |||
for (auto &&i : op.getMgbAttributes()) { | |||
getsetters.push_back(formatv( | |||
"{{const_cast<char*>(\"{1}\"), py_get_generic({0}, {1}), py_set_generic({0}, {1}), const_cast<char*>(\"{1}\"), NULL},", | |||
className, i.name)); | |||
} | |||
// generate tp_init | |||
std::string initBody; | |||
if (!op.getMgbAttributes().empty()) { | |||
initBody += "static const char* kwlist[] = {"; | |||
std::vector<llvm::StringRef> attr_name_list; | |||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
attr_name_list.push_back(attr.name); | |||
}); | |||
attr_name_list.push_back("scope"); | |||
llvm::for_each(attr_name_list, [&](auto&& attr) { | |||
initBody += formatv("\"{0}\", ", attr); | |||
}); | |||
initBody += "NULL};\n"; | |||
initBody += " PyObject "; | |||
std::vector<std::string> attr_init; | |||
llvm::for_each(attr_name_list, [&](auto&& attr) { | |||
attr_init.push_back(formatv("*{0} = NULL", attr)); | |||
}); | |||
initBody += llvm::join(attr_init, ", ") + ";\n"; | |||
initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|"; | |||
// an extra slot created for name | |||
initBody += std::string(attr_name_list.size(), 'O'); | |||
initBody += "\", const_cast<char**>(kwlist)"; | |||
llvm::for_each(attr_name_list, [&](auto&& attr) { | |||
initBody += formatv(", &{0}", attr); | |||
}); | |||
initBody += "))\n"; | |||
initBody += " return -1;\n"; | |||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
initBody += formatv(R"( | |||
if ({1}) {{ | |||
try {{ | |||
reinterpret_cast<PyOp({0})*>(self)->inst().{1} = | |||
pyobj_convert_generic<decltype({0}::{1})>::from({1}); | |||
} CATCH_ALL(-1) | |||
} | |||
)", className, attr.name); | |||
}); | |||
initBody += formatv(R"( | |||
if (scope) {{ | |||
try {{ | |||
reinterpret_cast<PyOp(OpDef)*>(self)->op | |||
->set_scope(pyobj_convert_generic<std::string>::from(scope)); | |||
} CATCH_ALL(-1) | |||
} | |||
)", className); | |||
} | |||
initBody += "\n return 0;"; | |||
os << formatv(R"( | |||
PyOpDefBegin({0}) // {{ | |||
static PyGetSetDef py_getsetters[]; | |||
static int py_init(PyObject *self, PyObject *args, PyObject *kwds); | |||
// }; | |||
PyOpDefEnd({0}) | |||
PyGetSetDef PyOp({0})::py_getsetters[] = {{ | |||
{1} | |||
{{NULL} /* Sentinel */ | |||
}; | |||
int PyOp({0})::py_init(PyObject *self, PyObject *args, PyObject *kwds) {{ | |||
{2} | |||
} | |||
void _init_py_{0}(py::module m) {{ | |||
using py_op = PyOp({0}); | |||
auto& py_type = PyOpType({0}); | |||
py_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||
py_type.tp_name = "megengine.core._imperative_rt.ops.{0}"; | |||
py_type.tp_basicsize = sizeof(PyOp({0})); | |||
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
py_type.tp_doc = "{0}"; | |||
py_type.tp_base = &PyOpType(OpDef); | |||
py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||
py_type.tp_new = py_new_generic<py_op>; | |||
py_type.tp_init = py_op::py_init; | |||
py_type.tp_getset = py_op::py_getsetters; | |||
mgb_assert(PyType_Ready(&py_type) >= 0); | |||
{3} | |||
PyType_Modified(&py_type); | |||
m.add_object("{0}", reinterpret_cast<PyObject*>(&py_type)); | |||
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace({0}::typeinfo(), &py_type).second); | |||
} | |||
)", | |||
op.getCppClassName(), llvm::join(getsetters, "\n "), initBody, body); | |||
} | |||
static void for_each_operator(raw_ostream &os, RecordKeeper &keeper, | |||
std::function<void(raw_ostream&, MgbOp&)> callback) { | |||
auto op_base_class = keeper.getClass("Op"); | |||
ASSERT(op_base_class, "could not find base class Op"); | |||
for (auto&& i: keeper.getDefs()) { | |||
auto&& r = i.second; | |||
if (r->isSubClassOf(op_base_class)) { | |||
auto op = mlir::tblgen::Operator(r.get()); | |||
if (op.getDialectName().str() == "mgb") { | |||
std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl; | |||
callback(os, llvm::cast<MgbOp>(op)); | |||
} | |||
} | |||
} | |||
} | |||
static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) { | |||
for_each_operator(os, keeper, gen_op_def_c_header_single); | |||
for_each_operator(os, keeper, gen_to_string_trait_for_enum); | |||
return false; | |||
} | |||
static bool gen_op_def_c_body(raw_ostream &os, RecordKeeper &keeper) { | |||
for_each_operator(os, keeper, gen_op_def_c_body_single); | |||
return false; | |||
} | |||
static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) { | |||
EnumContext ctx; | |||
using namespace std::placeholders; | |||
for_each_operator(os, keeper, | |||
std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx))); | |||
return false; | |||
} | |||
static bool gen_op_def_python_c_extension(raw_ostream &os, RecordKeeper &keeper) { | |||
EnumContext ctx; | |||
using namespace std::placeholders; | |||
for_each_operator(os, keeper, | |||
std::bind(gen_op_def_python_c_extension_single, _1, _2, std::ref(ctx))); | |||
os << "#define INIT_ALL_OP(m)"; | |||
for_each_operator(os, keeper, [&](raw_ostream& os, MgbOp& op) { | |||
os << formatv(" \\\n _init_py_{0}(m);", op.getCppClassName()); | |||
}); | |||
os << "\n"; | |||
return false; | |||
} | |||
using namespace mlir::tblgen; | |||
int main(int argc, char **argv) { | |||
llvm::InitLLVM y(argc, argv); | |||
@@ -0,0 +1,40 @@ | |||
/** | |||
* \file imperative/tablegen/emitter.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <unordered_map> | |||
#include <stdexcept> | |||
#include "llvm/ADT/StringRef.h" | |||
#include "llvm/Support/raw_ostream.h" | |||
namespace mlir::tblgen { | |||
struct Environment { | |||
std::unordered_map<unsigned int, std::pair<llvm::StringRef, llvm::StringRef>> enumAlias; | |||
}; | |||
struct EmitterBase { | |||
EmitterBase(raw_ostream& os_): os(os_) {} | |||
EmitterBase(raw_ostream& os_, Environment& env): os(os_), env_p(&env) {} | |||
protected: | |||
void newline() { os << "\n"; } | |||
Environment& env() { | |||
if (env_p) { | |||
return *env_p; | |||
} | |||
throw std::runtime_error("access global environment via non-environment emitter"); | |||
} | |||
raw_ostream& os; | |||
Environment* env_p = nullptr; | |||
}; | |||
} // namespace mlir::tblgen |
@@ -1,3 +1,16 @@ | |||
/** | |||
* \file imperative/tablegen/helper.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <iostream> | |||
#include <string> | |||
#include <vector> | |||
@@ -278,5 +291,28 @@ public: | |||
} | |||
}; | |||
using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase; | |||
using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin; | |||
using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin; | |||
using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin; | |||
using MgbOp = mlir::tblgen::MgbOpBase; | |||
using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin; | |||
static inline void foreach_operator(llvm::RecordKeeper &keeper, | |||
std::function<void(MgbOp&)> callback) { | |||
auto op_base_class = keeper.getClass("Op"); | |||
ASSERT(op_base_class, "could not find base class Op"); | |||
for (auto&& i: keeper.getDefs()) { | |||
auto&& r = i.second; | |||
if (r->isSubClassOf(op_base_class)) { | |||
auto op = mlir::tblgen::Operator(r.get()); | |||
if (op.getDialectName().str() == "mgb") { | |||
std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl; | |||
callback(llvm::cast<MgbOp>(op)); | |||
} | |||
} | |||
} | |||
} | |||
} // namespace tblgen | |||
} // namespace mlir |
@@ -0,0 +1,309 @@ | |||
/** | |||
* \file imperative/tablegen/targets/cpp_class.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "./cpp_class.h" | |||
#include "../emitter.h" | |||
namespace mlir::tblgen { | |||
namespace { | |||
llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) { | |||
// Note: we have already registered the corresponding attr wrappers | |||
// for following basic ctypes so we needn't handle them here | |||
/* auto&& attr_type_name = attr.getAttrDefName(); | |||
if (attr_type_name == "UI32Attr") { | |||
return "uint32_t"; | |||
} | |||
if (attr_type_name == "UI64Attr") { | |||
return "uint64_t"; | |||
} | |||
if (attr_type_name == "I32Attr") { | |||
return "int32_t"; | |||
} | |||
if (attr_type_name == "F32Attr") { | |||
return "float"; | |||
} | |||
if (attr_type_name == "F64Attr") { | |||
return "double"; | |||
} | |||
if (attr_type_name == "StrAttr") { | |||
return "std::string"; | |||
} | |||
if (attr_type_name == "BoolAttr") { | |||
return "bool"; | |||
}*/ | |||
auto&& attr = llvm::cast<MgbAttrWrapper>(attr_); | |||
if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) { | |||
return e->getEnumName(); | |||
} | |||
return attr.getUnderlyingType(); | |||
} | |||
class OpDefEmitter final: public EmitterBase { | |||
public: | |||
OpDefEmitter(MgbOp& op_, raw_ostream& os_): | |||
EmitterBase(os_), op(op_) {} | |||
void emit_header(); | |||
void emit_tpl_spl(); | |||
void emit_body(); | |||
private: | |||
MgbOp& op; | |||
}; | |||
void OpDefEmitter::emit_header() { | |||
os << formatv( | |||
"class {0} : public OpDefImplBase<{0}> {{\n" | |||
" MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n" | |||
"public:\n", | |||
op.getCppClassName() | |||
); | |||
// handle enum alias | |||
for (auto &&i : op.getMgbAttributes()) { | |||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
os << formatv( | |||
" using {0} = {1};\n", | |||
attr->getEnumName(), attr->getUnderlyingType() | |||
); | |||
} | |||
} | |||
for (auto &&i : op.getMgbAttributes()) { | |||
auto defaultValue = i.attr.getDefaultValue().str(); | |||
if (!defaultValue.empty()) { | |||
defaultValue = formatv(" = {0}", defaultValue); | |||
} | |||
os << formatv( | |||
" {0} {1}{2};\n", | |||
attr_to_ctype(i.attr), i.name, defaultValue | |||
); | |||
} | |||
auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) { | |||
os << formatv( | |||
" {0}({1}){2}{3}\n", | |||
op.getCppClassName(), paramList, memInitList, body | |||
); | |||
}; | |||
gen_ctor("", "", " = default;"); | |||
if (!op.getMgbAttributes().empty()) { | |||
std::vector<std::string> paramList, initList; | |||
for (auto &&i : op.getMgbAttributes()) { | |||
paramList.push_back(formatv( | |||
"{0} {1}_", attr_to_ctype(i.attr), i.name | |||
)); | |||
initList.push_back(formatv( | |||
"{0}({0}_)", i.name | |||
)); | |||
} | |||
paramList.push_back("std::string scope_ = {}"); | |||
gen_ctor(llvm::join(paramList, ", "), | |||
": " + llvm::join(initList, ", "), | |||
" { set_scope(scope_); }"); | |||
} | |||
auto packedParams = op.getPackedParams(); | |||
if (!packedParams.empty()) { | |||
std::vector<std::string> paramList, initList; | |||
for (auto &&p : packedParams) { | |||
auto&& paramFields = p.getFields(); | |||
auto&& paramType = p.getFullName(); | |||
auto&& paramName = formatv("packed_param_{0}", paramList.size()); | |||
paramList.push_back( | |||
paramFields.empty() ? paramType.str() | |||
: formatv("{0} {1}", paramType, paramName) | |||
); | |||
for (auto&& i : paramFields) { | |||
initList.push_back(formatv( | |||
"{0}({1}.{0})", i.name, paramName | |||
)); | |||
} | |||
} | |||
for (auto&& i : op.getExtraArguments()) { | |||
paramList.push_back(formatv( | |||
"{0} {1}_", attr_to_ctype(i.attr), i.name | |||
)); | |||
initList.push_back(formatv( | |||
"{0}({0}_)", i.name | |||
)); | |||
} | |||
gen_ctor(llvm::join(paramList, ", "), | |||
initList.empty() ? "" : ": " + llvm::join(initList, ", "), | |||
" {}"); | |||
} | |||
if (!packedParams.empty()) { | |||
for (auto&& p : packedParams) { | |||
auto accessor = p.getAccessor(); | |||
if (!accessor.empty()) { | |||
os << formatv( | |||
" {0} {1}() const {{\n", | |||
p.getFullName(), accessor | |||
); | |||
std::vector<llvm::StringRef> fields; | |||
for (auto&& i : p.getFields()) { | |||
fields.push_back(i.name); | |||
} | |||
os << formatv( | |||
" return {{{0}};\n", | |||
llvm::join(fields, ", ") | |||
); | |||
os << " }\n"; | |||
} | |||
} | |||
} | |||
if (auto decl = op.getExtraOpdefDecl()) { | |||
os << decl.getValue(); | |||
} | |||
os << formatv( | |||
"};\n\n" | |||
); | |||
} | |||
void OpDefEmitter::emit_tpl_spl() { | |||
for (auto &&i : op.getMgbAttributes()) { | |||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
if (attr->supportToString()) { | |||
std::vector<std::string> case_body; | |||
std::string ename = formatv("{0}::{1}", | |||
op.getCppClassName(), attr->getEnumName()); | |||
llvm::for_each(attr->getEnumMembers(), [&](auto&& v){ | |||
case_body.push_back(formatv( | |||
"case {0}::{1}: return \"{1}\";", ename, v)); | |||
}); | |||
os << formatv(R"( | |||
template <> | |||
struct ToStringTrait<{0}> { | |||
std::string operator()({0} e) const { | |||
switch (e) { | |||
{1} | |||
default: | |||
return "{0}::Unknown"; | |||
} | |||
} | |||
}; | |||
)", ename, llvm::join(case_body, "\n")); | |||
} | |||
} | |||
} | |||
} | |||
void OpDefEmitter::emit_body() { | |||
auto&& className = op.getCppClassName(); | |||
os << formatv( | |||
"MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className | |||
); | |||
auto formatMethImpl = [&](auto&& meth) { | |||
return formatv( | |||
"{0}_{1}_impl", className, meth | |||
); | |||
}; | |||
std::vector<std::string> methods; | |||
if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) { | |||
os << "namespace {\n"; | |||
// generate hash() | |||
mlir::tblgen::FmtContext ctx; | |||
os << formatv( | |||
"size_t {0}(const OpDef& def_) {{\n", | |||
formatMethImpl("hash") | |||
); | |||
os << formatv( | |||
" auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
" static_cast<void>(op_);\n", | |||
className | |||
); | |||
ctx.withSelf("op_"); | |||
os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx); | |||
os << "}\n"; | |||
// generate is_same_st() | |||
os << formatv( | |||
"bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n", | |||
formatMethImpl("is_same_st") | |||
); | |||
os << formatv( | |||
" auto &&a_ = lhs_.cast_final_safe<{0}>(),\n" | |||
" &&b_ = rhs_.cast_final_safe<{0}>();\n" | |||
" static_cast<void>(a_);\n" | |||
" static_cast<void>(b_);\n", | |||
className | |||
); | |||
os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_"); | |||
os << "}\n"; | |||
// generate props() | |||
os << formatv( | |||
"std::vector<std::pair<const char*, std::string>> {0}(const OpDef& def_) {{\n", | |||
formatMethImpl("props") | |||
); | |||
os << formatv( | |||
" auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
" static_cast<void>(op_);\n", | |||
className | |||
); | |||
ctx.withSelf("op_"); | |||
os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx); | |||
os << "}\n"; | |||
// generate make_name() | |||
os << formatv( | |||
"std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name") | |||
); | |||
os << formatv( | |||
" auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
" static_cast<void>(op_);\n", | |||
className | |||
); | |||
ctx.withSelf("op_"); | |||
os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx); | |||
os << "}\n"; | |||
os << "} // anonymous namespace\n"; | |||
methods.push_back("hash"); | |||
methods.push_back("is_same_st"); | |||
methods.push_back("props"); | |||
methods.push_back("make_name"); | |||
} | |||
if (!methods.empty()) { | |||
os << formatv( | |||
"OP_TRAIT_REG({0}, {0})", op.getCppClassName() | |||
); | |||
for (auto&& i : methods) { | |||
os << formatv( | |||
"\n .{0}({1})", i, formatMethImpl(i) | |||
); | |||
} | |||
os << ";\n\n"; | |||
} | |||
} | |||
} // namespace | |||
bool gen_op_def_c_header(raw_ostream &os, llvm::RecordKeeper &keeper) { | |||
foreach_operator(keeper, [&](MgbOp& op) { | |||
OpDefEmitter emitter(op, os); | |||
emitter.emit_header(); | |||
emitter.emit_tpl_spl(); | |||
}); | |||
return false; | |||
} | |||
bool gen_op_def_c_body(raw_ostream &os, llvm::RecordKeeper &keeper) { | |||
foreach_operator(keeper, [&](MgbOp& op) { | |||
OpDefEmitter emitter(op, os); | |||
emitter.emit_body(); | |||
}); | |||
return false; | |||
} | |||
} // namespace mlir::tblgen |
@@ -0,0 +1,21 @@ | |||
/** | |||
* \file imperative/tablegen/targets/cpp_class.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "../helper.h" | |||
namespace mlir::tblgen { | |||
bool gen_op_def_c_header(raw_ostream &os, llvm::RecordKeeper &keeper); | |||
bool gen_op_def_c_body(raw_ostream &os, llvm::RecordKeeper &keeper); | |||
} // namespace mlir::tblgen |
@@ -0,0 +1,142 @@ | |||
/** | |||
* \file imperative/tablegen/targets/pybind11.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "./pybind11.h" | |||
#include "../emitter.h" | |||
namespace mlir::tblgen { | |||
namespace { | |||
class OpDefEmitter final: public EmitterBase { | |||
public: | |||
OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_): | |||
EmitterBase(os_, env_), op(op_) {} | |||
void emit(); | |||
private: | |||
MgbOp& op; | |||
}; | |||
void OpDefEmitter::emit() { | |||
auto className = op.getCppClassName(); | |||
os << formatv( | |||
"py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n", | |||
className | |||
); | |||
for (auto&& i : op.getMgbAttributes()) { | |||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
unsigned int enumID; | |||
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
auto&& aliasBase = alias->getAliasBase(); | |||
enumID = | |||
llvm::cast<MgbEnumAttr>(aliasBase) | |||
.getBaseRecord()->getID(); | |||
} else { | |||
enumID = attr->getBaseRecord()->getID(); | |||
} | |||
auto&& enumAlias = env().enumAlias; | |||
auto&& iter = enumAlias.find(enumID); | |||
if (iter == enumAlias.end()) { | |||
os << formatv( | |||
"py::enum_<{0}::{1}>({0}Inst, \"{1}\")", | |||
className, attr->getEnumName() | |||
); | |||
std::vector<std::string> body; | |||
for (auto&& i: attr->getEnumMembers()) { | |||
os << formatv( | |||
"\n .value(\"{2}\", {0}::{1}::{2})", | |||
className, attr->getEnumName(), i | |||
); | |||
body.push_back(formatv( | |||
"if (str == \"{2}\") return {0}::{1}::{2};", | |||
className, attr->getEnumName(), i | |||
)); | |||
} | |||
if (attr->getEnumCombinedFlag()) { | |||
//! define operator | | |||
os << formatv( | |||
"\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ " | |||
"\n return static_cast<{0}::{1}>(uint32_t(s0) | uint32_t(s1));" | |||
"\n })", | |||
className, attr->getEnumName()); | |||
//! define operator & | |||
os << formatv( | |||
"\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{" | |||
"\n return static_cast<{0}::{1}>(uint32_t(s0) & uint32_t(s1));" | |||
"\n })", | |||
className, attr->getEnumName()); | |||
} | |||
os << formatv( | |||
"\n .def(py::init([](const std::string& in) {" | |||
"\n auto&& str = normalize_enum(in);" | |||
"\n {0}" | |||
"\n throw py::cast_error(\"invalid enum value \" + in);" | |||
"\n }));\n", | |||
llvm::join(body, "\n ") | |||
); | |||
os << formatv( | |||
"py::implicitly_convertible<std::string, {0}::{1}>();\n\n", | |||
className, attr->getEnumName() | |||
); | |||
enumAlias.emplace(enumID, | |||
std::make_pair(className, attr->getEnumName())); | |||
} else { | |||
os << formatv( | |||
"{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n", | |||
className, attr->getEnumName(), | |||
iter->second.first, iter->second.second | |||
); | |||
} | |||
} | |||
} | |||
// generate op class binding | |||
os << formatv("{0}Inst", className); | |||
bool hasDefaultCtor = op.getMgbAttributes().empty(); | |||
if (!hasDefaultCtor) { | |||
os << "\n .def(py::init<"; | |||
std::vector<llvm::StringRef> targs; | |||
for (auto &&i : op.getMgbAttributes()) { | |||
targs.push_back(i.attr.getReturnType()); | |||
} | |||
os << llvm::join(targs, ", "); | |||
os << ", std::string>()"; | |||
for (auto &&i : op.getMgbAttributes()) { | |||
os << formatv(", py::arg(\"{0}\")", i.name); | |||
auto defaultValue = i.attr.getDefaultValue(); | |||
if (!defaultValue.empty()) { | |||
os << formatv(" = {0}", defaultValue); | |||
} else { | |||
hasDefaultCtor = true; | |||
} | |||
} | |||
os << ", py::arg(\"scope\") = {})"; | |||
} | |||
if (hasDefaultCtor) { | |||
os << "\n .def(py::init<>())"; | |||
} | |||
for (auto &&i : op.getMgbAttributes()) { | |||
os << formatv( | |||
"\n .def_readwrite(\"{0}\", &{1}::{0})", | |||
i.name, className | |||
); | |||
} | |||
os << ";\n\n"; | |||
} | |||
} // namespace | |||
bool gen_op_def_pybind11(raw_ostream &os, llvm::RecordKeeper &keeper) { | |||
Environment env; | |||
using namespace std::placeholders; | |||
foreach_operator(keeper, [&](MgbOp& op) { | |||
OpDefEmitter(op, os, env).emit(); | |||
}); | |||
return false; | |||
} | |||
} // namespace mlir::tblgen |
@@ -0,0 +1,19 @@ | |||
/** | |||
* \file imperative/tablegen/targets/pybind11.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "../helper.h" | |||
namespace mlir::tblgen { | |||
bool gen_op_def_pybind11(raw_ostream &os, llvm::RecordKeeper &keeper); | |||
} // namespace mlir::tblgen |
@@ -0,0 +1,313 @@ | |||
/** | |||
* \file imperative/tablegen/targets/python_c_extension.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "python_c_extension.h" | |||
#include "../emitter.h" | |||
namespace mlir::tblgen { | |||
namespace { | |||
struct Initproc { | |||
std::string func; | |||
Initproc(std::string&& s): func(std::move(s)) {} | |||
std::string operator()(std::string argument) { | |||
return formatv("{0}({1})", func, argument); | |||
} | |||
}; | |||
class OpDefEmitter: public EmitterBase { | |||
public: | |||
OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_): | |||
EmitterBase(os_, env_), op(op_) { | |||
ctx.withSelf(op.getCppClassName()); | |||
} | |||
Initproc emit(); | |||
private: | |||
void emit_class(); | |||
void emit_py_init(); | |||
void emit_py_getsetters(); | |||
Initproc emit_initproc(); | |||
MgbOp& op; | |||
std::vector<Initproc> subclasses; | |||
mlir::tblgen::FmtContext ctx; | |||
}; | |||
class EnumAttrEmitter: public EmitterBase { | |||
public: | |||
EnumAttrEmitter(llvm::StringRef parent, MgbEnumAttr* attr_, raw_ostream& os_, Environment& env_): | |||
EmitterBase(os_, env_), attr(attr_) { | |||
unsigned int enumID; | |||
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
auto&& aliasBase = alias->getAliasBase(); | |||
enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID(); | |||
} else { | |||
enumID = attr->getBaseRecord()->getID(); | |||
} | |||
ctx.addSubst("enumTpl", attr->getEnumCombinedFlag() ? "BitCombinedEnumWrapper" : "EnumWrapper"); | |||
ctx.addSubst("opClass", parent); | |||
ctx.addSubst("enumClass", attr->getEnumName()); | |||
firstOccur = env().enumAlias.emplace(enumID, std::make_pair(parent, attr->getEnumName())).second; | |||
} | |||
Initproc emit(); | |||
protected: | |||
void emit_tpl_spl(); | |||
Initproc emit_initproc(); | |||
MgbEnumAttr* attr; | |||
bool firstOccur; | |||
mlir::tblgen::FmtContext ctx; | |||
}; | |||
Initproc EnumAttrEmitter::emit() { | |||
emit_tpl_spl(); | |||
return emit_initproc(); | |||
} | |||
void EnumAttrEmitter::emit_tpl_spl() { | |||
if (!firstOccur) return; | |||
os << tgfmt( | |||
"template<> PyTypeObject $enumTpl<$opClass::$enumClass>::type={};\n", | |||
&ctx); | |||
os << tgfmt( | |||
"template<> const char* $enumTpl<$opClass::$enumClass>::name = " | |||
"\"$opClass.$enumClass\";\n", | |||
&ctx); | |||
if (attr->getEnumCombinedFlag()) { | |||
os << tgfmt( | |||
"template<> PyNumberMethods " | |||
"$enumTpl<$opClass::$enumClass>::number_methods={};\n", | |||
&ctx); | |||
os << tgfmt( | |||
"template<> struct EnumTrait<$opClass::$enumClass> { static constexpr " | |||
"bool is_bit_combined = true;};\n", | |||
&ctx); | |||
} | |||
auto str2type = [&](auto&& i) -> std::string { | |||
return tgfmt("{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx, i); | |||
}; | |||
os << tgfmt(R"( | |||
template<> std::unordered_map<std::string, $opClass::$enumClass> | |||
$enumTpl<$opClass::$enumClass>::str2type = {$0}; | |||
)", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), str2type), ", ")); | |||
auto type2str = [&](auto&& i) -> std::string { | |||
return tgfmt("{$opClass::$enumClass::$0, normalize_enum(\"$0\")}", &ctx, i); | |||
}; | |||
os << tgfmt(R"( | |||
template<> std::unordered_map<$opClass::$enumClass, std::string> | |||
$enumTpl<$opClass::$enumClass>::type2str = {$0}; | |||
)", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), type2str), ", ")); | |||
} | |||
Initproc EnumAttrEmitter::emit_initproc() { | |||
std::string initproc = formatv("_init_py_{0}_{1}", | |||
ctx.getSubstFor("opClass"), ctx.getSubstFor("enumClass")); | |||
os << tgfmt(R"( | |||
void $0(PyTypeObject& py_type) { | |||
auto& e_type = $enumTpl<$opClass::$enumClass>::type; | |||
)", &ctx, initproc); | |||
if (firstOccur) { | |||
os << tgfmt(R"( | |||
e_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||
e_type.tp_name = "megengine.core._imperative_rt.ops.$opClass.$enumClass"; | |||
e_type.tp_basicsize = sizeof($enumTpl<$opClass::$enumClass>); | |||
e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
e_type.tp_doc = "$opClass.$enumClass"; | |||
e_type.tp_base = &PyBaseObject_Type; | |||
e_type.tp_repr = $enumTpl<$opClass::$enumClass>::py_repr; | |||
e_type.tp_richcompare = $enumTpl<$opClass::$enumClass>::tp_richcompare; | |||
)", &ctx); | |||
if (attr->getEnumCombinedFlag()) { | |||
// only bit combined enum could new instance because bitwise operation, | |||
// others should always use singleton | |||
os << tgfmt(R"( | |||
e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum; | |||
e_type.tp_init = $enumTpl<$opClass::$enumClass>::py_init; | |||
auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods; | |||
number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or; | |||
number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and; | |||
e_type.tp_as_number = &number_method; | |||
)", &ctx); | |||
} | |||
os << " mgb_assert(PyType_Ready(&e_type) >= 0);\n"; | |||
for (auto&& i : attr->getEnumMembers()) { | |||
os << tgfmt(R"({ | |||
PyObject* inst = e_type.tp_alloc(&e_type, 0); | |||
reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0; | |||
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "$0", inst) >= 0); | |||
PyType_Modified(&e_type); | |||
})", &ctx, i); | |||
} | |||
} | |||
os << tgfmt(R"( | |||
mgb_assert(PyDict_SetItemString( | |||
py_type.tp_dict, "$enumClass", reinterpret_cast<PyObject*>(&e_type)) >= 0); | |||
)", &ctx); | |||
os << "}\n"; | |||
return initproc; | |||
} | |||
Initproc OpDefEmitter::emit() { | |||
for (auto&& i : op.getMgbAttributes()) { | |||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
subclasses.push_back(EnumAttrEmitter(op.getCppClassName(), attr, os, env()).emit()); | |||
} | |||
} | |||
emit_class(); | |||
emit_py_init(); | |||
emit_py_getsetters(); | |||
return emit_initproc(); | |||
} | |||
void OpDefEmitter::emit_class() { | |||
os << tgfmt(R"( | |||
PyOpDefBegin($_self) // { | |||
static PyGetSetDef py_getsetters[]; | |||
static int py_init(PyObject *self, PyObject *args, PyObject *kwds); | |||
// }; | |||
PyOpDefEnd($_self) | |||
)", &ctx); | |||
} | |||
void OpDefEmitter::emit_py_init() { | |||
std::string initBody; | |||
if (!op.getMgbAttributes().empty()) { | |||
initBody += "static const char* kwlist[] = {"; | |||
std::vector<llvm::StringRef> attr_name_list; | |||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
attr_name_list.push_back(attr.name); | |||
}); | |||
attr_name_list.push_back("scope"); | |||
llvm::for_each(attr_name_list, [&](auto&& attr) { | |||
initBody += formatv("\"{0}\", ", attr); | |||
}); | |||
initBody += "NULL};\n"; | |||
initBody += " PyObject "; | |||
auto initializer = [&](auto&& attr) -> std::string { | |||
return formatv("*{0} = NULL", attr); | |||
}; | |||
initBody += llvm::join(llvm::map_range(attr_name_list, initializer), ", ") + ";\n"; | |||
initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|"; | |||
// an extra slot created for name | |||
initBody += std::string(attr_name_list.size(), 'O'); | |||
initBody += "\", const_cast<char**>(kwlist)"; | |||
llvm::for_each(attr_name_list, [&](auto&& attr) { | |||
initBody += formatv(", &{0}", attr); | |||
}); | |||
initBody += "))\n"; | |||
initBody += " return -1;\n"; | |||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
initBody += tgfmt(R"( | |||
if ($0) { | |||
try { | |||
reinterpret_cast<PyOp($_self)*>(self)->inst().$0 = | |||
pyobj_convert_generic<decltype($_self::$0)>::from($0); | |||
} CATCH_ALL(-1) | |||
} | |||
)", &ctx, attr.name); | |||
}); | |||
initBody += tgfmt(R"( | |||
if (scope) { | |||
try { | |||
reinterpret_cast<PyOp(OpDef)*>(self)->op | |||
->set_scope(pyobj_convert_generic<std::string>::from(scope)); | |||
} CATCH_ALL(-1) | |||
} | |||
)", &ctx); | |||
} | |||
initBody += "\n return 0;"; | |||
os << tgfmt(R"( | |||
int PyOp($_self)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { | |||
$0 | |||
} | |||
)", &ctx, initBody); | |||
} | |||
void OpDefEmitter::emit_py_getsetters() { | |||
auto f = [&](auto&& attr) -> std::string { | |||
return tgfmt( | |||
"{const_cast<char*>(\"$0\"), py_get_generic($_self, $0), py_set_generic($_self, $0), const_cast<char*>(\"$0\"), NULL},", | |||
&ctx, attr.name); | |||
}; | |||
os << tgfmt(R"( | |||
PyGetSetDef PyOp($_self)::py_getsetters[] = { | |||
$0 | |||
{NULL} /* Sentinel */ | |||
}; | |||
)", &ctx, llvm::join(llvm::map_range(op.getMgbAttributes(), f), "\n ")); | |||
} | |||
Initproc OpDefEmitter::emit_initproc() { | |||
std::string initproc = formatv("_init_py_{0}", op.getCppClassName()); | |||
std::string subclass_init_call; | |||
for (auto&& i : subclasses) { | |||
subclass_init_call += formatv(" {0};\n", i("py_type")); | |||
} | |||
os << tgfmt(R"( | |||
void $0(py::module m) { | |||
using py_op = PyOp($_self); | |||
auto& py_type = PyOpType($_self); | |||
py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||
py_type.tp_name = "megengine.core._imperative_rt.ops.$_self"; | |||
py_type.tp_basicsize = sizeof(PyOp($_self)); | |||
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
py_type.tp_doc = "$_self"; | |||
py_type.tp_base = &PyOpType(OpDef); | |||
py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||
py_type.tp_new = py_new_generic<py_op>; | |||
py_type.tp_init = py_op::py_init; | |||
py_type.tp_getset = py_op::py_getsetters; | |||
mgb_assert(PyType_Ready(&py_type) >= 0); | |||
$1 | |||
PyType_Modified(&py_type); | |||
m.add_object("$_self", reinterpret_cast<PyObject*>(&py_type)); | |||
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace($_self::typeinfo(), &py_type).second); | |||
} | |||
)", &ctx, initproc, subclass_init_call); | |||
return initproc; | |||
} | |||
} // namespace | |||
bool gen_op_def_python_c_extension(raw_ostream &os, llvm::RecordKeeper &keeper) { | |||
Environment env; | |||
using namespace std::placeholders; | |||
std::vector<Initproc> initprocs; | |||
foreach_operator(keeper, [&](MgbOp& op) { | |||
initprocs.emplace_back(OpDefEmitter(op, os, env).emit()); | |||
}); | |||
os << "#define INIT_ALL_OP(m)"; | |||
for(auto&& init : initprocs) { | |||
os << formatv(" \\\n {0};", init("m")); | |||
} | |||
os << "\n"; | |||
return false; | |||
} | |||
} // namespace mlir::tblgen |
@@ -0,0 +1,19 @@ | |||
/** | |||
* \file imperative/tablegen/targets/python_c_extension.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "../helper.h" | |||
namespace mlir::tblgen { | |||
bool gen_op_def_python_c_extension(raw_ostream &os, llvm::RecordKeeper &keeper); | |||
} // namespace mlir::tblgen |