You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

format.py 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. #!/usr/bin/env python3
  2. # This file is part of MegBrain.
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. import argparse
  5. import os
  6. import re
  7. import subprocess
  8. import tempfile
  9. from functools import partial
  10. from multiprocessing import Manager
  11. from tqdm.contrib.concurrent import process_map
  12. # change workspace to MegBrain root dir
  13. os.chdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
  14. failed_files = Manager().list()
  15. def process_file(file, clang_format, write):
  16. source = open(file, "r").read()
  17. source = re.sub(r"MGB_DEFINE(?P<r>(.|\n)*?)// +{", "class MGB_DEFINE\g<r>{", source)
  18. result = subprocess.check_output(
  19. [
  20. clang_format,
  21. "-style=file",
  22. "-verbose",
  23. "-assume-filename={}".format(file),
  24. # file,
  25. ],
  26. input=bytes(source.encode("utf-8")),
  27. )
  28. result = result.decode("utf-8")
  29. result = re.sub(r"class MGB_DEFINE((.|\n)*?){", r"MGB_DEFINE\1// {", result)
  30. if write:
  31. with tempfile.NamedTemporaryFile(
  32. dir=os.path.dirname(file), delete=False
  33. ) as tmp_file:
  34. tmp_file.write(result.encode("utf-8"))
  35. os.rename(tmp_file.name, file)
  36. else:
  37. ret_code = subprocess.run(
  38. ["diff", "--color=always", file, "-"], input=bytes(result.encode("utf-8")),
  39. ).returncode
  40. # man diff: 0 for same, 1 for different, 2 if trouble.
  41. if ret_code == 2:
  42. raise RuntimeError("format process (without overwrite) failed")
  43. if ret_code != 0:
  44. print(file)
  45. global failed_files
  46. failed_files.append(file)
  47. def main():
  48. parser = argparse.ArgumentParser(
  49. description="Format source files using clang-format, eg: `./tools/format.py src -w`. \
  50. Require clang-format version == 12.0",
  51. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  52. )
  53. parser.add_argument(
  54. "path", nargs="+", help="file name or path based on MegBrain root dir."
  55. )
  56. parser.add_argument(
  57. "-w",
  58. "--write",
  59. action="store_true",
  60. help="use formatted file to replace original file.",
  61. )
  62. parser.add_argument(
  63. "--clang-format",
  64. default=os.getenv("CLANG_FORMAT", "clang-format"),
  65. help="clang-format executable name; it can also be "
  66. "modified via the CLANG_FORMAT environment var",
  67. )
  68. args = parser.parse_args()
  69. format_type = [".cpp", ".c", ".h", ".cu", ".cuh", ".inl"]
  70. def getfiles(path):
  71. rst = []
  72. for p in os.listdir(path):
  73. p = os.path.join(path, p)
  74. if os.path.isdir(p):
  75. rst += getfiles(p)
  76. elif (
  77. os.path.isfile(p)
  78. and not os.path.islink(p)
  79. and os.path.splitext(p)[1] in format_type
  80. ):
  81. rst.append(p)
  82. return rst
  83. files = []
  84. for path in args.path:
  85. if os.path.isdir(path):
  86. files += getfiles(path)
  87. elif os.path.isfile(path):
  88. files.append(path)
  89. else:
  90. raise ValueError("Invalid path {}".format(path))
  91. # check version, we only support 12.0.1 now
  92. version = subprocess.check_output(
  93. [
  94. args.clang_format,
  95. "--version",
  96. ],
  97. )
  98. version = version.decode("utf-8")
  99. need_version = '12.0.1'
  100. if version.find(need_version) < 0:
  101. print('We only support {} now, please install {} version, find version: {}'
  102. .format(need_version, need_version, version))
  103. raise RuntimeError('clang-format version not equal {}'.format(need_version))
  104. process_map(
  105. partial(process_file, clang_format=args.clang_format, write=args.write,),
  106. files,
  107. chunksize=10,
  108. )
  109. if failed_files:
  110. raise RuntimeError("above files are not properly formatted!")
  111. if __name__ == "__main__":
  112. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台