You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_elemwise_special_kern_impls.py 1.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import argparse
  4. import os
  5. from gen_elemwise_utils import DTYPES
  6. def main():
  7. parser = argparse.ArgumentParser(
  8. description="generate elemwise impl files",
  9. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  10. )
  11. parser.add_argument(
  12. "--type",
  13. type=str,
  14. choices=["cuda", "hip"],
  15. default="cuda",
  16. help="generate cuda/hip elemwise special kernel file",
  17. )
  18. parser.add_argument("output", help="output directory")
  19. args = parser.parse_args()
  20. if not os.path.isdir(args.output):
  21. os.makedirs(args.output)
  22. if args.type == "cuda":
  23. cpp_ext = "cu"
  24. else:
  25. assert args.type == "hip"
  26. cpp_ext = "cpp.hip"
  27. for dtype in DTYPES.keys():
  28. fname = "special_{}.{}".format(dtype, cpp_ext)
  29. fname = os.path.join(args.output, fname)
  30. with open(fname, "w") as fout:
  31. w = lambda s: print(s, file=fout)
  32. w("// generated by gen_elemwise_special_kern_impls.py")
  33. if dtype == "dt_float16" or dtype == "dt_bfloat16":
  34. w("#if !MEGDNN_DISABLE_FLOAT16")
  35. w('#include "../special_kerns.inl"')
  36. w("INST(::megdnn::dtype::{})".format(DTYPES[dtype][0]))
  37. w("#undef INST")
  38. w("}")
  39. w("}")
  40. if dtype == "dt_float16" or dtype == "dt_bfloat16":
  41. w("#endif")
  42. print("generated {}".format(fname))
  43. os.utime(args.output)
  44. if __name__ == "__main__":
  45. main()