#!/usr/bin/env python3 # -*- coding: utf-8 -*- import argparse import itertools import os PREFIXES = { "dp4a": "conv_bias_int8_implicit_gemm_cdiv4hwn4", "imma": "conv_bias_int8_implicit_gemm", } ACTIVATIONS = {1: ("IDENTITY", "_id"), 2: ("RELU", "_relu"), 3: ("H_SWISH", "_hswish")} BIASES = { 1: ("PerElementBiasVisitor", "_per_elem"), 2: ("PerChannelBiasVisitor", "_per_chan"), } SUFFIXES = { "dp4a": ["", "_ld_64bit", "_ld_64bit_unroll_width", "_unroll_width"], "imma": [ "_imma16x16x16_cdiv4hwn4", "_imma8x32x16_cdiv4hwn4", "_imma32x8x16_cdiv4hwn4", "_imma16x16x16_cdiv4hwn4_reorder_filter", "_imma8x32x16_cdiv4hwn4_reorder_filter", "_imma32x8x16_cdiv4hwn4_reorder_filter", "_imma16x16x16_cdiv4hwn4_unroll_width", "_imma8x32x16_cdiv4hwn4_unroll_width", "_imma32x8x16_cdiv4hwn4_unroll_width", ], } def main(): parser = argparse.ArgumentParser( description="generate cuda conv bias (dp4a/imma) kern impl files", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--type", type=str, choices=["dp4a", "imma"], default="dp4a", help="generate cuda conv bias kernel file", ) parser.add_argument("output", help="output directory") args = parser.parse_args() if not os.path.isdir(args.output): os.makedirs(args.output) inst = """ template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX>>( const int8_t* d_src, const int8_t* d_filter, BIAS bias, IConvEpilogue> epilogue, const ConvParam& param, float alpha, float beta, cudaStream_t stream);""" for suffix in SUFFIXES[args.type]: for _, act in ACTIVATIONS.items(): prefix = PREFIXES[args.type] bias = BIASES[2] fname = "{}{}{}{}.cu".format(prefix, suffix, bias[1], act[1]) fname = os.path.join(args.output, fname) with open(fname, "w") as fout: w = lambda s: print(s, file=fout) w("// generated by gen_cuda_conv_bias_kern_impls.py") cur_inst = ( inst.replace("PREFIX", prefix) .replace("SUFFIX", suffix) .replace("BIAS", bias[0]) .replace("ACTIVATION", act[0]) ) w('#include "../{}{}.cuinl"'.format(prefix, suffix)) w(cur_inst) print("generated {}".format(fname)) os.utime(args.output) if __name__ == "__main__": main()