You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. /**
  2. * \file dnn/src/naive/batched_matrix_mul/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/naive/batched_matrix_mul/opr_impl.h"
  12. #include "src/naive/matrix_mul/opr_impl.h"
  13. #include "src/naive/handle.h"
  14. #include "src/common/utils.h"
  15. namespace megdnn {
  16. namespace naive {
  17. BatchedMatrixMulForwardImpl::BatchedMatrixMulForwardImpl(Handle *handle):
  18. BatchedMatrixMulForward(handle),
  19. m_opr(this->handle()->create_operator<MatrixMulForward>())
  20. {
  21. }
  22. size_t BatchedMatrixMulForwardImpl::get_workspace_in_bytes(
  23. const TensorLayout &A, const TensorLayout &B,
  24. const TensorLayout &C) {
  25. MEGDNN_MARK_USED_VAR(A);
  26. MEGDNN_MARK_USED_VAR(B);
  27. MEGDNN_MARK_USED_VAR(C);
  28. return 0;
  29. }
  30. void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A,
  31. _megdnn_tensor_in B,
  32. _megdnn_tensor_out C,
  33. _megdnn_workspace workspace) {
  34. check_exec(A.layout, B.layout, C.layout, workspace.size);
  35. m_opr->param() = this->param();
  36. auto N = A.layout.shape[0];
  37. TensorND A_, B_, C_;
  38. A_.raw_ptr = A.raw_ptr;
  39. A_.layout = A.layout.remove_axis(0);
  40. B_.raw_ptr = B.raw_ptr;
  41. B_.layout = B.layout.remove_axis(0);
  42. C_.raw_ptr = C.raw_ptr;
  43. C_.layout = C.layout.remove_axis(0);
  44. auto Astrd = A.layout.dtype.size() * A.layout.stride[0],
  45. Bstrd = B.layout.dtype.size() * B.layout.stride[0],
  46. Cstrd = C.layout.dtype.size() * C.layout.stride[0];
  47. auto advance_ptr = [](TensorND &dest, ptrdiff_t d) {
  48. dest.raw_ptr = static_cast<void*>(
  49. static_cast<dt_byte*>(dest.raw_ptr) + d);
  50. };
  51. rep(n, N) {
  52. m_opr->exec(A_, B_, C_, workspace);
  53. advance_ptr(A_, Astrd);
  54. advance_ptr(B_, Bstrd);
  55. advance_ptr(C_, Cstrd);
  56. }
  57. }
  58. std::vector<BatchedMatrixMulForward::Algorithm*>
  59. BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/,
  60. const TensorLayout& /*B*/,
  61. const TensorLayout& /*C*/) {
  62. return {static_cast<HandleImpl*>(handle())
  63. ->default_batched_matmul_fwd_algo()};
  64. }
  65. BatchedMatrixMulForward::Algorithm*
  66. BatchedMatrixMulForwardImpl::get_algorithm_heuristic(
  67. const TensorLayout& /*A*/, const TensorLayout& /*B*/,
  68. const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/,
  69. const AlgoAttribute& /*attr*/) {
  70. return static_cast<HandleImpl*>(handle())
  71. ->default_batched_matmul_fwd_algo();
  72. }
  73. BatchedMatrixMulForward::Algorithm*
  74. BatchedMatrixMulForwardImpl::get_algorithm_from_desc(
  75. const AlgorithmDesc& desc) {
  76. Algorithm* ret = static_cast<HandleImpl*>(handle())
  77. ->default_batched_matmul_fwd_algo();
  78. megdnn_assert(desc == ret->info().desc);
  79. return ret;
  80. }
  81. } // namespace naive
  82. } // namespace megdnn
  83. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台