You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 2.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. /**
  2. * \file dnn/src/naive/relayout/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/naive/relayout/opr_impl.h"
  12. #include "megdnn/tensor_iter.h"
  13. #include "src/common/utils.h"
  14. #include "src/naive/handle.h"
  15. #include "midout.h"
  16. MIDOUT_DECL(naive_relayout)
  17. using namespace megdnn;
  18. using namespace naive;
  19. namespace {
  20. template <typename ctype>
  21. void do_copy(const TensorND& dst, const TensorND& src) {
  22. auto idst = tensor_iter_valonly<ctype>(dst).begin(),
  23. isrc = tensor_iter_valonly<ctype>(src).begin();
  24. for (size_t i = 0, it = dst.layout.total_nr_elems(); i < it; ++i) {
  25. *idst = *isrc;
  26. ++idst;
  27. ++isrc;
  28. }
  29. }
  30. bool is_cpu_handle(Handle* handle) {
  31. megcorePlatform_t plat;
  32. megcoreDeviceHandle_t dh;
  33. megcoreGetDeviceHandle(handle->megcore_computing_handle(), &dh);
  34. megcoreGetPlatform(dh, &plat);
  35. return plat == megcorePlatformCPU;
  36. }
  37. } // namespace
  38. void RelayoutForwardImpl::exec(
  39. _megdnn_tensor_in src0, _megdnn_tensor_out dst0, Handle* src_handle) {
  40. check_cpu_handle(src_handle);
  41. TensorND src = src0, dst = dst0;
  42. check_layout_and_canonize(src.layout, dst.layout);
  43. do_exec(src, dst);
  44. }
  45. void RelayoutForwardImpl::do_exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
  46. MIDOUT_BEGIN(naive_relayout, midout_iv(0)) {
  47. switch (src.layout.dtype.enumv()) {
  48. #define cb(_dt) \
  49. case DTypeEnum::_dt: { \
  50. MEGDNN_DISPATCH_CPU_KERN_OPR( \
  51. do_copy<DTypeTrait<dtype::_dt>::ctype>(dst, src)); \
  52. return; \
  53. }
  54. MEGDNN_FOREACH_DTYPE_NAME(cb)
  55. MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb)
  56. #undef cb
  57. default:
  58. megdnn_throw("bad dtype");
  59. }
  60. }
  61. MIDOUT_END();
  62. }
  63. void RelayoutForwardImpl::check_cpu_handle(Handle* handle) {
  64. megdnn_assert(
  65. !handle || handle == this->handle() || is_cpu_handle(handle),
  66. "relayout from non-CPU to CPU not supported");
  67. }
  68. // vim: syntax=cpp.doxygen