You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_list.py 1.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. from generator import (
  2. GenerateGemmOperations,
  3. GenerateGemvOperations,
  4. GenerateConv2dOperations,
  5. GenerateDeconvOperations,
  6. )
  7. class GenArg:
  8. def __init__(self, gen_op, gen_type):
  9. self.operations = gen_op
  10. self.type = gen_type
  11. def write_op_list(f, gen_op, gen_type):
  12. if gen_op == "gemm":
  13. operations = GenerateGemmOperations(GenArg(gen_op, gen_type))
  14. elif gen_op == "gemv":
  15. operations = GenerateGemvOperations(GenArg(gen_op, gen_type))
  16. elif gen_op == "conv2d":
  17. operations = GenerateConv2dOperations(GenArg(gen_op, gen_type))
  18. elif gen_op == "deconv":
  19. operations = GenerateDeconvOperations(GenArg(gen_op, gen_type))
  20. for op in operations:
  21. f.write(' "%s.cu",\n' % op.procedural_name())
  22. if gen_op != "gemv":
  23. f.write(' "all_%s_%s_operations.cu",\n' % (gen_op, gen_type))
  24. if __name__ == "__main__":
  25. with open("list.bzl", "w") as f:
  26. f.write("# Generated by dnn/scripts/cutlass_generator/gen_list.py\n\n")
  27. f.write("cutlass_gen_list = [\n")
  28. write_op_list(f, "gemm", "simt")
  29. write_op_list(f, "gemm", "tensorop1688")
  30. write_op_list(f, "gemm", "tensorop884")
  31. write_op_list(f, "gemv", "simt")
  32. write_op_list(f, "deconv", "simt")
  33. write_op_list(f, "deconv", "tensorop8816")
  34. write_op_list(f, "conv2d", "simt")
  35. write_op_list(f, "conv2d", "tensorop8816")
  36. write_op_list(f, "conv2d", "tensorop8832")
  37. f.write("]")