- #include "./pybind11.h"
- #include "../emitter.h"
-
- namespace mlir::tblgen {
- namespace {
- class OpDefEmitter final : public EmitterBase {
- public:
- OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_)
- : EmitterBase(os_, env_), op(op_) {}
-
- void emit();
-
- private:
- MgbOp& op;
- };
-
- void OpDefEmitter::emit() {
- auto className = op.getCppClassName();
- os << formatv(
- "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n",
- className);
- for (auto&& i : op.getMgbAttributes()) {
- if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
- unsigned int enumID;
- if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
- auto&& aliasBase = alias->getAliasBase();
- enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID();
- } else {
- enumID = attr->getBaseRecord()->getID();
- }
- auto&& enumAlias = env().enumAlias;
- auto&& iter = enumAlias.find(enumID);
- if (iter == enumAlias.end()) {
- os << formatv(
- "py::enum_<{0}::{1}>({0}Inst, \"{1}\")", className,
- attr->getEnumName());
- std::vector<std::string> body;
- for (auto&& i : attr->getEnumMembers()) {
- size_t d1 = i.find(' ');
- size_t d2 = i.find('=');
- size_t d = d1 <= d2 ? d1 : d2;
- os << formatv(
- "\n .value(\"{2}\", {0}::{1}::{2})", className,
- attr->getEnumName(), i.substr(0, d));
- body.push_back(
- formatv("if (str == \"{2}\") return {0}::{1}::{2};",
- className, attr->getEnumName(), i.substr(0, d)));
- }
- if (attr->getEnumCombinedFlag()) {
- //! define operator |
- os << formatv(
- "\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ "
- "\n return static_cast<{0}::{1}>(uint32_t(s0) | "
- "uint32_t(s1));"
- "\n })",
- className, attr->getEnumName());
- //! define operator &
- os << formatv(
- "\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{"
- "\n return static_cast<{0}::{1}>(uint32_t(s0) & "
- "uint32_t(s1));"
- "\n })",
- className, attr->getEnumName());
- }
- os << formatv(
- "\n .def(py::init([](const std::string& in) {"
- "\n auto&& str = normalize_enum(in);"
- "\n {0}"
- "\n throw py::cast_error(\"invalid enum value \" + in);"
- "\n }));\n",
- llvm::join(body, "\n "));
- os << formatv(
- "py::implicitly_convertible<std::string, {0}::{1}>();\n\n",
- className, attr->getEnumName());
- enumAlias.emplace(
- enumID, std::make_pair(className, attr->getEnumName()));
- } else {
- os << formatv(
- "{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n", className,
- attr->getEnumName(), iter->second.first, iter->second.second);
- }
- }
- }
- // generate op class binding
- os << formatv("{0}Inst", className);
- bool hasDefaultCtor = op.getMgbAttributes().empty();
- if (!hasDefaultCtor) {
- os << "\n .def(py::init<";
- std::vector<llvm::StringRef> targs;
- for (auto&& i : op.getMgbAttributes()) {
- targs.push_back(i.attr.getReturnType());
- }
- os << llvm::join(targs, ", ");
- os << ", std::string>()";
- for (auto&& i : op.getMgbAttributes()) {
- os << formatv(", py::arg(\"{0}\")", i.name);
- auto defaultValue = i.attr.getDefaultValue();
- if (!defaultValue.empty()) {
- os << formatv(" = {0}", defaultValue);
- } else {
- hasDefaultCtor = true;
- }
- }
- os << ", py::arg(\"scope\") = {})";
- }
- if (hasDefaultCtor) {
- os << "\n .def(py::init<>())";
- }
- for (auto&& i : op.getMgbAttributes()) {
- os << formatv("\n .def_readwrite(\"{0}\", &{1}::{0})", i.name, className);
- }
- os << ";\n\n";
- }
- } // namespace
-
- bool gen_op_def_pybind11(raw_ostream& os, llvm::RecordKeeper& keeper) {
- Environment env;
- using namespace std::placeholders;
- foreach_operator(keeper, [&](MgbOp& op) { OpDefEmitter(op, os, env).emit(); });
- return false;
- }
- } // namespace mlir::tblgen
|