#!/usr/bin/env python3 # -*- coding: utf-8 -*- import argparse import os from gen_elemwise_utils import DTYPES def main(): parser = argparse.ArgumentParser( description="generate elemwise impl files", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--type", type=str, choices=["cuda", "hip"], default="cuda", help="generate cuda/hip elemwise special kernel file", ) parser.add_argument("output", help="output directory") args = parser.parse_args() if not os.path.isdir(args.output): os.makedirs(args.output) if args.type == "cuda": cpp_ext = "cu" else: assert args.type == "hip" cpp_ext = "cpp.hip" for dtype in DTYPES.keys(): fname = "special_{}.{}".format(dtype, cpp_ext) fname = os.path.join(args.output, fname) with open(fname, "w") as fout: w = lambda s: print(s, file=fout) w("// generated by gen_elemwise_special_kern_impls.py") if dtype == "dt_float16" or dtype == "dt_bfloat16": w("#if !MEGDNN_DISABLE_FLOAT16") w('#include "../special_kerns.inl"') w("INST(::megdnn::dtype::{})".format(DTYPES[dtype][0])) w("#undef INST") w("}") w("}") if dtype == "dt_float16" or dtype == "dt_bfloat16": w("#endif") print("generated {}".format(fname)) os.utime(args.output) if __name__ == "__main__": main()