You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_cond_take_kern_impls.py 1.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import argparse
  4. import os
  5. from gen_elemwise_utils import DTYPES
  6. def main():
  7. parser = argparse.ArgumentParser(
  8. description="generate elemwise impl files",
  9. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  10. )
  11. parser.add_argument(
  12. "--type",
  13. type=str,
  14. choices=["cuda"],
  15. default="cuda",
  16. help="generate cuda cond take kernel file",
  17. )
  18. parser.add_argument("output", help="output directory")
  19. args = parser.parse_args()
  20. if not os.path.isdir(args.output):
  21. os.makedirs(args.output)
  22. assert args.type == "cuda"
  23. cpp_ext = "cu"
  24. for dtype in DTYPES.keys():
  25. fname = "{}.{}".format(dtype, cpp_ext)
  26. fname = os.path.join(args.output, fname)
  27. with open(fname, "w") as fout:
  28. w = lambda s: print(s, file=fout)
  29. w("// generated by gen_cond_take_kern_impls.py")
  30. w('#include "../kern.inl"')
  31. w("")
  32. if dtype == "dt_float16" or dtype == "dt_bfloat16":
  33. w("#if !MEGDNN_DISABLE_FLOAT16")
  34. w("namespace megdnn {")
  35. w("namespace cuda {")
  36. w("namespace cond_take {")
  37. w("")
  38. w("inst_genidx(::megdnn::dtype::{})".format(DTYPES[dtype][0]))
  39. w("#undef inst_genidx")
  40. w("")
  41. w("inst_copy(::megdnn::dtype::{})".format(DTYPES[dtype][0]))
  42. w("#undef inst_copy")
  43. w("#undef inst_copy_")
  44. w("")
  45. w("} // cond_take")
  46. w("} // cuda")
  47. w("} // megdnn")
  48. if dtype == "dt_float16" or dtype == "dt_bfloat16":
  49. w("#endif")
  50. print("generated {}".format(fname))
  51. os.utime(args.output)
  52. if __name__ == "__main__":
  53. main()