You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_list.py 1.3 kB

1234567891011121314151617181920212223242526272829303132333435363738
  1. from generator import (
  2. GenerateGemmOperations,
  3. GenerateGemvOperations,
  4. GenerateConv2dOperations,
  5. GenerateDeconvOperations,
  6. )
  7. class GenArg:
  8. def __init__(self, gen_op, gen_type):
  9. self.operations = gen_op
  10. self.type = gen_type
  11. def write_op_list(f, gen_op, gen_type):
  12. if gen_op == "gemm":
  13. operations = GenerateGemmOperations(GenArg(gen_op, gen_type))
  14. elif gen_op == "gemv":
  15. operations = GenerateGemvOperations(GenArg(gen_op, gen_type))
  16. elif gen_op == "conv2d":
  17. operations = GenerateConv2dOperations(GenArg(gen_op, gen_type))
  18. elif gen_op == "deconv":
  19. operations = GenerateDeconvOperations(GenArg(gen_op, gen_type))
  20. for op in operations:
  21. f.write(' "%s.cu",\n' % op.procedural_name())
  22. if __name__ == "__main__":
  23. with open("list.bzl", "w") as f:
  24. f.write("# Generated by dnn/scripts/cutlass_generator/gen_list.py\n\n")
  25. f.write("cutlass_gen_list = [\n")
  26. write_op_list(f, "gemm", "simt")
  27. write_op_list(f, "gemv", "simt")
  28. write_op_list(f, "deconv", "simt")
  29. write_op_list(f, "conv2d", "simt")
  30. write_op_list(f, "conv2d", "tensorop8816")
  31. write_op_list(f, "conv2d", "tensorop8832")
  32. f.write("]")

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台