You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_list.py 2.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from generator import (
  2. GenerateGemmOperations,
  3. GenerateGemvOperations,
  4. GenerateConv2dOperations,
  5. GenerateDeconvOperations,
  6. GenerateDwconv2dFpropOperations,
  7. GenerateDwconv2dDgradOperations,
  8. )
  9. class GenArg:
  10. def __init__(self, gen_op, gen_type):
  11. self.operations = gen_op
  12. self.type = gen_type
  13. def write_op_list(f, gen_op, gen_type):
  14. if gen_op == "gemm":
  15. operations = GenerateGemmOperations(GenArg(gen_op, gen_type))
  16. elif gen_op == "gemv":
  17. operations = GenerateGemvOperations(GenArg(gen_op, gen_type))
  18. elif gen_op == "conv2d":
  19. operations = GenerateConv2dOperations(GenArg(gen_op, gen_type))
  20. elif gen_op == "deconv":
  21. operations = GenerateDeconvOperations(GenArg(gen_op, gen_type))
  22. elif gen_op == "dwconv2d_fprop":
  23. operations = GenerateDwconv2dFpropOperations(GenArg(gen_op, gen_type))
  24. elif gen_op == "dwconv2d_dgrad":
  25. operations = GenerateDwconv2dDgradOperations(GenArg(gen_op, gen_type))
  26. elif gen_op == "dwconv2d_wgrad":
  27. pass
  28. for op in operations:
  29. f.write(' "%s.cu",\n' % op.procedural_name())
  30. if gen_op != "gemv":
  31. f.write(' "all_%s_%s_operations.cu",\n' % (gen_op, gen_type))
  32. if __name__ == "__main__":
  33. with open("list.bzl", "w") as f:
  34. f.write("# Generated by dnn/scripts/cutlass_generator/gen_list.py\n\n")
  35. f.write("cutlass_gen_list = [\n")
  36. write_op_list(f, "gemm", "simt")
  37. write_op_list(f, "gemm", "tensorop1688")
  38. write_op_list(f, "gemm", "tensorop884")
  39. write_op_list(f, "gemv", "simt")
  40. write_op_list(f, "deconv", "simt")
  41. write_op_list(f, "deconv", "tensorop8816")
  42. write_op_list(f, "conv2d", "simt")
  43. write_op_list(f, "conv2d", "tensorop8816")
  44. write_op_list(f, "conv2d", "tensorop8832")
  45. write_op_list(f, "dwconv2d_fprop", "simt")
  46. write_op_list(f, "dwconv2d_fprop", "tensorop884")
  47. write_op_list(f, "dwconv2d_dgrad", "simt")
  48. write_op_list(f, "dwconv2d_dgrad", "tensorop884")
  49. f.write("]")