You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

compare_binary_iodump.py 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. #! /usr/bin/env python3
  2. import argparse
  3. import os
  4. import struct
  5. import textwrap
  6. from pathlib import Path
  7. import numpy as np
  8. def load_tensor_binary(fobj):
  9. """Load a tensor dumped by the :class:`BinaryOprIODump` plugin; the actual
  10. tensor value dump is implemented by ``mgb::debug::dump_tensor``.
  11. Args:
  12. fobj: file object, or a string that contains the file name.
  13. Returns:
  14. tuple ``(tensor_value, tensor_name)``.
  15. """
  16. if isinstance(fobj, str):
  17. with open(fobj, "rb") as fin:
  18. return load_tensor_binary(fin)
  19. DTYPE_LIST = {
  20. 0: np.float32,
  21. 1: np.uint8,
  22. 2: np.int8,
  23. 3: np.int16,
  24. 4: np.int32,
  25. # 5: _mgb.intb1,
  26. # 6: _mgb.intb2,
  27. # 7: _mgb.intb4,
  28. 8: None,
  29. 9: np.float16,
  30. # quantized dtype start from 100000
  31. # see MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE in
  32. # dnn/include/megdnn/dtype.h
  33. 100000: np.uint8,
  34. 100001: np.int32,
  35. 100002: np.int8,
  36. }
  37. header_fmt = struct.Struct("III")
  38. name_len, dtype, max_ndim = header_fmt.unpack(fobj.read(header_fmt.size))
  39. assert (
  40. DTYPE_LIST[dtype] is not None
  41. ), "Cannot load this tensor: dtype Byte is unsupported."
  42. shape = list(struct.unpack("I" * max_ndim, fobj.read(max_ndim * 4)))
  43. while shape[-1] == 0:
  44. shape.pop(-1)
  45. name = fobj.read(name_len).decode("ascii")
  46. return np.fromfile(fobj, dtype=DTYPE_LIST[dtype]).reshape(shape), name
  47. def check(v0, v1, name, max_err):
  48. v0 = np.ascontiguousarray(v0, dtype=np.float32)
  49. v1 = np.ascontiguousarray(v1, dtype=np.float32)
  50. assert np.isfinite(v0.sum()) and np.isfinite(
  51. v1.sum()
  52. ), "{} not finite: sum={} vs sum={}".format(name, v0.sum(), v1.sum())
  53. assert v0.shape == v1.shape, "{} shape mismatch: {} vs {}".format(
  54. name, v0.shape, v1.shape
  55. )
  56. vdiv = np.max([np.abs(v0), np.abs(v1), np.ones_like(v0)], axis=0)
  57. err = np.abs(v0 - v1) / vdiv
  58. rst = err > max_err
  59. if rst.sum():
  60. idx = tuple(i[0] for i in np.nonzero(rst))
  61. raise AssertionError(
  62. "{} not equal: "
  63. "shape={} nonequal_idx={} v0={} v1={} err={}".format(
  64. name, v0.shape, idx, v0[idx], v1[idx], err[idx]
  65. )
  66. )
  67. def main():
  68. parser = argparse.ArgumentParser(
  69. description=(
  70. "compare tensor dumps generated BinaryOprIODump plugin, "
  71. "it can compare two dirs or two single files"
  72. ),
  73. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  74. )
  75. parser.add_argument("input0", help="dirname or filename")
  76. parser.add_argument("input1", help="dirname or filename")
  77. parser.add_argument(
  78. "-e", "--max-err", type=float, default=1e-3, help="max allowed error"
  79. )
  80. parser.add_argument(
  81. "-s", "--stop-on-error", action="store_true", help="do not compare "
  82. )
  83. args = parser.parse_args()
  84. files0 = set()
  85. files1 = set()
  86. if os.path.isdir(args.input0):
  87. assert os.path.isdir(args.input1)
  88. name0 = set()
  89. name1 = set()
  90. for i in os.listdir(args.input0):
  91. files0.add(str(Path(args.input0) / i))
  92. name0.add(i)
  93. for i in os.listdir(args.input1):
  94. files1.add(str(Path(args.input1) / i))
  95. name1.add(i)
  96. assert name0 == name1, "dir files mismatch: a-b={} b-a={}".format(
  97. name0 - name1, name1 - name0
  98. )
  99. else:
  100. files0.add(args.input0)
  101. files1.add(args.input1)
  102. files0 = sorted(files0)
  103. files1 = sorted(files1)
  104. for i, j in zip(files0, files1):
  105. val0, name0 = load_tensor_binary(i)
  106. val1, name1 = load_tensor_binary(j)
  107. name = "{}: \n{}\n{}\n".format(
  108. i, "\n ".join(textwrap.wrap(name0)), "\n ".join(textwrap.wrap(name1))
  109. )
  110. try:
  111. check(val0, val1, name, args.max_err)
  112. except Exception as exc:
  113. if args.stop_on_error:
  114. raise exc
  115. print(exc)
  116. if __name__ == "__main__":
  117. main()