You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_flatbuffers_schema.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import argparse
  4. import collections
  5. import hashlib
  6. import io
  7. import os
  8. import struct
  9. import textwrap
  10. from gen_param_defs import IndentWriterBase, ParamDef, member_defs
  11. def _cname_to_fbname(cname):
  12. return {
  13. "uint32_t": "uint",
  14. "uint64_t": "ulong",
  15. "int32_t": "int",
  16. "float": "float",
  17. "double": "double",
  18. "DTypeEnum": "DTypeEnum",
  19. "bool": "bool",
  20. }[cname]
  21. def scramble_enum_member_name(name):
  22. s = name.find("<<")
  23. if s != -1:
  24. name = name[0 : name.find("=") + 1] + " " + name[s + 2 :]
  25. if name in ("MIN", "MAX"):
  26. return name + "_"
  27. o_name = name.split(" ")[0].split("=")[0]
  28. if o_name in ("MIN", "MAX"):
  29. return name.replace(o_name, o_name + "_")
  30. return name
  31. class FlatBuffersWriter(IndentWriterBase):
  32. _skip_current_param = False
  33. _last_param = None
  34. _enums = None
  35. _used_enum = None
  36. _cur_const_val = {}
  37. def __call__(self, fout, defs):
  38. param_io = io.StringIO()
  39. super().__call__(param_io)
  40. self._used_enum = set()
  41. self._enums = {}
  42. self._process(defs)
  43. super().__call__(fout)
  44. self._write("// %s", self._get_header())
  45. self._write('include "dtype.fbs";')
  46. self._write("namespace mgb.serialization.fbs.param;\n")
  47. self._write_enums()
  48. self._write(param_io.getvalue())
  49. def _write_enums(self):
  50. for (p, e) in sorted(self._used_enum):
  51. name = p + e
  52. e = self._enums[(p, e)]
  53. self._write_doc(e.name)
  54. attribute = "(bit_flags)" if e.combined else ""
  55. self._write("enum %s%s : uint %s {", p, e.name, attribute, indent=1)
  56. for idx, member in enumerate(e.members):
  57. self._write_doc(member)
  58. self._write("%s,", scramble_enum_member_name(str(member)))
  59. self._write("}\n", indent=-1)
  60. def _write_doc(self, doc):
  61. if not isinstance(doc, member_defs.Doc) or not doc.doc:
  62. return
  63. doc_lines = []
  64. if doc.no_reformat:
  65. doc_lines = doc.raw_lines
  66. else:
  67. doc = doc.doc.replace("\n", " ")
  68. text_width = 80 - len(self._cur_indent) - 4
  69. doc_lines = textwrap.wrap(doc, text_width)
  70. for line in doc_lines:
  71. self._write("/// " + line)
  72. def _on_param_begin(self, p):
  73. self._last_param = p
  74. self._cur_const_val = {}
  75. self._write_doc(p.name)
  76. self._write("table %s {", p.name, indent=1)
  77. def _on_param_end(self, p):
  78. if self._skip_current_param:
  79. self._skip_current_param = False
  80. return
  81. self._write("}\n", indent=-1)
  82. def _on_member_enum(self, e):
  83. p = self._last_param
  84. key = str(p.name), str(e.name)
  85. self._enums[key] = e
  86. if self._skip_current_param:
  87. return
  88. self._write_doc(e.name)
  89. self._used_enum.add(key)
  90. if e.combined:
  91. default = e.compose_combined_enum(e.default)
  92. else:
  93. default = scramble_enum_member_name(
  94. str(e.members[e.default]).split(" ")[0].split("=")[0]
  95. )
  96. self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default)
  97. def _resolve_const(self, v):
  98. while v in self._cur_const_val:
  99. v = self._cur_const_val[v]
  100. return v
  101. def _on_member_field(self, f):
  102. if self._skip_current_param:
  103. return
  104. self._write_doc(f.name)
  105. self._write(
  106. "%s:%s = %s;",
  107. f.name,
  108. _cname_to_fbname(f.dtype.cname),
  109. self._get_fb_default(self._resolve_const(f.default)),
  110. )
  111. def _on_const_field(self, f):
  112. self._cur_const_val[str(f.name)] = str(f.default)
  113. def _on_member_enum_alias(self, e):
  114. if self._skip_current_param:
  115. return
  116. self._used_enum.add((e.src_class, e.src_name))
  117. enum_name = e.src_class + e.src_name
  118. s = e.src_enum
  119. if s.combined:
  120. default = s.compose_combined_enum(e.get_default())
  121. else:
  122. default = scramble_enum_member_name(
  123. str(s.members[e.get_default()]).split(" ")[0].split("=")[0]
  124. )
  125. self._write("%s:%s = %s;", e.name_field, enum_name, default)
  126. def _get_fb_default(self, cppdefault):
  127. if not isinstance(cppdefault, str):
  128. return cppdefault
  129. d = cppdefault
  130. if d.endswith("f"): # 1.f
  131. return d[:-1]
  132. if d.endswith("ull"):
  133. return d[:-3]
  134. if d.startswith("DTypeEnum::"):
  135. return d[11:]
  136. return d
  137. def main():
  138. parser = argparse.ArgumentParser(
  139. "generate FlatBuffers schema of operator param from description file"
  140. )
  141. parser.add_argument("input")
  142. parser.add_argument("output")
  143. args = parser.parse_args()
  144. with open(args.input) as fin:
  145. inputs = fin.read()
  146. exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc})
  147. input_hash = hashlib.sha256()
  148. input_hash.update(inputs.encode(encoding="UTF-8"))
  149. input_hash = input_hash.hexdigest()
  150. writer = FlatBuffersWriter()
  151. with open(args.output, "w") as fout:
  152. writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)
  153. if __name__ == "__main__":
  154. main()