You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_cond_take_kern_impls.py 1.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import argparse
  5. from gen_elemwise_utils import DTYPES
  6. def main():
  7. parser = argparse.ArgumentParser(
  8. description='generate elemwise impl files',
  9. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  10. parser.add_argument('--type', type=str, choices=['cuda'],
  11. default='cuda',
  12. help='generate cuda cond take kernel file')
  13. parser.add_argument('output', help='output directory')
  14. args = parser.parse_args()
  15. if not os.path.isdir(args.output):
  16. os.makedirs(args.output)
  17. assert args.type =='cuda'
  18. cpp_ext = 'cu'
  19. for dtype in DTYPES.keys():
  20. fname = '{}.{}'.format(dtype, cpp_ext)
  21. fname = os.path.join(args.output, fname)
  22. with open(fname, 'w') as fout:
  23. w = lambda s: print(s, file=fout)
  24. w('// generated by gen_cond_take_kern_impls.py')
  25. w('#include "../kern.inl"')
  26. w('')
  27. if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
  28. w('#if !MEGDNN_DISABLE_FLOAT16')
  29. w('namespace megdnn {')
  30. w('namespace cuda {')
  31. w('namespace cond_take {')
  32. w('')
  33. w('inst_genidx(::megdnn::dtype::{})'.format(DTYPES[dtype][0]))
  34. w('#undef inst_genidx')
  35. w('')
  36. w('inst_copy(::megdnn::dtype::{})'.format(DTYPES[dtype][0]))
  37. w('#undef inst_copy')
  38. w('#undef inst_copy_')
  39. w('')
  40. w('} // cond_take')
  41. w('} // cuda')
  42. w('} // megdnn')
  43. if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
  44. w('#endif')
  45. print('generated {}'.format(fname))
  46. os.utime(args.output)
  47. if __name__ == '__main__':
  48. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台