You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_tablegen.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import argparse
  4. import collections
  5. import hashlib
  6. import io
  7. import os
  8. import struct
  9. import textwrap
  10. from gen_param_defs import IndentWriterBase, ParamDef, member_defs
  11. # FIXME: move supportToString flag definition into the param def source file
  12. ENUM_TO_STRING_SPECIAL_RULES = [("Elemwise", "Mode"), ("ElemwiseMultiType", "Mode")]
  13. class ConverterWriter(IndentWriterBase):
  14. _skip_current_param = False
  15. _last_param = None
  16. _current_tparams = None
  17. _packed = None
  18. _const = None
  19. def __call__(self, fout, defs):
  20. super().__call__(fout)
  21. self._write("// %s", self._get_header())
  22. self._write("#ifndef MGB_PARAM")
  23. self._write("#define MGB_PARAM")
  24. self._process(defs)
  25. self._write("#endif // MGB_PARAM")
  26. def _ctype2attr(self, ctype, value):
  27. if ctype == "uint32_t":
  28. return "MgbUI32Attr", value
  29. if ctype == "uint64_t":
  30. return "MgbUI64Attr", value
  31. if ctype == "int32_t":
  32. return "MgbI32Attr", value
  33. if ctype == "float":
  34. return "MgbF32Attr", value
  35. if ctype == "double":
  36. return "MgbF64Attr", value
  37. if ctype == "bool":
  38. return "MgbBoolAttr", value
  39. if ctype == "DTypeEnum":
  40. self._packed = False
  41. return "MgbDTypeAttr", "megdnn::DType::from_enum(megdnn::{})".format(value)
  42. raise RuntimeError("unknown ctype")
  43. def _on_param_begin(self, p):
  44. self._last_param = p
  45. self._packed = True
  46. self._current_tparams = []
  47. self._const = set()
  48. def _on_param_end(self, p):
  49. if self._skip_current_param:
  50. self._skip_current_param = False
  51. return
  52. if self._packed:
  53. self._write(
  54. 'class {0}ParamBase<string accessor> : MgbPackedParamBase<"{0}", accessor> {{'.format(
  55. p.name
  56. ),
  57. indent=1,
  58. )
  59. else:
  60. self._write('def {0}Param: MgbParamBase<"{0}"> {{'.format(p.name), indent=1)
  61. self._write("let fields = (ins", indent=1)
  62. self._write(",\n{}".format(self._cur_indent).join(self._current_tparams))
  63. self._write(");", indent=-1)
  64. self._write("}\n", indent=-1)
  65. if self._packed:
  66. self._write('def {0}Param : {0}ParamBase<"param">;\n'.format(p.name))
  67. self._current_tparams = None
  68. self._packed = None
  69. self._const = None
  70. def _wrapped_with_default_value(self, attr, default):
  71. return 'MgbDefaultValuedAttr<{}, "{}">'.format(attr, default)
  72. def _on_member_enum(self, e):
  73. p = self._last_param
  74. # Note: always generate llvm Record def for enum attribute even it was not
  75. # directly used by any operator, or other enum couldn't alias to this enum
  76. td_class = "{}{}".format(p.name, e.name)
  77. fullname = "::megdnn::param::{}".format(p.name)
  78. enum_def = 'MgbEnumAttr<"{}", "{}", ['.format(fullname, e.name)
  79. def format(v):
  80. return '"{}"'.format(str(v).split(" ")[0].split("=")[0])
  81. enum_def += ",".join(format(i) for i in e.members)
  82. if e.combined:
  83. enum_def += "], 1"
  84. else:
  85. enum_def += "], 0"
  86. if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)):
  87. enum_def += ", 1" # whether generate ToStringTrait
  88. enum_def += ">"
  89. self._write("def {} : {};".format(td_class, enum_def))
  90. if self._skip_current_param:
  91. return
  92. # wrapped with default value
  93. if e.combined:
  94. default_val = "static_cast<{}::{}>({})".format(
  95. fullname, e.name, e.compose_combined_enum(e.default)
  96. )
  97. else:
  98. default_val = "{}::{}::{}".format(
  99. fullname, e.name, str(e.members[e.default]).split(" ")[0].split("=")[0]
  100. )
  101. wrapped = self._wrapped_with_default_value(td_class, default_val)
  102. self._current_tparams.append("{}:${}".format(wrapped, e.name_field))
  103. def _on_member_enum_alias(self, e):
  104. p = self._last_param
  105. if self._skip_current_param:
  106. return
  107. # write enum attr def
  108. td_class = "{}{}".format(p.name, e.name)
  109. fullname = "::megdnn::param::{}".format(p.name)
  110. base_td_class = "{}{}".format(e.src_class, e.src_name)
  111. enum_def = 'MgbEnumAliasAttr<"{}", "{}", {}>'.format(
  112. fullname, e.name, base_td_class
  113. )
  114. self._write("def {} : {};".format(td_class, enum_def))
  115. # wrapped with default value
  116. s = e.src_enum
  117. if s.combined:
  118. default_val = "static_cast<{}::{}>({})".format(
  119. fullname, e.name, s.compose_combined_enum(e.get_default())
  120. )
  121. else:
  122. default_val = "{}::{}::{}".format(
  123. fullname,
  124. e.name,
  125. str(s.members[e.get_default()]).split(" ")[0].split("=")[0],
  126. )
  127. wrapped = self._wrapped_with_default_value(td_class, default_val)
  128. self._current_tparams.append("{}:${}".format(wrapped, e.name_field))
  129. def _on_member_field(self, f):
  130. if self._skip_current_param:
  131. return
  132. attr, value = self._ctype2attr(f.dtype.cname, str(f.default))
  133. if str(value) in self._const:
  134. value = "::megdnn::param::{}::{}".format(self._last_param.name, value)
  135. wrapped = self._wrapped_with_default_value(attr, value)
  136. self._current_tparams.append("{}:${}".format(wrapped, f.name))
  137. def _on_const_field(self, f):
  138. self._const.add(str(f.name))
  139. def main():
  140. parser = argparse.ArgumentParser("generate op param tablegen file")
  141. parser.add_argument("input")
  142. parser.add_argument("output")
  143. args = parser.parse_args()
  144. with open(args.input) as fin:
  145. inputs = fin.read()
  146. exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc})
  147. input_hash = hashlib.sha256()
  148. input_hash.update(inputs.encode(encoding="UTF-8"))
  149. input_hash = input_hash.hexdigest()
  150. writer = ConverterWriter()
  151. with open(args.output, "w") as fout:
  152. writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)
  153. if __name__ == "__main__":
  154. main()