You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_cuda_conv_bias_kern_impls.py 2.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import argparse
  4. import itertools
  5. import os
  6. PREFIXES = {
  7. "dp4a": "conv_bias_int8_implicit_gemm_cdiv4hwn4",
  8. "imma": "conv_bias_int8_implicit_gemm",
  9. }
  10. ACTIVATIONS = {1: ("IDENTITY", "_id"), 2: ("RELU", "_relu"), 3: ("H_SWISH", "_hswish")}
  11. BIASES = {
  12. 1: ("PerElementBiasVisitor", "_per_elem"),
  13. 2: ("PerChannelBiasVisitor", "_per_chan"),
  14. }
  15. SUFFIXES = {
  16. "dp4a": ["", "_ld_64bit", "_ld_64bit_unroll_width", "_unroll_width"],
  17. "imma": [
  18. "_imma16x16x16_cdiv4hwn4",
  19. "_imma8x32x16_cdiv4hwn4",
  20. "_imma32x8x16_cdiv4hwn4",
  21. "_imma16x16x16_cdiv4hwn4_reorder_filter",
  22. "_imma8x32x16_cdiv4hwn4_reorder_filter",
  23. "_imma32x8x16_cdiv4hwn4_reorder_filter",
  24. "_imma16x16x16_cdiv4hwn4_unroll_width",
  25. "_imma8x32x16_cdiv4hwn4_unroll_width",
  26. "_imma32x8x16_cdiv4hwn4_unroll_width",
  27. ],
  28. }
  29. def main():
  30. parser = argparse.ArgumentParser(
  31. description="generate cuda conv bias (dp4a/imma) kern impl files",
  32. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  33. )
  34. parser.add_argument(
  35. "--type",
  36. type=str,
  37. choices=["dp4a", "imma"],
  38. default="dp4a",
  39. help="generate cuda conv bias kernel file",
  40. )
  41. parser.add_argument("output", help="output directory")
  42. args = parser.parse_args()
  43. if not os.path.isdir(args.output):
  44. os.makedirs(args.output)
  45. inst = """
  46. template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS,
  47. IConvEpilogue<Activation<megdnn::param_enumv::ConvBias::NonlineMode::ACTIVATION>>>(
  48. const int8_t* d_src,
  49. const int8_t* d_filter,
  50. BIAS bias,
  51. IConvEpilogue<Activation<megdnn::param_enumv::ConvBias::NonlineMode::ACTIVATION>> epilogue,
  52. const ConvParam& param,
  53. float alpha,
  54. float beta,
  55. cudaStream_t stream);"""
  56. for suffix in SUFFIXES[args.type]:
  57. for _, act in ACTIVATIONS.items():
  58. prefix = PREFIXES[args.type]
  59. bias = BIASES[2]
  60. fname = "{}{}{}{}.cu".format(prefix, suffix, bias[1], act[1])
  61. fname = os.path.join(args.output, fname)
  62. with open(fname, "w") as fout:
  63. w = lambda s: print(s, file=fout)
  64. w("// generated by gen_cuda_conv_bias_kern_impls.py")
  65. cur_inst = (
  66. inst.replace("PREFIX", prefix)
  67. .replace("SUFFIX", suffix)
  68. .replace("BIAS", bias[0])
  69. .replace("ACTIVATION", act[0])
  70. )
  71. w('#include "../{}{}.cuinl"'.format(prefix, suffix))
  72. w(cur_inst)
  73. print("generated {}".format(fname))
  74. os.utime(args.output)
  75. if __name__ == "__main__":
  76. main()