You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

format.py 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. #!/usr/bin/env python3
  2. # This file is part of MegBrain.
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. import argparse
  5. import os
  6. import re
  7. import subprocess
  8. import tempfile
  9. from functools import partial
  10. from multiprocessing import Manager
  11. from tqdm.contrib.concurrent import process_map
  12. # change workspace to MegBrain root dir
  13. os.chdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
  14. failed_files = Manager().list()
  15. def process_file(file, clang_format, write):
  16. original_source = open(file, "r").read()
  17. source = original_source
  18. source = re.sub(r"MGB_DEFINE(?P<r>([^\\]|\n)*?)// *{", r"class MGB_DEFINE\g<r>{", source)
  19. source, count = re.subn(r"(?<!#define )MGB_DEFINE(.*) +\\", r"class MGB_DEFINE\1{\\", source)
  20. result = subprocess.check_output(
  21. [
  22. clang_format,
  23. "-style=file",
  24. "-verbose",
  25. "-assume-filename={}".format(file),
  26. # file,
  27. ],
  28. input=bytes(source.encode("utf-8")),
  29. )
  30. result = result.decode("utf-8")
  31. if count:
  32. result = re.sub(r"class MGB_DEFINE(.*){( *)\\", r"MGB_DEFINE\1\2 \\", result)
  33. result = re.sub(r"class MGB_DEFINE((.|\n)*?){", r"MGB_DEFINE\1// {", result)
  34. if write and original_source != result:
  35. with tempfile.NamedTemporaryFile(
  36. dir=os.path.dirname(file), delete=False
  37. ) as tmp_file:
  38. tmp_file.write(result.encode("utf-8"))
  39. os.rename(tmp_file.name, file)
  40. else:
  41. ret_code = subprocess.run(
  42. ["diff", "--color=always", file, "-"], input=bytes(result.encode("utf-8")),
  43. ).returncode
  44. # man diff: 0 for same, 1 for different, 2 if trouble.
  45. if ret_code == 2:
  46. raise RuntimeError("format process (without overwrite) failed")
  47. if ret_code != 0:
  48. print(file)
  49. global failed_files
  50. failed_files.append(file)
  51. def main():
  52. parser = argparse.ArgumentParser(
  53. description="Format source files using clang-format, eg: `./tools/format.py src -w`. \
  54. Require clang-format version == 12.0",
  55. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  56. )
  57. parser.add_argument(
  58. "path", nargs="+", help="file name or path based on MegBrain root dir."
  59. )
  60. parser.add_argument(
  61. "-w",
  62. "--write",
  63. action="store_true",
  64. help="use formatted file to replace original file.",
  65. )
  66. parser.add_argument(
  67. "--clang-format",
  68. default=os.getenv("CLANG_FORMAT", "clang-format"),
  69. help="clang-format executable name; it can also be "
  70. "modified via the CLANG_FORMAT environment var",
  71. )
  72. args = parser.parse_args()
  73. format_type = [".cpp", ".c", ".h", ".cu", ".cuh", ".inl"]
  74. def getfiles(path):
  75. rst = []
  76. for p in os.listdir(path):
  77. p = os.path.join(path, p)
  78. if os.path.isdir(p):
  79. rst += getfiles(p)
  80. elif (
  81. os.path.isfile(p)
  82. and not os.path.islink(p)
  83. and os.path.splitext(p)[1] in format_type
  84. ):
  85. rst.append(p)
  86. return rst
  87. files = []
  88. for path in args.path:
  89. if os.path.isdir(path):
  90. files += getfiles(path)
  91. elif os.path.isfile(path):
  92. files.append(path)
  93. else:
  94. raise ValueError("Invalid path {}".format(path))
  95. # check version, we only support 12.0.1 now
  96. version = subprocess.check_output(
  97. [
  98. args.clang_format,
  99. "--version",
  100. ],
  101. )
  102. version = version.decode("utf-8")
  103. need_version = '12.0.1'
  104. if version.find(need_version) < 0:
  105. print('We only support {} now, please install {} version, find version: {}'
  106. .format(need_version, need_version, version))
  107. raise RuntimeError('clang-format version not equal {}'.format(need_version))
  108. process_map(
  109. partial(process_file, clang_format=args.clang_format, write=args.write,),
  110. files,
  111. chunksize=10,
  112. )
  113. if failed_files:
  114. raise RuntimeError("above files are not properly formatted!")
  115. if __name__ == "__main__":
  116. main()