You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

correlation.cpp 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. /**
  2. * \file dnn/src/common/correlation.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/oprs.h"
  13. #include "src/common/utils.h"
  14. namespace megdnn {
  15. void CorrelationBase::deduce_layout_fwd(const TensorLayout& data1,
  16. const TensorLayout& data2,
  17. TensorLayout& dst) {
  18. megdnn_assert_contiguous(data1);
  19. megdnn_assert_contiguous(data2);
  20. megdnn_assert_contiguous(dst);
  21. auto errmsg = [&]() {
  22. return megdnn_layout_msg(data1) + ", " + megdnn_layout_msg(data2) +
  23. ", " + megdnn_layout_msg(dst);
  24. };
  25. MEGDNN_MARK_USED_VAR(errmsg);
  26. using Format = CorrelationBase::Param::Format;
  27. megdnn_assert(param().format == Format::NCHW);
  28. auto data1_dtype = data1.dtype, data2_dtype = data2.dtype;
  29. megdnn_assert(data1_dtype == data2_dtype &&
  30. data1_dtype.category() == DTypeCategory::FLOAT);
  31. megdnn_assert(data1.ndim == 4_z, "%s", errmsg().c_str());
  32. megdnn_assert(data2.ndim == 4_z, "%s", errmsg().c_str());
  33. uint32_t pad_size = param().pad_size;
  34. uint32_t kernel_size = param().kernel_size;
  35. uint32_t stride1 = param().stride1;
  36. uint32_t stride2 = param().stride2;
  37. uint32_t max_displacement = param().max_displacement;
  38. int paddedbottomheight = data1[2] + 2 * pad_size;
  39. int paddedbottomwidth = data1[3] + 2 * pad_size;
  40. uint32_t kernel_radius = (kernel_size - 1) / 2;
  41. uint32_t border_size = max_displacement + kernel_radius;
  42. uint32_t top_width =
  43. ceil(static_cast<float>(paddedbottomwidth - border_size * 2) /
  44. static_cast<float>(stride1));
  45. uint32_t top_height =
  46. ceil(static_cast<float>(paddedbottomheight - border_size * 2) /
  47. static_cast<float>(stride1));
  48. uint32_t neighborhood_grid_radius = max_displacement / stride2;
  49. uint32_t neighborhood_grid_width = neighborhood_grid_radius * 2 + 1;
  50. uint32_t top_channels = neighborhood_grid_width * neighborhood_grid_width;
  51. megdnn_assert(top_width >= 1 && top_height >= 1);
  52. dst = TensorLayout{{data1[0], top_channels, top_height, top_width},
  53. data1.dtype};
  54. }
  55. void CorrelationBase::check_layout_fwd(const TensorLayout& data1,
  56. const TensorLayout& data2,
  57. const TensorLayout& dst) {
  58. TensorLayout dst_expected;
  59. megdnn_assert_eq_dtype(data1, dst);
  60. megdnn_assert_eq_shape(data1, data2);
  61. deduce_layout_fwd(data1, data2, dst_expected);
  62. megdnn_assert_eq_shape(dst_expected, dst);
  63. }
  64. void CorrelationForward::deduce_layout(const TensorLayout& data1,
  65. const TensorLayout& data2,
  66. TensorLayout& dst) {
  67. deduce_layout_fwd(data1, data2, dst);
  68. }
  69. void CorrelationForward::check_exec(const TensorLayout& data1,
  70. const TensorLayout& data2,
  71. const TensorLayout& dst,
  72. size_t workspace_in_bytes) {
  73. check_layout_fwd(data1, data2, dst);
  74. auto required_workspace_in_bytes =
  75. get_workspace_in_bytes(data1, data2, dst);
  76. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  77. }
  78. void CorrelationBackwardData1::check_exec(const TensorLayout& diff,
  79. const TensorLayout& data1,
  80. const TensorLayout& data2,
  81. const TensorLayout& grad1,
  82. size_t workspace_in_bytes) {
  83. check_layout_fwd(grad1, data2, diff);
  84. megdnn_assert_eq_shape(data1, data2);
  85. auto required_workspace_in_bytes =
  86. get_workspace_in_bytes(diff, data1, data2, grad1);
  87. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  88. }
  89. void CorrelationBackwardData2::check_exec(const TensorLayout& diff,
  90. const TensorLayout& data1,
  91. const TensorLayout& data2,
  92. const TensorLayout& grad2,
  93. size_t workspace_in_bytes) {
  94. check_layout_fwd(data1, grad2, diff);
  95. megdnn_assert_eq_shape(data1, data2);
  96. auto required_workspace_in_bytes =
  97. get_workspace_in_bytes(diff, data1, data2, grad2);
  98. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  99. }
  100. void CorrelationBackwardData2::deduce_layout(const TensorLayout& diff,
  101. const TensorLayout& data1,
  102. const TensorLayout& data2,
  103. TensorLayout& grad) {
  104. megdnn_assert_eq_shape(data1, data2);
  105. check_layout_fwd(data1, data2, diff);
  106. grad = data2;
  107. }
  108. void CorrelationBackwardData1::deduce_layout(const TensorLayout& diff,
  109. const TensorLayout& data1,
  110. const TensorLayout& data2,
  111. TensorLayout& grad) {
  112. megdnn_assert_eq_shape(data1, data2);
  113. check_layout_fwd(data1, data2, diff);
  114. grad = data1;
  115. }
  116. } // namespace megdnn
  117. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台