You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cutlass_float32_simt.cpp 4.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. /**
  2. * \file dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/cuda/cutlass/singleton.h"
  13. #include "src/cuda/handle.h"
  14. #include "src/cuda/matrix_mul/algos.h"
  15. #include "src/cuda/utils.h"
  16. #if CUDA_VERSION >= 9020
  17. using namespace megdnn;
  18. using namespace cuda;
  19. bool MatrixMulForwardImpl::AlgoFloat32SIMT::is_available(
  20. const SizeArgs& args) const {
  21. bool available =
  22. args.opr->param().format == param::MatrixMul::Format::DEFAULT &&
  23. args.layout_a.dtype == dtype::Float32() &&
  24. args.layout_b.dtype == dtype::Float32() &&
  25. args.layout_c.dtype == dtype::Float32();
  26. int n = args.layout_c.shape[1];
  27. auto&& device_prop = cuda::current_device_prop();
  28. int y_grid_limit = device_prop.maxGridSize[1];
  29. // limit y grid
  30. available &= ((n + m_algo_param.threadblock_n - 1) /
  31. m_algo_param.threadblock_n <=
  32. y_grid_limit);
  33. return available;
  34. }
  35. size_t MatrixMulForwardImpl::AlgoFloat32SIMT::get_workspace_in_bytes(
  36. const SizeArgs& args) const {
  37. return 0_z;
  38. }
  39. void MatrixMulForwardImpl::AlgoFloat32SIMT::do_exec(
  40. const ExecArgs& args) const {
  41. int64_t lda = args.tensor_a.layout.stride[0],
  42. ldb = args.tensor_b.layout.stride[0],
  43. ldc = args.tensor_c.layout.stride[0];
  44. auto&& param = args.opr->param();
  45. int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1],
  46. k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1];
  47. cutlass::gemm::GemmCoord problem_size{m, n, k};
  48. auto&& stream = cuda_stream(args.opr->handle());
  49. int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr);
  50. // \note these constants of cutlass epilogue will be passed to struct
  51. // `GemmArguments` by pointer and interpreted as ElementCompute*, a
  52. // different dtype here results in undefined epilogue behaviors
  53. float alpha = 1.f, beta = 0.f;
  54. using namespace cutlass::library;
  55. auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor
  56. : LayoutTypeID::kRowMajor;
  57. auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor
  58. : LayoutTypeID::kRowMajor;
  59. int alignment = min_alignment_requirement();
  60. GemmKey key{NumericTypeID::kF32,
  61. layoutA,
  62. NumericTypeID::kF32,
  63. layoutB,
  64. NumericTypeID::kF32,
  65. LayoutTypeID::kRowMajor,
  66. NumericTypeID::kF32,
  67. m_algo_param.threadblock_m,
  68. m_algo_param.threadblock_n,
  69. m_algo_param.threadblock_k,
  70. m_algo_param.warp_m,
  71. m_algo_param.warp_n,
  72. m_algo_param.warp_k,
  73. 1,
  74. 1,
  75. 1,
  76. 2,
  77. alignment,
  78. alignment,
  79. SplitKMode::kNone};
  80. const Operation* op = Singleton::get().operation_table.find_op(key);
  81. GemmArguments gemm_args{problem_size,
  82. args.tensor_a.raw_ptr,
  83. args.tensor_b.raw_ptr,
  84. args.tensor_c.raw_ptr,
  85. args.tensor_c.raw_ptr,
  86. lda,
  87. ldb,
  88. ldc,
  89. ldc,
  90. 1,
  91. &alpha,
  92. &beta};
  93. cutlass_check(op->run(&gemm_args, workspace, stream));
  94. }
  95. #endif
  96. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台