|
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
-
- import os
- import argparse
- import itertools
-
- PREFIXES = {"dp4a": "conv_bias_int8_implicit_gemm_cdiv4hwn4", "imma": "conv_bias_int8_implicit_gemm"}
-
- ACTIVATIONS = {1: ("IDENTITY", "_id"),
- 2: ("RELU", "_relu"),
- 3: ("H_SWISH", "_hswish")}
-
- BIASES = {1: ("PerElementBiasVisitor", "_per_elem"),
- 2: ("PerChannelBiasVisitor", "_per_chan")}
-
- SUFFIXES = {"dp4a": ["", "_ld_64bit", "_ld_64bit_unroll_width", "_unroll_width"],
- "imma": ["_imma16x16x16_cdiv4hwn4", "_imma8x32x16_cdiv4hwn4", "_imma32x8x16_cdiv4hwn4",
- "_imma16x16x16_cdiv4hwn4_reorder_filter", "_imma8x32x16_cdiv4hwn4_reorder_filter", "_imma32x8x16_cdiv4hwn4_reorder_filter",
- "_imma16x16x16_cdiv4hwn4_unroll_width", "_imma8x32x16_cdiv4hwn4_unroll_width", "_imma32x8x16_cdiv4hwn4_unroll_width"]}
-
- def main():
- parser = argparse.ArgumentParser(
- description='generate cuda conv bias (dp4a/imma) kern impl files',
- formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument('--type', type=str, choices=['dp4a',
- 'imma'],
- default='dp4a', help='generate cuda conv bias kernel file')
- parser.add_argument('output', help='output directory')
- args = parser.parse_args()
-
- if not os.path.isdir(args.output):
- os.makedirs(args.output)
-
-
- inst = '''
- template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS,
- IConvEpilogue<Activation<megdnn::param_enumv::ConvBias::NonlineMode::ACTIVATION>>>(
- const int8_t* d_src,
- const int8_t* d_filter,
- BIAS bias,
- IConvEpilogue<Activation<megdnn::param_enumv::ConvBias::NonlineMode::ACTIVATION>> epilogue,
- const ConvParam& param,
- float alpha,
- float beta,
- cudaStream_t stream);'''
-
- for suffix in SUFFIXES[args.type]:
- for _, act in ACTIVATIONS.items():
- prefix = PREFIXES[args.type]
- bias = BIASES[2]
- fname = "{}{}{}{}.cu".format(prefix, suffix, bias[1], act[1])
- fname = os.path.join(args.output, fname)
- with open(fname, "w") as fout:
- w = lambda s: print(s, file=fout)
- w('// generated by gen_cuda_conv_bias_kern_impls.py')
- cur_inst = inst.replace("PREFIX", prefix).replace("SUFFIX", suffix).replace("BIAS", bias[0]).replace("ACTIVATION", act[0])
- w('#include "../{}{}.cuinl"'.format(prefix, suffix))
- w(cur_inst)
-
- print('generated {}'.format(fname))
- os.utime(args.output)
-
- if __name__ == '__main__':
- main()
|