You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_list.py 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. from generator import ( # isort: skip; isort: skip
  2. GenerateConv2dOperations,
  3. GenerateDeconvOperations,
  4. GenerateDwconv2dDgradOperations,
  5. GenerateDwconv2dFpropOperations,
  6. GenerateDwconv2dWgradOperations,
  7. GenerateGemmOperations,
  8. GenerateGemvOperations,
  9. )
  10. class GenArg:
  11. def __init__(self, gen_op, gen_type):
  12. self.operations = gen_op
  13. self.type = gen_type
  14. def write_op_list(f, gen_op, gen_type):
  15. if gen_op == "gemm":
  16. operations = GenerateGemmOperations(GenArg(gen_op, gen_type))
  17. elif gen_op == "gemv":
  18. operations = GenerateGemvOperations(GenArg(gen_op, gen_type))
  19. elif gen_op == "conv2d":
  20. operations = GenerateConv2dOperations(GenArg(gen_op, gen_type))
  21. elif gen_op == "deconv":
  22. operations = GenerateDeconvOperations(GenArg(gen_op, gen_type))
  23. elif gen_op == "dwconv2d_fprop":
  24. operations = GenerateDwconv2dFpropOperations(GenArg(gen_op, gen_type))
  25. elif gen_op == "dwconv2d_dgrad":
  26. operations = GenerateDwconv2dDgradOperations(GenArg(gen_op, gen_type))
  27. elif gen_op == "dwconv2d_wgrad":
  28. operations = GenerateDwconv2dWgradOperations(GenArg(gen_op, gen_type))
  29. for op in operations:
  30. f.write(' "%s.cu",\n' % op.procedural_name())
  31. if gen_op != "gemv":
  32. f.write(' "all_%s_%s_operations.cu",\n' % (gen_op, gen_type))
  33. # Write down a list of merged filenames
  34. def write_merge_file_name(f, gen_op, gen_type, split_number):
  35. for i in range(0, split_number):
  36. f.write(' "{}_{}_{}.cu",\n'.format(gen_op, gen_type, i))
  37. if gen_op != "gemv":
  38. f.write(' "all_{}_{}_operations.cu",\n'.format(gen_op, gen_type))
  39. if __name__ == "__main__":
  40. with open("list.bzl", "w") as f:
  41. f.write("# Generated by dnn/scripts/cutlass_generator/gen_list.py\n\n")
  42. f.write("cutlass_gen_list = [\n")
  43. write_merge_file_name(f, "gemm", "simt", 2)
  44. write_merge_file_name(f, "gemm", "tensorop884", 30)
  45. write_merge_file_name(f, "gemm", "tensorop1688", 2)
  46. write_merge_file_name(f, "gemv", "simt", 2)
  47. write_merge_file_name(f, "deconv", "simt", 2)
  48. write_merge_file_name(f, "deconv", "tensorop8816", 4)
  49. write_merge_file_name(f, "conv2d", "simt", 2)
  50. write_merge_file_name(f, "conv2d", "tensorop8816", 4)
  51. write_merge_file_name(f, "conv2d", "tensorop8832", 4)
  52. write_merge_file_name(f, "dwconv2d_fprop", "simt", 2)
  53. write_merge_file_name(f, "dwconv2d_fprop", "tensorop884", 4)
  54. write_merge_file_name(f, "dwconv2d_dgrad", "simt", 2)
  55. write_merge_file_name(f, "dwconv2d_dgrad", "tensorop884", 4)
  56. write_merge_file_name(f, "dwconv2d_wgrad", "simt", 2)
  57. write_merge_file_name(f, "dwconv2d_wgrad", "tensorop884", 4)
  58. f.write("]")