You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_elemwise_kern_impls.py 1.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import argparse
  4. import itertools
  5. import os
  6. from gen_elemwise_utils import ARITIES, DTYPES, MODES
  7. def main():
  8. parser = argparse.ArgumentParser(
  9. description="generate elemwise impl files",
  10. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  11. )
  12. parser.add_argument(
  13. "--type",
  14. type=str,
  15. choices=["cuda", "hip", "cpp"],
  16. default="cpp",
  17. help="generate cuda/hip kernel file",
  18. )
  19. parser.add_argument("output", help="output directory")
  20. args = parser.parse_args()
  21. if not os.path.isdir(args.output):
  22. os.makedirs(args.output)
  23. if args.type == "cuda":
  24. cpp_ext = "cu"
  25. elif args.type == "hip":
  26. cpp_ext = "cpp.hip"
  27. else:
  28. assert args.type == "cpp"
  29. cpp_ext = "cpp"
  30. for anum, ctype in itertools.product(ARITIES.keys(), DTYPES.keys()):
  31. for mode in MODES[(anum, DTYPES[ctype][1])]:
  32. formode = "MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)".format(mode)
  33. fname = "{}_{}.{}".format(mode, ctype, cpp_ext)
  34. fname = os.path.join(args.output, fname)
  35. with open(fname, "w") as fout:
  36. w = lambda s: print(s, file=fout)
  37. w("// generated by gen_elemwise_kern_impls.py")
  38. if ctype == "dt_float16" or ctype == "dt_bfloat16":
  39. w("#if !MEGDNN_DISABLE_FLOAT16")
  40. w("#define KERN_IMPL_MODE(cb) {}".format(formode))
  41. w("#define KERN_IMPL_ARITY {}".format(anum))
  42. w("#define KERN_IMPL_CTYPE {}".format(ctype))
  43. w('#include "../kern_impl.inl"')
  44. if ctype == "dt_float16" or ctype == "dt_bfloat16":
  45. w("#endif")
  46. print("generated {}".format(fname))
  47. os.utime(args.output)
  48. if __name__ == "__main__":
  49. main()