You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_cuda_batch_conv_bias_kern_impls.py 2.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import argparse
  4. import itertools
  5. import os
  6. PREFIXES = {
  7. "dp4a": [
  8. ("batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4", True),
  9. ("batch_conv_bias_int8_gemm_ncdiv4hw4", False),
  10. ("batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128", False),
  11. ]
  12. }
  13. ACTIVATIONS = {1: ("IDENTITY", "_id"), 2: ("RELU", "_relu"), 3: ("H_SWISH", "_hswish")}
  14. BIASES = {
  15. 1: ("PerElementBiasVisitor", "_per_elem"),
  16. 2: ("PerChannelBiasVisitor", "_per_chan"),
  17. }
  18. SUFFIXES = {"dp4a": [""], "imma": [""]}
  19. def main():
  20. parser = argparse.ArgumentParser(
  21. description="generate cuda batch conv bias (dp4a/imma) kern impl files",
  22. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  23. )
  24. parser.add_argument(
  25. "--type",
  26. type=str,
  27. choices=["dp4a", "imma"],
  28. default="dp4a",
  29. help="generate cuda conv bias kernel file",
  30. )
  31. parser.add_argument("output", help="output directory")
  32. args = parser.parse_args()
  33. if not os.path.isdir(args.output):
  34. os.makedirs(args.output)
  35. inst = """
  36. template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS,
  37. IConvEpilogue<Activation<megdnn::param_enumv::BatchConvBias::NonlineMode::ACTIVATION>>>(
  38. const int8_t* d_src,
  39. const int8_t* d_filter, WORKSPACE
  40. BIAS bias,
  41. IConvEpilogue<Activation<megdnn::param_enumv::BatchConvBias::NonlineMode::ACTIVATION>> epilogue,
  42. const ConvParam& param,
  43. float alpha,
  44. float beta,
  45. cudaStream_t stream);"""
  46. for prefix in PREFIXES[args.type]:
  47. for suffix in SUFFIXES[args.type]:
  48. for _, act in ACTIVATIONS.items():
  49. has_workspace = prefix[1]
  50. bias = BIASES[2]
  51. fname = "{}{}{}{}.cu".format(prefix[0], suffix, bias[1], act[1])
  52. fname = os.path.join(args.output, fname)
  53. with open(fname, "w") as fout:
  54. w = lambda s: print(s, file=fout)
  55. w("// generated by gen_batch_cuda_conv_bias_kern_impls.py")
  56. cur_inst = (
  57. inst.replace("PREFIX", prefix[0])
  58. .replace("SUFFIX", suffix)
  59. .replace("BIAS", bias[0])
  60. .replace("ACTIVATION", act[0])
  61. )
  62. if has_workspace:
  63. cur_inst = cur_inst.replace("WORKSPACE", "\nint* d_workspace, ")
  64. else:
  65. cur_inst = cur_inst.replace("WORKSPACE", "")
  66. w('#include "../{}{}.cuinl"'.format(prefix[0], suffix))
  67. w(cur_inst)
  68. print("generated {}".format(fname))
  69. os.utime(args.output)
  70. if __name__ == "__main__":
  71. main()