You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_flatbuffers_converter.py 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import argparse
  4. import collections
  5. import hashlib
  6. import io
  7. import os
  8. import struct
  9. import textwrap
  10. from gen_param_defs import IndentWriterBase, ParamDef, member_defs
  11. class ConverterWriter(IndentWriterBase):
  12. _skip_current_param = False
  13. _last_param = None
  14. _param_fields = None
  15. _fb_fields = []
  16. def __call__(self, fout, defs):
  17. super().__call__(fout)
  18. self._write("// %s", self._get_header())
  19. self._write("#include <flatbuffers/flatbuffers.h>")
  20. self._write("namespace mgb {")
  21. self._write("namespace serialization {")
  22. self._write("namespace fbs {")
  23. self._process(defs)
  24. self._write("} // namespace fbs")
  25. self._write("} // namespace serialization")
  26. self._write("} // namespace mgb")
  27. def _on_param_begin(self, p):
  28. self._last_param = p
  29. self._param_fields = []
  30. self._fb_fields = ["builder"]
  31. self._write(
  32. "template<>\nstruct ParamConverter<megdnn::param::%s> {", p.name, indent=1
  33. )
  34. self._write("using MegDNNType = megdnn::param::%s;", p.name)
  35. self._write("using FlatBufferType = fbs::param::%s;\n", p.name)
  36. def _on_param_end(self, p):
  37. if self._skip_current_param:
  38. self._skip_current_param = False
  39. return
  40. self._write("static MegDNNType to_param(const FlatBufferType* fb) {", indent=1)
  41. line = "return {"
  42. line += ", ".join(self._param_fields)
  43. line += "};"
  44. self._write(line)
  45. self._write("}\n", indent=-1)
  46. self._write(
  47. "static flatbuffers::Offset<FlatBufferType> to_flatbuffer(flatbuffers::FlatBufferBuilder& builder, const MegDNNType& param) {",
  48. indent=1,
  49. )
  50. line = "return fbs::param::Create{}(".format(str(p.name))
  51. line += ", ".join(self._fb_fields)
  52. line += ");"
  53. self._write(line)
  54. self._write("}", indent=-1)
  55. self._write("};\n", indent=-1)
  56. def _on_member_enum(self, e):
  57. p = self._last_param
  58. key = str(p.name) + str(e.name)
  59. if self._skip_current_param:
  60. return
  61. self._param_fields.append(
  62. "static_cast<megdnn::param::{}::{}>(fb->{}())".format(
  63. str(p.name), str(e.name), e.name_field
  64. )
  65. )
  66. self._fb_fields.append(
  67. "static_cast<fbs::param::{}>(param.{})".format(key, e.name_field)
  68. )
  69. def _on_member_field(self, f):
  70. if self._skip_current_param:
  71. return
  72. if f.dtype.cname == "DTypeEnum":
  73. self._param_fields.append(
  74. "intl::convert_dtype_to_megdnn(fb->{}())".format(f.name)
  75. )
  76. self._fb_fields.append(
  77. "intl::convert_dtype_to_fbs(param.{})".format(f.name)
  78. )
  79. else:
  80. self._param_fields.append("fb->{}()".format(f.name))
  81. self._fb_fields.append("param.{}".format(f.name))
  82. def _on_const_field(self, f):
  83. pass
  84. def _on_member_enum_alias(self, e):
  85. if self._skip_current_param:
  86. return
  87. enum_name = e.src_class + e.src_name
  88. self._param_fields.append(
  89. "static_cast<megdnn::param::{}::{}>(fb->{}())".format(
  90. e.src_class, e.src_name, e.name_field
  91. )
  92. )
  93. self._fb_fields.append(
  94. "static_cast<fbs::param::{}>(param.{})".format(enum_name, e.name_field)
  95. )
  96. def main():
  97. parser = argparse.ArgumentParser(
  98. "generate convert functions between FlatBuffers type and MegBrain type"
  99. )
  100. parser.add_argument("input")
  101. parser.add_argument("output")
  102. args = parser.parse_args()
  103. with open(args.input) as fin:
  104. inputs = fin.read()
  105. exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc})
  106. input_hash = hashlib.sha256()
  107. input_hash.update(inputs.encode(encoding="UTF-8"))
  108. input_hash = input_hash.hexdigest()
  109. writer = ConverterWriter()
  110. with open(args.output, "w") as fout:
  111. writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)
  112. if __name__ == "__main__":
  113. main()