You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_param_defs.py 34 kB


  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import argparse
  4. import collections
  5. import hashlib
  6. import os
  7. import struct
  8. import textwrap
  9. class member_defs:
  10. """contain classes to define members of an opr param"""
  11. Dtype = collections.namedtuple(
  12. "Dtype", ["cname", "pycvt", "pyfmt", "cppjson", "cname_attr"]
  13. )
  14. Dtype.__new__.__defaults__ = ("",)
  15. uint32 = Dtype("uint32_t", "int", "I", "NumberInt")
  16. uint64 = Dtype("uint64_t", "int", "Q", "NumberInt", "alignas(sizeof(uint64_t)) ")
  17. int32 = Dtype("int32_t", "int", "i", "NumberInt")
  18. float32 = Dtype("float", "float", "f", "Number")
  19. float64 = Dtype("double", "float", "d", "Number")
  20. dtype = Dtype("DTypeEnum", "_as_dtype_num", "I", "Number")
  21. bool = Dtype("bool", "bool", "?", "Bool")
  22. class Base:
  23. pass
  24. class Doc:
  25. """wrap an identifier to associate document
  26. note: if the doc starts with a linebreak, it would not be reforamtted.
  27. """
  28. __slots__ = ["id", "doc"]
  29. def __init__(self, id_, doc):
  30. assert isinstance(id_, str) and isinstance(doc, str), (id_, doc)
  31. self.id = id_
  32. self.doc = doc
  33. @property
  34. def no_reformat(self):
  35. """whether reformat is disallowed for this doc string"""
  36. return self.doc.startswith("\n")
  37. @property
  38. def raw_lines(self):
  39. """the doc lines when ``no_format`` is true"""
  40. ret = self.doc.split("\n")
  41. assert not ret[0]
  42. return ret[1:]
  43. @classmethod
  44. def make(cls, v):
  45. """make doc object from str or doc"""
  46. if isinstance(v, cls):
  47. return v
  48. assert isinstance(v, str)
  49. return cls(v, "")
  50. def __str__(self):
  51. return self.id
  52. def __eq__(self, rhs):
  53. if isinstance(rhs, str):
  54. return self.id == rhs
  55. return isinstance(rhs, Doc) and (self.id, self.doc) == (rhs.id, rhs.doc)
  56. class Enum(Base):
  57. """define an enum; the result would contain both an enum class def and its
  58. corresponding data field
  59. :param default:
  60. for normal enum class: index of default member value
  61. for bit combined class: tuple of index of default member value
  62. For example, following representations of the default value for bit
  63. combined class are all equivalent:
  64. Enum(members=('a', 'b', 'c'), default=('a', 'b'), ...)
  65. Enum(members=('a', 'b', 'c'), default=(0, 1), ...)
  66. Enum(members=('a', 'b', 'c'), default=(1 << 0) | (1 << 1), ...)
  67. :attr name_field: name of the data field of this enum in the param
  68. struct
  69. :attr member_alias:
  70. for normal enum class: list of (member, alias) pairs
  71. for bit combined class: list of (tuple of members, alias) paris
  72. """
  73. __slots__ = [
  74. "name",
  75. "name_field",
  76. "members",
  77. "default",
  78. "member_alias",
  79. "combined",
  80. ]
  81. all_enums = {}
  82. """(param_name, name) => enum"""
  83. def __init__(
  84. self,
  85. param_name,
  86. name,
  87. name_field,
  88. members,
  89. default,
  90. member_alias,
  91. combined=False,
  92. ):
  93. name = member_defs.Doc.make(name)
  94. assert name.id[0].isupper()
  95. members = tuple(map(member_defs.Doc.make, members))
  96. self.name = name
  97. self.combined = combined
  98. self.name_field = self.get_name_field(name.id, name_field)
  99. self.members = members
  100. self.default = self.normalize_enum_value(default)
  101. self.all_enums[(param_name, name.id)] = self
  102. assert isinstance(member_alias, list)
  103. self.member_alias = member_alias
  104. @classmethod
  105. def get_name_field(cls, name, name_field):
  106. if name_field is None:
  107. name_field = name[0].lower() + name[1:]
  108. assert isinstance(name_field, str)
  109. return name_field
  110. def normalize_enum_value(self, value):
  111. def normalize(v):
  112. if isinstance(v, str):
  113. for idx, m in enumerate(self.members):
  114. m = str(m).split(" ")[0].split("=")[0]
  115. if v == m:
  116. return idx
  117. raise ValueError("enum member '{}' does not exist.".format(v))
  118. assert isinstance(v, int)
  119. return v
  120. if self.combined:
  121. if isinstance(value, int):
  122. value = self.decompose_combined_enum(value)
  123. assert isinstance(value, tuple)
  124. value = tuple(normalize(i) for i in value)
  125. return value
  126. else:
  127. return normalize(value)
  128. @staticmethod
  129. def decompose_combined_enum(v):
  130. """Integer => tuple of the indexes of the enum members"""
  131. assert isinstance(v, int)
  132. idx = 0
  133. members = []
  134. while v > 0:
  135. if v & 1:
  136. members.append(idx)
  137. idx += 1
  138. v >>= 1
  139. return tuple(members)
  140. def compose_combined_enum(self, v):
  141. """tuple of members => Integer"""
  142. assert self.combined and isinstance(v, tuple)
  143. norm_v = self.normalize_enum_value(v)
  144. return sum(1 << i for i in norm_v)
  145. class Field(Base):
  146. """define a normal data field"""
  147. __slots__ = ["name", "dtype", "default"]
  148. def __init__(self, name, dtype, default):
  149. assert isinstance(dtype, member_defs.Dtype)
  150. self.name = member_defs.Doc.make(name)
  151. self.dtype = dtype
  152. self.default = default
  153. class Const(Base):
  154. """define a const data field"""
  155. __slots__ = ["name", "dtype", "default"]
  156. def __init__(self, name, dtype, default):
  157. assert isinstance(dtype, member_defs.Dtype)
  158. self.name = member_defs.Doc.make(name)
  159. self.dtype = dtype
  160. self.default = default
  161. class EnumAlias(Base):
  162. """alias of enum type from another param"""
  163. __slots__ = ["name", "name_field", "src_class", "src_name", "default"]
  164. def __init__(self, name, name_field, src_class, src_name, default):
  165. self.name = name
  166. self.name_field = member_defs.Enum.get_name_field(name, name_field)
  167. self.src_class = src_class
  168. if src_name is None:
  169. src_name = name
  170. self.src_name = src_name
  171. self.default = default
  172. # TODO: remove this assertion if needed; adding mock param_defs in
  173. # current testing framework is too complicated, and currently we
  174. # only allow aliasing of normal enum
  175. assert not self.src_enum.combined
  176. @property
  177. def src_enum(self):
  178. """source Enum class"""
  179. return member_defs.Enum.all_enums[(self.src_class, self.src_name)]
  180. def get_default(self):
  181. """get default index; fallback to src index if default is not
  182. set"""
  183. if self.default is None:
  184. return self.src_enum.default
  185. return self.src_enum.normalize_enum_value(self.default)
  186. class ParamDef:
  187. """"""
  188. __all_tags = set()
  189. all_param_defs = []
  190. __slots__ = ["name", "members", "tag", "is_legacy"]
  191. def __init__(self, name, doc="", *, version=0, is_legacy=False):
  192. self.members = []
  193. self.all_param_defs.append(self)
  194. h = hashlib.sha256(name.encode("utf-8"))
  195. if version:
  196. h.update(struct.pack("<I", version))
  197. if is_legacy:
  198. name += "V{}".format(version)
  199. self.name = member_defs.Doc(name, doc)
  200. self.tag = int(h.hexdigest()[:8], 16)
  201. self.is_legacy = is_legacy
  202. if self.tag < 1024:
  203. self.tag += 1024
  204. assert (
  205. self.tag not in self.__all_tags
  206. ), "tag hash confliction: name={} tag={}".format(name, self.tag)
  207. self.__all_tags.add(self.tag)
  208. def add_fields(self, dtype, *names_defaults):
  209. assert isinstance(dtype, str)
  210. dtype = getattr(member_defs, dtype)
  211. assert len(names_defaults) % 2 == 0
  212. for i, j in zip(names_defaults[::2], names_defaults[1::2]):
  213. self.members.append(member_defs.Field(i, dtype, j))
  214. return self
  215. def add_enum(self, name, *members, default=0, name_field=None, member_alias=[]):
  216. self.members.append(
  217. member_defs.Enum(
  218. self.name.id, name, name_field, members, default, member_alias
  219. )
  220. )
  221. return self
  222. def add_bit_combination_enum(
  223. self, name, *members, default=tuple(), name_field=None, member_alias=[]
  224. ):
  225. self.members.append(
  226. member_defs.Enum(
  227. self.name.id, name, name_field, members, default, member_alias, True
  228. )
  229. )
  230. return self
  231. def add_enum_alias(
  232. self, name, src_class, src_name=None, name_field=None, default=None
  233. ):
  234. self.members.append(
  235. member_defs.EnumAlias(name, name_field, src_class, src_name, default)
  236. )
  237. return self
  238. def add_const(self, dtype, *names_defaults):
  239. assert isinstance(dtype, str)
  240. dtype = getattr(member_defs, dtype)
  241. assert len(names_defaults) % 2 == 0
  242. for i, j in zip(names_defaults[::2], names_defaults[1::2]):
  243. self.members.append(member_defs.Const(i, dtype, j))
  244. return self
  245. class WriterBase:
  246. """base class for output file writer"""
  247. _fout = None
  248. _input_hash = None
  249. def __call__(self, fout):
  250. self._fout = fout
  251. def set_input_hash(self, h):
  252. self._input_hash = h
  253. return self
  254. def _get_header(self):
  255. return "generated by {} for {}".format(
  256. os.path.basename(__file__), self._input_hash
  257. )
  258. def _process(self, defs):
  259. dispatch = {
  260. member_defs.Enum: self._on_member_enum,
  261. member_defs.EnumAlias: self._on_member_enum_alias,
  262. member_defs.Field: self._on_member_field,
  263. member_defs.Const: self._on_const_field,
  264. }
  265. for i in defs:
  266. assert isinstance(i, ParamDef)
  267. self._on_param_begin(i)
  268. for j in i.members:
  269. dispatch[type(j)](j)
  270. self._on_param_end(i)
  271. def _on_param_begin(self, p):
  272. """:type p: :class:`.ParamDef`"""
  273. def _on_param_end(self, p):
  274. """:type p: :class:`.ParamDef`"""
  275. def _on_member_enum(self, e):
  276. """:type p: :class:`.Enum`"""
  277. def _on_member_enum_alias(self, e):
  278. """:type p: :class:`.EnumAlias`"""
  279. def _on_member_field(self, f):
  280. """:type p: :class:`.Field`"""
  281. def _on_const_field(self, f):
  282. """:type p: :class:`.Const`"""
  283. class IndentWriterBase(WriterBase):
  284. _cur_indent = ""
  285. def _indent(self):
  286. self._cur_indent += " " * 4
  287. def _unindent(self):
  288. self._cur_indent = self._cur_indent[:-4]
  289. def _write(self, content, *fmt, indent=0):
  290. if indent < 0:
  291. self._unindent()
  292. self._fout.write(self._cur_indent)
  293. if fmt:
  294. content = content % fmt
  295. self._fout.write(content)
  296. self._fout.write("\n")
  297. if indent > 0:
  298. self._indent()
  299. class PyWriter(IndentWriterBase):
  300. FieldDef = collections.namedtuple(
  301. "FieldDef", ["name", "cvt", "fmt", "default", "type", "doc"]
  302. )
  303. # see _on_param_end() for the use of those fields
  304. _cur_param_name = None
  305. _cur_fields = None
  306. _cur_struct_fmt = None
  307. _enum_member2num = None
  308. def __init__(self, for_imperative=False):
  309. self._imperative = for_imperative
  310. def __call__(self, fout, defs):
  311. super().__call__(fout)
  312. self._enum_member2num = []
  313. self._write("# %s", self._get_header())
  314. self._write("import struct")
  315. self._write("from . import enum36 as enum")
  316. self._write(
  317. "class _ParamDefBase:\n"
  318. " def serialize(self):\n"
  319. ' tag = struct.pack("I", type(self).TAG)\n'
  320. " pdata = [getattr(self, i) for i in self.__slots__]\n"
  321. " for idx, v in enumerate(pdata):\n"
  322. " if isinstance(v, _EnumBase):\n"
  323. " pdata[idx] = _enum_member2num[id(v)]\n"
  324. " elif isinstance(v, _BitCombinedEnumBase):\n"
  325. " pdata[idx] = v._value_\n"
  326. " return tag + self._packer.pack(*pdata)\n"
  327. "\n"
  328. )
  329. # it's hard to mix custom implemention into enum, just do copy-paste instead
  330. classbody = (
  331. " @classmethod\n"
  332. " def __normalize(cls, val):\n"
  333. " if isinstance(val, str):\n"
  334. ' if not hasattr(cls, "__member_upper_dict__"):\n'
  335. " cls.__member_upper_dict__ = {k.upper(): v\n"
  336. " for k, v in cls.__members__.items()}\n"
  337. " val = cls.__member_upper_dict__.get(val.upper(),val)\n"
  338. " return val\n"
  339. " @classmethod\n"
  340. " def convert(cls, val):\n"
  341. " val = cls.__normalize(val)\n"
  342. " if isinstance(val, cls):\n"
  343. " return val\n"
  344. " return cls(val)\n"
  345. " @classmethod\n"
  346. " def _missing_(cls, value):\n"
  347. " vnorm = cls.__normalize(value)\n"
  348. " if vnorm is not value:\n"
  349. " return cls(vnorm)\n"
  350. " return super()._missing_(value)\n"
  351. "\n"
  352. )
  353. self._write("class _EnumBase(enum.Enum):\n" + classbody)
  354. self._write("class _BitCombinedEnumBase(enum.Flag):\n" + classbody)
  355. if not self._imperative:
  356. self._write(
  357. "def _as_dtype_num(dtype):\n"
  358. " import megbrain.mgb as m\n"
  359. " return m._get_dtype_num(dtype)\n"
  360. "\n"
  361. )
  362. self._write(
  363. "def _as_serialized_dtype(dtype):\n"
  364. " import megbrain.mgb as m\n"
  365. " return m._get_serialized_dtype(dtype)\n"
  366. "\n"
  367. )
  368. else:
  369. self._write(
  370. "def _as_dtype_num(dtype):\n"
  371. " import megengine.core._imperative_rt.utils as m\n"
  372. " return m._get_dtype_num(dtype)\n"
  373. "\n"
  374. )
  375. self._write(
  376. "def _as_serialized_dtype(dtype):\n"
  377. " import megengine.core._imperative_rt.utils as m\n"
  378. " return m._get_serialized_dtype(dtype)\n"
  379. "\n"
  380. )
  381. self._process(defs)
  382. self._write(
  383. '''
  384. class SerializedDType(_ParamDefBase):
  385. TAG = FakeSerializedDType.TAG
  386. __slots__ = ['dtype']
  387. class IdentityPacker:
  388. def pack(self, *args):
  389. assert all([isinstance(x, bytes) for x in args])
  390. return b''.join(args)
  391. _packer = IdentityPacker()
  392. def __init__(self, dtype):
  393. """
  394. :type dtype: :class:`np.dtype` compatible
  395. """
  396. self.dtype = _as_serialized_dtype(dtype)
  397. '''
  398. )
  399. self._write("_enum_member2num = {\n %s}", ",\n ".join(self._enum_member2num))
  400. def _write_doc(self, doc):
  401. assert isinstance(doc, member_defs.Doc)
  402. if not doc.doc:
  403. return
  404. if doc.no_reformat:
  405. self._write('"""')
  406. for i in doc.raw_lines:
  407. self._write(i)
  408. self._write('"""')
  409. return
  410. doc = doc.doc.replace("\n", " ")
  411. textwidth = 80 - len(self._cur_indent)
  412. self._write('"""')
  413. for i in textwrap.wrap(doc, textwidth):
  414. self._write(i)
  415. self._write('"""')
  416. def _on_param_begin(self, p):
  417. self._cur_param_name = str(p.name)
  418. self._cur_fields = []
  419. self._cur_enum_names = []
  420. self._write("class %s(_ParamDefBase):", p.name, indent=1)
  421. self._write_doc(p.name)
  422. self._write("TAG = %d", p.tag)
  423. def _on_param_end(self, p):
  424. # gen slots and packer
  425. self._write(
  426. "__slots__ = [%s]", ", ".join(map('"{.name}"'.format, self._cur_fields))
  427. )
  428. struct_fmt = "".join(i.fmt for i in self._cur_fields)
  429. if not struct_fmt:
  430. struct_fmt = "x"
  431. else:
  432. # add padding at end
  433. max_t = max(struct_fmt, key=struct.calcsize)
  434. struct_fmt += "0{}".format(max_t)
  435. self._write('_packer = struct.Struct("%s")', struct_fmt)
  436. # gen __init__ signature
  437. self._write(
  438. "def __init__(%s):",
  439. ", ".join(
  440. ["self"]
  441. + list("{}={}".format(i.name, i.default) for i in self._cur_fields)
  442. ),
  443. indent=1,
  444. )
  445. # gen __init__ doc
  446. self._write('"""')
  447. for i in self._cur_fields:
  448. self._write(":type {}: :class:`.{}`".format(i.name, i.type))
  449. if i.doc:
  450. self._write(":param {}: {}".format(i.name, i.doc))
  451. self._write('"""')
  452. # gen cvt in __init__
  453. for i in self._cur_fields:
  454. self._write("self.%s = %s", i.name, i.cvt)
  455. self._unindent()
  456. self._unindent()
  457. self._write("")
  458. def _on_member_enum(self, e):
  459. qualname = "{}.{}".format(self._cur_param_name, e.name)
  460. if e.combined:
  461. self._write("class %s(_BitCombinedEnumBase):", e.name, indent=1)
  462. else:
  463. self._write("class %s(_EnumBase):", e.name, indent=1)
  464. self._write_doc(e.name)
  465. for emem in e.members:
  466. if e.combined:
  467. self._write("%s", emem)
  468. self._write_doc(emem)
  469. else:
  470. v = str(emem).split(" ")[0].split("=")[0]
  471. n = int(str(emem).split("=")[1])
  472. self._write('%s = "%s"', v, v)
  473. self._write_doc(emem)
  474. self._enum_member2num.append("id({}.{}):{}".format(qualname, v, n))
  475. for emem, emem_alias in e.member_alias:
  476. em_a = emem_alias.split(" ")[0].split("=")[0]
  477. if e.combined:
  478. self._write("%s = %s", em_a, e.compose_combined_enum(emem))
  479. else:
  480. em = str(emem).split(" ")[0].split("=")[0]
  481. self._write("%s = %s", em_a, em)
  482. self._unindent()
  483. self._write("")
  484. if e.combined:
  485. default = e.compose_combined_enum(e.default)
  486. else:
  487. default = "'{}'".format(
  488. str(e.members[e.default]).split(" ")[0].split("=")[0]
  489. )
  490. self._cur_fields.append(
  491. self.FieldDef(
  492. name=e.name_field,
  493. cvt="{}.convert({})".format(qualname, e.name_field),
  494. fmt="I",
  495. default=default,
  496. type=qualname,
  497. doc=None,
  498. )
  499. )
  500. def _on_member_enum_alias(self, e):
  501. self._write("%s = %s.%s", e.name, e.src_class, e.src_name)
  502. s = e.src_enum
  503. qualname = "{}.{}".format(e.src_class, e.src_name)
  504. if s.combined:
  505. default = s.compose_combined_enum(e.get_default())
  506. else:
  507. default = "'{}'".format(
  508. str(s.members[e.get_default()]).split(" ")[0].split("=")[0]
  509. )
  510. self._cur_fields.append(
  511. self.FieldDef(
  512. name=e.name_field,
  513. cvt="{}.convert({})".format(qualname, e.name_field),
  514. fmt="I",
  515. default=default,
  516. type=qualname,
  517. doc=None,
  518. )
  519. )
  520. def _get_py_default(self, cppdefault):
  521. if not isinstance(cppdefault, str):
  522. return cppdefault
  523. d = cppdefault
  524. if d.endswith("f"): # 1.f
  525. return d[:-1]
  526. if d.endswith("ull"):
  527. return d[:-3]
  528. if d == "false":
  529. return "False"
  530. if d == "true":
  531. return "True"
  532. if d.startswith("DTypeEnum::"):
  533. return '"{}"'.format(d.split(":")[2].lower())
  534. return d
  535. def _on_member_field(self, f):
  536. d = self._get_py_default(f.default)
  537. self._cur_fields.append(
  538. self.FieldDef(
  539. name=f.name,
  540. cvt="{}({})".format(f.dtype.pycvt, f.name),
  541. fmt=f.dtype.pyfmt,
  542. default=d,
  543. type=f.dtype.pycvt,
  544. doc=f.name.doc,
  545. )
  546. )
  547. def _on_const_field(self, f):
  548. d = self._get_py_default(f.default)
  549. self._write_doc(f.name)
  550. self._write("%s = %s", f.name, d)
  551. class CPPWriter(IndentWriterBase):
  552. _param_namespace = "param"
  553. _ctor_args = None
  554. """list of (text in func param, var name); func param name must be var name
  555. appended by an underscore"""
  556. _non_static_members = None
  557. def __call__(self, fout, defs):
  558. super().__call__(fout)
  559. self._write("// %s", self._get_header())
  560. self._write("#pragma once")
  561. self._write('#include "megdnn/dtype.h"')
  562. self._write("#include <stdint.h>")
  563. if self._param_namespace == "param":
  564. self._write("#include <string.h>")
  565. self._write("namespace megdnn {")
  566. self._write("namespace %s {", self._param_namespace)
  567. self._process(defs)
  568. self._write("} // namespace megdnn")
  569. self._write("} // namespace %s", self._param_namespace)
  570. self._write("// vim: syntax=cpp.doxygen")
  571. def _write_doc(self, doc):
  572. assert isinstance(doc, member_defs.Doc)
  573. if not doc.doc:
  574. return
  575. if doc.no_reformat:
  576. self._write("/*")
  577. for i in doc.raw_lines:
  578. self._write("* " + i)
  579. self._write("*/")
  580. return
  581. doc = doc.doc.replace("\n", " ")
  582. textwidth = 80 - len(self._cur_indent) - 4
  583. if len(doc) <= textwidth:
  584. self._write("//! " + doc)
  585. return
  586. self._write("/*!")
  587. for i in textwrap.wrap(doc, textwidth):
  588. self._write(" * " + i)
  589. self._write(" */")
  590. def _on_param_begin(self, p):
  591. self._write_doc(p.name)
  592. self._write("struct %s {", p.name, indent=1)
  593. self._write("static MEGDNN_CONSTEXPR uint32_t TAG = %du;", p.tag)
  594. self._ctor_args = []
  595. self._non_static_members = []
  596. def _add_ctor_args(self, typename, default, varname):
  597. self._ctor_args.append(
  598. ("{} {}_={}".format(typename, varname, default), varname)
  599. )
  600. def _on_param_end(self, p):
  601. """
  602. MegDNN param structures are not packed and we need to initialize the structure
  603. paddings to zero or it would break MegBrain hash system. We do memset(0) in default
  604. ctor and use a trick, wrapping non-static members in a anonymous union which would
  605. copy the object representation in its default copy/move ctor, for copy/move ctor.
  606. > The implicitly-defined copy/move constructor for a non-union class X performs
  607. > a memberwise copy/move of its bases and members. [class.copy.ctor 14]
  608. > The implicitly-defined copy/move constructor for a union X copies the object
  609. > representation (6.9) of X. [class.copy.ctor 15]
  610. """
  611. if self._non_static_members:
  612. self._write("union { struct {")
  613. for i in self._non_static_members:
  614. if isinstance(i, member_defs.Field):
  615. self._write_doc(i.name)
  616. self._write("%s%s %s;", i.dtype.cname_attr, i.dtype.cname, i.name)
  617. else:
  618. assert isinstance(i, (member_defs.Enum, member_defs.EnumAlias))
  619. self._write("%s %s;", i.name, i.name_field)
  620. self._write("}; };")
  621. if self._ctor_args:
  622. pdefs, varnames = zip(*self._ctor_args)
  623. self._write("%s(%s) {", p.name, ", ".join(pdefs), indent=1)
  624. self._write("memset(this, 0, sizeof(*this));")
  625. for var in varnames:
  626. self._write("this->%s = %s_;", var, var)
  627. self._write("}", indent=-1)
  628. self._write("};\n", indent=-1)
  629. def _on_member_enum(self, e):
  630. self._write_doc(e.name)
  631. self._write("enum class %s: uint32_t {", e.name, indent=1)
  632. for i in e.members:
  633. self._write_doc(i)
  634. v = str(i)
  635. if i is not e.members[-1] or e.member_alias:
  636. v += ","
  637. self._write(v)
  638. for mem, alias in e.member_alias:
  639. if e.combined:
  640. self._write("%s = %s,", alias, e.compose_combined_enum(mem))
  641. else:
  642. self._write(
  643. "%s = %s,",
  644. str(alias).split(" ")[0].split("=")[0],
  645. str(mem).split(" ")[0].split("=")[0],
  646. )
  647. self._write("};", indent=-1)
  648. self._non_static_members.append(e)
  649. self._write(
  650. "static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;",
  651. str(e.name).upper(),
  652. len(e.members),
  653. )
  654. if e.combined:
  655. default = "static_cast<{}>({})".format(
  656. e.name, e.compose_combined_enum(e.default)
  657. )
  658. else:
  659. value = str(e.members[e.default])
  660. value = value.split(" ")[0].split("=")[0]
  661. default = "{}::{}".format(e.name, value)
  662. self._add_ctor_args(e.name, default, e.name_field)
  663. def _on_member_enum_alias(self, e):
  664. s = e.src_enum
  665. self._write("using %s = %s::%s;", e.name, e.src_class, e.src_name)
  666. self._non_static_members.append(e)
  667. self._write(
  668. "static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;",
  669. str(e.name).upper(),
  670. len(s.members),
  671. )
  672. if s.combined:
  673. default = "static_cast<{}>({})".format(
  674. e.name, s.compose_combined_enum(e.default)
  675. )
  676. else:
  677. value = str(s.members[e.get_default()])
  678. value = value.split(" ")[0].split("=")[0]
  679. default = "{}::{}".format(e.name, value)
  680. self._add_ctor_args(e.name, default, e.name_field)
  681. def _on_member_field(self, f):
  682. self._non_static_members.append(f)
  683. self._add_ctor_args(f.dtype.cname, f.default, f.name)
  684. def _on_const_field(self, f):
  685. self._write_doc(f.name)
  686. if "int" in f.dtype.cname:
  687. self._write(
  688. "static constexpr %s%s %s = %s;",
  689. f.dtype.cname_attr,
  690. f.dtype.cname,
  691. f.name,
  692. f.default,
  693. )
  694. else:
  695. self._write(
  696. "static const %s%s %s = %s;",
  697. f.dtype.cname_attr,
  698. f.dtype.cname,
  699. f.name,
  700. f.default,
  701. )
  702. class CPPEnumValueWriter(CPPWriter):
  703. _param_namespace = "param_enumv"
  704. def _on_member_enum(self, e):
  705. self._write_doc(e.name)
  706. self._write("struct %s {", e.name, indent=1)
  707. for val in e.members:
  708. self._write_doc(val)
  709. v = str(val)
  710. self._write("static const uint32_t %s;", v)
  711. for mem, alias in e.member_alias:
  712. self._write(
  713. "static const uint32_t %s = %s;",
  714. str(alias).split(" ")[0].split("=")[0],
  715. str(mem).split(" ")[0].split("=")[0],
  716. )
  717. self._write("};", indent=-1)
  718. def _on_member_enum_alias(self, e):
  719. s = e.src_enum
  720. self._write("typedef %s::%s %s;", e.src_class, e.src_name, e.name)
  721. def _on_member_field(self, f):
  722. pass
  723. def _on_const_field(self, f):
  724. pass
  725. class CPPEnumItemWriter(WriterBase):
  726. _class_name = None
  727. _enum_name = None
  728. _enable = False
  729. def __init__(self, enum_def):
  730. self._class_name, self._enum_name = enum_def.split(":")
  731. def __call__(self, fout, defs):
  732. super().__call__(fout)
  733. self._process(defs)
  734. def _on_param_begin(self, p):
  735. self._enable = p.name == self._class_name
  736. def _on_member_enum(self, e):
  737. if self._enable and e.name == self._enum_name:
  738. for i in e.members:
  739. self._fout.write("{}\n".format(i))
  740. class CPPParamJsonFuncWriter(IndentWriterBase):
  741. _param_namespace = "param"
  742. _param_name = None
  743. _items = None
  744. def _write_json_item(self, json_cls, field):
  745. cls2ctype = {
  746. "NumberInt": "int64_t",
  747. "Number": "double",
  748. "Bool": "bool",
  749. }
  750. self._items.append(
  751. '{"%s", json::%s::make(static_cast<%s>(p.%s))},'
  752. % (field, json_cls, cls2ctype[json_cls], field)
  753. )
  754. def __call__(self, fout, defs):
  755. super().__call__(fout)
  756. self._write("// %s", self._get_header())
  757. self._write(
  758. "// this file can only be included in "
  759. "megbrain/src/plugin/impl/opr_footprint.cpp\n"
  760. "// please do not include it directly"
  761. )
  762. self._write('#include "megdnn/opr_param_defs.h"')
  763. self._write("#pragma once")
  764. self._write("using namespace megdnn;")
  765. self._write("namespace mgb {")
  766. self._write("namespace opr {")
  767. self._write("template<class OprParam>")
  768. self._write(
  769. "std::shared_ptr<mgb::json::Value> opr_param_to_json(const OprParam &param);"
  770. )
  771. self._process(defs)
  772. self._write("} // namespace opr")
  773. self._write("} // namespace mgb")
  774. self._write("\n// vim: syntax=cpp.doxygen")
  775. def _on_param_begin(self, p):
  776. self._write("template<>", indent=0)
  777. self._write(
  778. "std::shared_ptr<mgb::json::Value> opr_param_to_json(const param::%s &p) {",
  779. p.name,
  780. indent=1,
  781. )
  782. self._param_name = "param::{}".format(p.name)
  783. self._items = []
  784. def _on_param_end(self, p):
  785. self._write("return json::Object::make({", indent=1)
  786. for i in self._items:
  787. self._write(i, indent=0)
  788. self._write("});", indent=-1)
  789. self._write("}", indent=-1)
  790. def _on_member_enum(self, e):
  791. self._write(
  792. "auto %s2str = [](const %s::%s arg) -> std::string {",
  793. e.name,
  794. self._param_name,
  795. e.name,
  796. indent=1,
  797. )
  798. self._write("switch (arg) {", indent=1)
  799. enum2str = []
  800. if isinstance(e, member_defs.EnumAlias):
  801. members = e.src_enum.members
  802. else:
  803. members = e.members
  804. for i in members:
  805. v = str(i)
  806. v = v.split(" ")[0].split("=")[0]
  807. self._write(
  808. 'case %s::%s::%s: return "%s";',
  809. self._param_name,
  810. e.name,
  811. v,
  812. v,
  813. indent=0,
  814. )
  815. self._write(
  816. 'default: mgb_throw(MegBrainError, "Invalid %s::%s:%%d", static_cast<int>(arg));',
  817. self._param_name,
  818. e.name,
  819. indent=0,
  820. )
  821. self._write("}", indent=-1)
  822. self._write("};", indent=-1)
  823. self._items.append(
  824. '{"%s", json::String::make(%s2str(p.%s))},'
  825. % (e.name_field, e.name, e.name_field)
  826. )
  827. def _on_member_enum_alias(self, e):
  828. self._on_member_enum(e)
  829. def _on_member_field(self, f):
  830. self._write_json_item(f.dtype.cppjson, f.name)
  831. def _on_const_field(self, f):
  832. pass
  833. def main():
  834. parser = argparse.ArgumentParser("generate opr param defs from description file")
  835. parser.add_argument(
  836. "--enumv",
  837. action="store_true",
  838. help="generate c++03 compatible code which only " "contains enum values",
  839. )
  840. parser.add_argument(
  841. "-t", "--type", choices=["c++", "py"], default="c++", help="output type"
  842. )
  843. parser.add_argument(
  844. "--write-enum-items",
  845. help="write enum item names to output file; argument "
  846. "should be given in the CLASS:ENUM format",
  847. )
  848. parser.add_argument(
  849. "--write-cppjson",
  850. help="generate megbrain json serialization implemention" "cpp file",
  851. )
  852. parser.add_argument("input")
  853. parser.add_argument("output")
  854. parser.add_argument(
  855. "--imperative", action="store_true", help="generate files for imperatvie "
  856. )
  857. args = parser.parse_args()
  858. for_imperative = args.imperative
  859. with open(args.input) as fin:
  860. inputs = fin.read()
  861. exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc})
  862. input_hash = hashlib.sha256()
  863. input_hash.update(inputs.encode(encoding="UTF-8"))
  864. input_hash = input_hash.hexdigest()
  865. if args.type == "py":
  866. writer = PyWriter(for_imperative=for_imperative)
  867. else:
  868. assert args.type == "c++"
  869. if args.enumv:
  870. writer = CPPEnumValueWriter()
  871. elif args.write_enum_items:
  872. writer = CPPEnumItemWriter(args.write_enum_items)
  873. else:
  874. writer = CPPWriter()
  875. with open(args.output, "w") as fout:
  876. writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)
  877. if args.write_cppjson:
  878. writer = CPPParamJsonFuncWriter()
  879. with open(args.write_cppjson, "w") as fout:
  880. writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)
  881. if __name__ == "__main__":
  882. main()