You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

format.py 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. #!/usr/bin/env python3
  2. # This file is part of MegBrain.
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. import argparse
  5. import os
  6. import re
  7. import subprocess
  8. import tempfile
  9. from functools import partial
  10. from multiprocessing import Manager
  11. from tqdm.contrib.concurrent import process_map
  12. # change workspace to MegBrain root dir
  13. os.chdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
  14. failed_files = Manager().list()
  15. def process_file(file, clang_format, write):
  16. source = open(file, "r").read()
  17. source = re.sub(r"MGB_DEFINE(?P<r>([^\\]|\n)*?)// *{", r"class MGB_DEFINE\g<r>{", source)
  18. source, count = re.subn(r"(?<!#define )MGB_DEFINE(.*) +\\", r"class MGB_DEFINE\1{\\", source)
  19. result = subprocess.check_output(
  20. [
  21. clang_format,
  22. "-style=file",
  23. "-verbose",
  24. "-assume-filename={}".format(file),
  25. # file,
  26. ],
  27. input=bytes(source.encode("utf-8")),
  28. )
  29. result = result.decode("utf-8")
  30. if count:
  31. result = re.sub(r"class MGB_DEFINE(.*){( *)\\", r"MGB_DEFINE\1\2 \\", result)
  32. result = re.sub(r"class MGB_DEFINE((.|\n)*?){", r"MGB_DEFINE\1// {", result)
  33. if write:
  34. with tempfile.NamedTemporaryFile(
  35. dir=os.path.dirname(file), delete=False
  36. ) as tmp_file:
  37. tmp_file.write(result.encode("utf-8"))
  38. os.rename(tmp_file.name, file)
  39. else:
  40. ret_code = subprocess.run(
  41. ["diff", "--color=always", file, "-"], input=bytes(result.encode("utf-8")),
  42. ).returncode
  43. # man diff: 0 for same, 1 for different, 2 if trouble.
  44. if ret_code == 2:
  45. raise RuntimeError("format process (without overwrite) failed")
  46. if ret_code != 0:
  47. print(file)
  48. global failed_files
  49. failed_files.append(file)
  50. def main():
  51. parser = argparse.ArgumentParser(
  52. description="Format source files using clang-format, eg: `./tools/format.py src -w`. \
  53. Require clang-format version == 12.0",
  54. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  55. )
  56. parser.add_argument(
  57. "path", nargs="+", help="file name or path based on MegBrain root dir."
  58. )
  59. parser.add_argument(
  60. "-w",
  61. "--write",
  62. action="store_true",
  63. help="use formatted file to replace original file.",
  64. )
  65. parser.add_argument(
  66. "--clang-format",
  67. default=os.getenv("CLANG_FORMAT", "clang-format"),
  68. help="clang-format executable name; it can also be "
  69. "modified via the CLANG_FORMAT environment var",
  70. )
  71. args = parser.parse_args()
  72. format_type = [".cpp", ".c", ".h", ".cu", ".cuh", ".inl"]
  73. def getfiles(path):
  74. rst = []
  75. for p in os.listdir(path):
  76. p = os.path.join(path, p)
  77. if os.path.isdir(p):
  78. rst += getfiles(p)
  79. elif (
  80. os.path.isfile(p)
  81. and not os.path.islink(p)
  82. and os.path.splitext(p)[1] in format_type
  83. ):
  84. rst.append(p)
  85. return rst
  86. files = []
  87. for path in args.path:
  88. if os.path.isdir(path):
  89. files += getfiles(path)
  90. elif os.path.isfile(path):
  91. files.append(path)
  92. else:
  93. raise ValueError("Invalid path {}".format(path))
  94. # check version, we only support 12.0.1 now
  95. version = subprocess.check_output(
  96. [
  97. args.clang_format,
  98. "--version",
  99. ],
  100. )
  101. version = version.decode("utf-8")
  102. need_version = '12.0.1'
  103. if version.find(need_version) < 0:
  104. print('We only support {} now, please install {} version, find version: {}'
  105. .format(need_version, need_version, version))
  106. raise RuntimeError('clang-format version not equal {}'.format(need_version))
  107. process_map(
  108. partial(process_file, clang_format=args.clang_format, write=args.write,),
  109. files,
  110. chunksize=10,
  111. )
  112. if failed_files:
  113. raise RuntimeError("above files are not properly formatted!")
  114. if __name__ == "__main__":
  115. main()