You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

embed_cache.py 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. #
  5. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. #
  7. # Unless required by applicable law or agreed to in writing,
  8. # software distributed under the License is distributed on an
  9. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. # 为了保证全局图优化里的 profiling 结果不受到 ci 环境的影响,所以把写死的 profiling 数据存到了 cache 里去,
  11. # 每次跑测试会从内存 cache 里读取 profiling 结果,然后根据 profiling 结果去做全局图优化,这样确保每次运行
  12. # 结果都是一致的。
  13. # ProfilerCache 可以支持把内存中 cache 下来的 profiling 数据 dump 成文件。
  14. # 这个脚本就是用于把 dump 出去的 cache 文件打包成 cache 的头文件,用于测试时读取数据,构建 InMemory 的 ProfilerCache 。
  15. # 如果在 src/gopt/test/layout_transform_pass.cpp 里新添加了全局图优化相关的测试,则需要考虑用这个脚本来
  16. # 更新 cache 头文件中的 profiling 数据。
  17. # 1. 首先将 src/gopt/test/layout_transform_pass.cpp 中的 `#define MGB_WITH_CACHED_TEST 1` 修改为
  18. # `#define MGB_WITH_CACHED_TEST 0`
  19. # 2. 编译megbrain_test,并运行所有全局图优化相关测试:
  20. # ./megbrain_test --gtest_filter="*LayoutTransform*"
  21. # 3. 用这个脚本把所有的cache文件打包在一起
  22. # python3 embed_cache.py -o cache_data.h -r $(ls /path/to/cache/*.cache)
  23. # 4. 将步骤1中的 define 语句改回原样,这样 profile 过程就会使用 cache 下来的数据。
  24. # 5. 最后可以重新构建一下 megbrain_test ,确保测试结果正确。
  25. import os.path
  26. import logging
  27. import hashlib
  28. import argparse
  29. import struct
  30. import itertools
  31. import sys
  32. import subprocess
  33. import re
  34. logger = logging.getLogger(__name__)
  35. logging.basicConfig(level=logging.WARNING, format='%(asctime)-15s %(message)s')
  36. CHAR_MAP = {i: r'{}'.format(i) for i in range(256)}
  37. def _u32(data):
  38. return struct.unpack('<I', data)[0]
  39. class CacheDataGenerator:
  40. _cache_files = None
  41. def __init__(self, cache_files, remove_plat_info = True):
  42. self._cache_files = cache_files
  43. self._remove_plat_info = remove_plat_info
  44. def _get_hash(self):
  45. return _u32(self._hash.digest()[:4])
  46. def gen_cache_data(self, fpath):
  47. fname = os.path.basename(fpath)
  48. with open(fpath, 'rb') as fcache:
  49. cache_data = fcache.read()
  50. if self._remove_plat_info:
  51. for matched in re.finditer(
  52. rb"(layout_transform_profile:plat=.*);dev=.*;cap=\d.\d",
  53. cache_data
  54. ):
  55. plat_info = matched.group(1)
  56. cat_info = cache_data[matched.span()[0] - 4: matched.span()[1]]
  57. cache_data = re.sub(cat_info, struct.pack('I', len(plat_info)) + plat_info, cache_data)
  58. cache_data = struct.unpack(
  59. "<{}B".format(len(cache_data)), cache_data)
  60. ret = list(map(CHAR_MAP.__getitem__, cache_data))
  61. for i in range(50, len(ret), 50):
  62. ret[i] = '\n' + ret[i]
  63. return ','.join(ret)
  64. def gen_cache_data_header(self, fout, src_map):
  65. fout.write('// generated embed_cache.py\n')
  66. fout.write('#include <vector>\n')
  67. fout.write('#include <stdint.h>\n')
  68. for k, v in sorted(src_map.items()):
  69. fout.write("""
  70. static const std::vector<uint8_t> {} = {{
  71. """.format(k.replace('.', '_')))
  72. fout.write('{}'.format(v))
  73. fout.write('};\n')
  74. def invoke(self, output):
  75. logger.info('generate cache_data.h ...')
  76. fname2cache_data = {}
  77. for fname in self._cache_files:
  78. base, ext = os.path.splitext(os.path.basename(fname))
  79. assert ext == ".cache", "ext: {}, fname {}".format(ext, fname)
  80. assert base not in fname2cache_data, "duplicated kernel: " + base
  81. fname2cache_data[base] = self.gen_cache_data(fname)
  82. with open(output, 'w') as fout:
  83. self.gen_cache_data_header(fout, fname2cache_data)
  84. logger.info('done')
  85. if __name__ == '__main__':
  86. parser = argparse.ArgumentParser(
  87. description='embed cubin into cpp source file',
  88. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  89. parser.add_argument('-o', '--output', help='output source file',
  90. required=True)
  91. parser.add_argument(
  92. "-r",
  93. "--remove-plat-info",
  94. action='store_true',
  95. default=True,
  96. help="whether remove platform infomation in the cache (default: True)"
  97. )
  98. parser.add_argument('cache', help='cache files to be embedded', nargs='+')
  99. args = parser.parse_args()
  100. cache_generator = CacheDataGenerator(args.cache, args.remove_plat_info)
  101. cache_generator.invoke(args.output)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台