|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
-
- import argparse
- import os
-
- from gen_elemwise_utils import DTYPES
-
-
- def main():
- parser = argparse.ArgumentParser(
- description="generate elemwise impl files",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
- parser.add_argument(
- "--type",
- type=str,
- choices=["cuda"],
- default="cuda",
- help="generate cuda cond take kernel file",
- )
- parser.add_argument("output", help="output directory")
- args = parser.parse_args()
-
- if not os.path.isdir(args.output):
- os.makedirs(args.output)
-
- assert args.type == "cuda"
- cpp_ext = "cu"
-
- for dtype in DTYPES.keys():
- fname = "{}.{}".format(dtype, cpp_ext)
- fname = os.path.join(args.output, fname)
- with open(fname, "w") as fout:
- w = lambda s: print(s, file=fout)
-
- w("// generated by gen_cond_take_kern_impls.py")
- w('#include "../kern.inl"')
- w("")
- if dtype == "dt_float16" or dtype == "dt_bfloat16":
- w("#if !MEGDNN_DISABLE_FLOAT16")
- w("namespace megdnn {")
- w("namespace cuda {")
- w("namespace cond_take {")
- w("")
-
- w("inst_genidx(::megdnn::dtype::{})".format(DTYPES[dtype][0]))
- w("#undef inst_genidx")
- w("")
- w("inst_copy(::megdnn::dtype::{})".format(DTYPES[dtype][0]))
- w("#undef inst_copy")
- w("#undef inst_copy_")
-
- w("")
- w("} // cond_take")
- w("} // cuda")
- w("} // megdnn")
- if dtype == "dt_float16" or dtype == "dt_bfloat16":
- w("#endif")
-
- print("generated {}".format(fname))
-
- os.utime(args.output)
-
-
- if __name__ == "__main__":
- main()
|