You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

correlation.cpp 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. /**
  2. * \file dnn/test/cuda/correlation.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "test/cuda/fixture.h"
  13. #include "test/common/checker.h"
  14. #include "test/common/correlation.h"
  15. namespace megdnn {
  16. namespace test {
  17. TEST_F(CUDA, CORRELATION_FORWARD) {
  18. using namespace correlation;
  19. std::vector<TestArg> args = get_args();
  20. Checker<Correlation> checker(handle_cuda());
  21. for (auto&& arg : args) {
  22. checker.set_param(arg.param)
  23. .set_dtype(0, dtype::Float32())
  24. .set_dtype(1, dtype::Float32())
  25. .execs({arg.data1, arg.data2, {}});
  26. }
  27. }
  28. TEST_F(CUDA, CORRELATION_BACKWARDDATA1) {
  29. ConstValue const_0{0};
  30. using Param = CorrelationBackwardData1::Param;
  31. Param param;
  32. param.is_multiply = true;
  33. param.format = Param::Format::NCHW;
  34. param.stride1 = 2;
  35. param.stride2 = 2;
  36. param.kernel_size = 3;
  37. param.pad_size = 4;
  38. Checker<CorrelationBackwardData1> checker(handle_cuda());
  39. checker.set_epsilon(1e-2);
  40. uint32_t pad_size = param.pad_size;
  41. uint32_t kernel_size = param.kernel_size;
  42. uint32_t stride1 = param.stride1;
  43. uint32_t stride2 = param.stride2;
  44. uint32_t max_displacement = param.max_displacement;
  45. auto run = [&](DType dtype) {
  46. for (size_t N : {1, 3})
  47. for (size_t C : {1, 3})
  48. for (size_t OH : {10, 100})
  49. for (size_t OW : {10, 100}) {
  50. int paddedbottomheight = OH + 2 * pad_size;
  51. int paddedbottomwidth = OW + 2 * pad_size;
  52. uint32_t kernel_radius = (kernel_size - 1) / 2;
  53. uint32_t border_size = max_displacement + kernel_radius;
  54. uint32_t top_width =
  55. ceil(static_cast<float>(
  56. paddedbottomwidth - border_size * 2) /
  57. static_cast<float>(stride1));
  58. uint32_t top_height =
  59. ceil(static_cast<float>(
  60. paddedbottomheight - border_size * 2) /
  61. static_cast<float>(stride1));
  62. uint32_t neighborhood_grid_radius = max_displacement / stride2;
  63. uint32_t neighborhood_grid_width =
  64. neighborhood_grid_radius * 2 + 1;
  65. uint32_t top_channels =
  66. neighborhood_grid_width * neighborhood_grid_width;
  67. checker.set_param(param)
  68. .set_dtype(0, dtype)
  69. .set_dtype(1, dtype)
  70. .set_dtype(2, dtype)
  71. .set_dtype(3, dtype)
  72. .execs({{N, top_channels, top_height, top_width},
  73. {N, C, OH, OW},
  74. {N, C, OH, OW},
  75. {N, C, OH, OW}});
  76. }
  77. };
  78. run(dtype::Float32());
  79. run(dtype::Float16());
  80. checker.set_epsilon(5e-2);
  81. run(dtype::BFloat16());
  82. }
  83. TEST_F(CUDA, CORRELATION_BACKWARDDATA2) {
  84. ConstValue const_0{0};
  85. using Param = CorrelationBackwardData2::Param;
  86. Param param;
  87. param.is_multiply = true;
  88. param.format = Param::Format::NCHW;
  89. param.stride1 = 2;
  90. param.stride2 = 2;
  91. param.kernel_size = 3;
  92. param.pad_size = 4;
  93. Checker<CorrelationBackwardData2> checker(handle_cuda());
  94. checker.set_epsilon(1e-2);
  95. uint32_t pad_size = param.pad_size;
  96. uint32_t kernel_size = param.kernel_size;
  97. uint32_t stride1 = param.stride1;
  98. uint32_t stride2 = param.stride2;
  99. uint32_t max_displacement = param.max_displacement;
  100. auto run = [&](DType dtype) {
  101. for (size_t N : {1, 3})
  102. for (size_t C : {1, 3})
  103. for (size_t OH : {10, 100})
  104. for (size_t OW : {10, 100}) {
  105. int paddedbottomheight = OH + 2 * pad_size;
  106. int paddedbottomwidth = OW + 2 * pad_size;
  107. uint32_t kernel_radius = (kernel_size - 1) / 2;
  108. uint32_t border_size = max_displacement + kernel_radius;
  109. uint32_t top_width =
  110. ceil(static_cast<float>(
  111. paddedbottomwidth - border_size * 2) /
  112. static_cast<float>(stride1));
  113. uint32_t top_height =
  114. ceil(static_cast<float>(
  115. paddedbottomheight - border_size * 2) /
  116. static_cast<float>(stride1));
  117. uint32_t neighborhood_grid_radius = max_displacement / stride2;
  118. uint32_t neighborhood_grid_width =
  119. neighborhood_grid_radius * 2 + 1;
  120. uint32_t top_channels =
  121. neighborhood_grid_width * neighborhood_grid_width;
  122. checker.set_param(param)
  123. .set_dtype(0, dtype)
  124. .set_dtype(1, dtype)
  125. .set_dtype(2, dtype)
  126. .set_dtype(3, dtype)
  127. .execs({{N, top_channels, top_height, top_width},
  128. {N, C, OH, OW},
  129. {N, C, OH, OW},
  130. {N, C, OH, OW}});
  131. }
  132. };
  133. run(dtype::Float32());
  134. run(dtype::Float16());
  135. checker.set_epsilon(5e-2);
  136. run(dtype::BFloat16());
  137. }
  138. } // namespace test
  139. } // namespace megdnn
  140. // vim: syntax=cpp.doxygen