You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_elemwise_kern_impls.py 2.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import argparse
  5. import itertools
  6. from gen_elemwise_utils import ARITIES, DTYPES, MODES
  7. def main():
  8. parser = argparse.ArgumentParser(
  9. description='generate elemwise impl files',
  10. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  11. parser.add_argument('--type', type=str, choices=['cuda',
  12. 'hip',
  13. 'cpp'],
  14. default='cpp', help='generate cuda/hip kernel file')
  15. parser.add_argument('output', help='output directory')
  16. args = parser.parse_args()
  17. if not os.path.isdir(args.output):
  18. os.makedirs(args.output)
  19. if args.type == 'cuda':
  20. cpp_ext = 'cu'
  21. elif args.type == 'hip':
  22. cpp_ext = 'cpp.hip'
  23. else:
  24. assert args.type == 'cpp'
  25. cpp_ext = 'cpp'
  26. for anum, ctype in itertools.product(ARITIES.keys(), DTYPES.keys()):
  27. for mode in MODES[(anum, DTYPES[ctype][1])]:
  28. formode = 'MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)'.format(mode)
  29. fname = '{}_{}.{}'.format(mode, ctype, cpp_ext)
  30. fname = os.path.join(args.output, fname)
  31. with open(fname, 'w') as fout:
  32. w = lambda s: print(s, file=fout)
  33. w('// generated by gen_elemwise_kern_impls.py')
  34. if ctype == 'dt_float16' or ctype == 'dt_bfloat16':
  35. w('#if !MEGDNN_DISABLE_FLOAT16')
  36. w('#define KERN_IMPL_MODE(cb) {}'.format(formode))
  37. w('#define KERN_IMPL_ARITY {}'.format(anum))
  38. w('#define KERN_IMPL_CTYPE {}'.format(ctype))
  39. w('#include "../kern_impl.inl"')
  40. if ctype == 'dt_float16' or ctype == 'dt_bfloat16':
  41. w('#endif')
  42. print('generated {}'.format(fname))
  43. os.utime(args.output)
  44. if __name__ == '__main__':
  45. main()