You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 2.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. /**
  2. * \file dnn/src/naive/argmxx/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/naive/argmxx/opr_impl.h"
  12. #include "src/common/utils.h"
  13. #include "src/common/reduce_helper.h"
  14. #include "src/naive/handle.h"
  15. #include <numeric>
  16. namespace {
  17. using namespace megdnn;
  18. template <bool is_max> struct traits;
  19. template <> struct traits<true> {
  20. static const float init;
  21. static bool better_than(float lhs, float rhs)
  22. { return lhs > rhs; }
  23. };
  24. const float traits<true>::init = std::numeric_limits<float>::lowest();
  25. template <> struct traits<false> {
  26. static const float init;
  27. static float better_than(float lhs, float rhs)
  28. { return lhs < rhs; }
  29. };
  30. const float traits<false>::init = std::numeric_limits<float>::max();
  31. template <typename T, bool is_max>
  32. void exec_forward(_megdnn_tensor_in src,
  33. _megdnn_tensor_out dst,
  34. const ArgmxxBase::Param &param)
  35. {
  36. size_t A, B, C;
  37. reduce::get_ABC(src.layout, A, B, C, param.axis);
  38. for (size_t a = 0; a < A; ++a) for (size_t c = 0; c < C; ++c) {
  39. float best_val = traits<is_max>::init;
  40. size_t best_arg = 0;
  41. for (size_t b = 0; b < B; ++b) {
  42. float curr_val = float(src.ptr<T>()[(a*B+b)*C+c]);
  43. if (traits<is_max>::better_than(curr_val, best_val)) {
  44. best_val = curr_val;
  45. best_arg = b;
  46. }
  47. }
  48. dst.ptr<dt_int32>()[a*C+c] = best_arg;
  49. }
  50. }
  51. } // anonymous namespace
  52. namespace megdnn {
  53. namespace naive {
  54. void ArgmaxForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  55. _megdnn_workspace workspace)
  56. {
  57. check_exec(src.layout, dst.layout, workspace.size);
  58. #define cb(DType) \
  59. if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
  60. using ctype = typename DTypeTrait<DType>::ctype; \
  61. MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl*>(handle()), \
  62. exec_forward<ctype MEGDNN_COMMA true>(src, dst, param())); \
  63. }
  64. MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
  65. #undef cb
  66. }
  67. void ArgminForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  68. _megdnn_workspace workspace)
  69. {
  70. check_exec(src.layout, dst.layout, workspace.size);
  71. #define cb(DType) \
  72. if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
  73. using ctype = typename DTypeTrait<DType>::ctype; \
  74. MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl*>(handle()), \
  75. exec_forward<ctype MEGDNN_COMMA false>(src, dst, param())); \
  76. }
  77. MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
  78. #undef cb
  79. }
  80. } // namespace naive
  81. } // namespace megdnn
  82. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台