|
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
-
- import argparse
- import collections
- import hashlib
- import io
- import os
- import struct
- import textwrap
-
- from gen_param_defs import IndentWriterBase, ParamDef, member_defs
-
- # FIXME: move supportToString flag definition into the param def source file
- ENUM_TO_STRING_SPECIAL_RULES = [("Elemwise", "Mode"), ("ElemwiseMultiType", "Mode")]
-
-
- class ConverterWriter(IndentWriterBase):
- _skip_current_param = False
- _last_param = None
- _current_tparams = None
- _packed = None
- _const = None
-
- def __call__(self, fout, defs):
- super().__call__(fout)
- self._write("// %s", self._get_header())
- self._write("#ifndef MGB_PARAM")
- self._write("#define MGB_PARAM")
- self._process(defs)
- self._write("#endif // MGB_PARAM")
-
- def _ctype2attr(self, ctype, value):
- if ctype == "uint32_t":
- return "MgbUI32Attr", value
- if ctype == "uint64_t":
- return "MgbUI64Attr", value
- if ctype == "int32_t":
- return "MgbI32Attr", value
- if ctype == "float":
- return "MgbF32Attr", value
- if ctype == "double":
- return "MgbF64Attr", value
- if ctype == "bool":
- return "MgbBoolAttr", value
- if ctype == "DTypeEnum":
- self._packed = False
- return "MgbDTypeAttr", "megdnn::DType::from_enum(megdnn::{})".format(value)
- raise RuntimeError("unknown ctype")
-
- def _on_param_begin(self, p):
- self._last_param = p
- self._packed = True
- self._current_tparams = []
- self._const = set()
-
- def _on_param_end(self, p):
- if self._skip_current_param:
- self._skip_current_param = False
- return
- if self._packed:
- self._write(
- 'class {0}ParamBase<string accessor> : MgbPackedParamBase<"{0}", accessor> {{'.format(
- p.name
- ),
- indent=1,
- )
- else:
- self._write('def {0}Param: MgbParamBase<"{0}"> {{'.format(p.name), indent=1)
- self._write("let fields = (ins", indent=1)
- self._write(",\n{}".format(self._cur_indent).join(self._current_tparams))
- self._write(");", indent=-1)
- self._write("}\n", indent=-1)
- if self._packed:
- self._write('def {0}Param : {0}ParamBase<"param">;\n'.format(p.name))
- self._current_tparams = None
- self._packed = None
- self._const = None
-
- def _wrapped_with_default_value(self, attr, default):
- return 'MgbDefaultValuedAttr<{}, "{}">'.format(attr, default)
-
- def _on_member_enum(self, e):
- p = self._last_param
-
- # Note: always generate llvm Record def for enum attribute even it was not
- # directly used by any operator, or other enum couldn't alias to this enum
- td_class = "{}{}".format(p.name, e.name)
- fullname = "::megdnn::param::{}".format(p.name)
- enum_def = 'MgbEnumAttr<"{}", "{}", ['.format(fullname, e.name)
-
- def format(v):
- return '"{}"'.format(str(v).split(" ")[0].split("=")[0])
-
- enum_def += ",".join(format(i) for i in e.members)
-
- if e.combined:
- enum_def += "], 1"
- else:
- enum_def += "], 0"
-
- if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)):
- enum_def += ", 1" # whether generate ToStringTrait
- enum_def += ">"
-
- self._write("def {} : {};".format(td_class, enum_def))
- if self._skip_current_param:
- return
-
- # wrapped with default value
- if e.combined:
- default_val = "static_cast<{}::{}>({})".format(
- fullname, e.name, e.compose_combined_enum(e.default)
- )
- else:
- default_val = "{}::{}::{}".format(
- fullname, e.name, str(e.members[e.default]).split(" ")[0].split("=")[0]
- )
-
- wrapped = self._wrapped_with_default_value(td_class, default_val)
-
- self._current_tparams.append("{}:${}".format(wrapped, e.name_field))
-
- def _on_member_enum_alias(self, e):
- p = self._last_param
- if self._skip_current_param:
- return
-
- # write enum attr def
- td_class = "{}{}".format(p.name, e.name)
- fullname = "::megdnn::param::{}".format(p.name)
- base_td_class = "{}{}".format(e.src_class, e.src_name)
- enum_def = 'MgbEnumAliasAttr<"{}", "{}", {}>'.format(
- fullname, e.name, base_td_class
- )
- self._write("def {} : {};".format(td_class, enum_def))
-
- # wrapped with default value
- s = e.src_enum
- if s.combined:
- default_val = "static_cast<{}::{}>({})".format(
- fullname, e.name, s.compose_combined_enum(e.get_default())
- )
- else:
- default_val = "{}::{}::{}".format(
- fullname,
- e.name,
- str(s.members[e.get_default()]).split(" ")[0].split("=")[0],
- )
-
- wrapped = self._wrapped_with_default_value(td_class, default_val)
-
- self._current_tparams.append("{}:${}".format(wrapped, e.name_field))
-
- def _on_member_field(self, f):
- if self._skip_current_param:
- return
- attr, value = self._ctype2attr(f.dtype.cname, str(f.default))
- if str(value) in self._const:
- value = "::megdnn::param::{}::{}".format(self._last_param.name, value)
- wrapped = self._wrapped_with_default_value(attr, value)
- self._current_tparams.append("{}:${}".format(wrapped, f.name))
-
- def _on_const_field(self, f):
- self._const.add(str(f.name))
-
-
- def main():
- parser = argparse.ArgumentParser("generate op param tablegen file")
- parser.add_argument("input")
- parser.add_argument("output")
- args = parser.parse_args()
-
- with open(args.input) as fin:
- inputs = fin.read()
- exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc})
- input_hash = hashlib.sha256()
- input_hash.update(inputs.encode(encoding="UTF-8"))
- input_hash = input_hash.hexdigest()
-
- writer = ConverterWriter()
- with open(args.output, "w") as fout:
- writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)
-
-
- if __name__ == "__main__":
- main()
|