You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

correlation.cpp 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. #include "test/cuda/fixture.h"
  2. #include "test/common/checker.h"
  3. #include "test/common/correlation.h"
  4. namespace megdnn {
  5. namespace test {
  6. TEST_F(CUDA, CORRELATION_FORWARD) {
  7. using namespace correlation;
  8. std::vector<TestArg> args = get_args();
  9. Checker<Correlation> checker(handle_cuda());
  10. for (auto&& arg : args) {
  11. checker.set_param(arg.param)
  12. .set_dtype(0, dtype::Float32())
  13. .set_dtype(1, dtype::Float32())
  14. .execs({arg.data1, arg.data2, {}});
  15. }
  16. }
  17. TEST_F(CUDA, CORRELATION_BACKWARDDATA1) {
  18. ConstValue const_0{0};
  19. using Param = CorrelationBackwardData1::Param;
  20. Param param;
  21. param.is_multiply = true;
  22. param.format = Param::Format::NCHW;
  23. param.stride1 = 2;
  24. param.stride2 = 2;
  25. param.kernel_size = 3;
  26. param.pad_size = 4;
  27. Checker<CorrelationBackwardData1> checker(handle_cuda());
  28. checker.set_epsilon(1e-2);
  29. uint32_t pad_size = param.pad_size;
  30. uint32_t kernel_size = param.kernel_size;
  31. uint32_t stride1 = param.stride1;
  32. uint32_t stride2 = param.stride2;
  33. uint32_t max_displacement = param.max_displacement;
  34. auto run = [&](DType dtype) {
  35. for (size_t N : {1, 3})
  36. for (size_t C : {1, 3})
  37. for (size_t OH : {10, 100})
  38. for (size_t OW : {10, 100}) {
  39. int paddedbottomheight = OH + 2 * pad_size;
  40. int paddedbottomwidth = OW + 2 * pad_size;
  41. uint32_t kernel_radius = (kernel_size - 1) / 2;
  42. uint32_t border_size = max_displacement + kernel_radius;
  43. uint32_t top_width =
  44. ceil(static_cast<float>(
  45. paddedbottomwidth - border_size * 2) /
  46. static_cast<float>(stride1));
  47. uint32_t top_height =
  48. ceil(static_cast<float>(
  49. paddedbottomheight - border_size * 2) /
  50. static_cast<float>(stride1));
  51. uint32_t neighborhood_grid_radius = max_displacement / stride2;
  52. uint32_t neighborhood_grid_width =
  53. neighborhood_grid_radius * 2 + 1;
  54. uint32_t top_channels =
  55. neighborhood_grid_width * neighborhood_grid_width;
  56. checker.set_param(param)
  57. .set_dtype(0, dtype)
  58. .set_dtype(1, dtype)
  59. .set_dtype(2, dtype)
  60. .set_dtype(3, dtype)
  61. .execs({{N, top_channels, top_height, top_width},
  62. {N, C, OH, OW},
  63. {N, C, OH, OW},
  64. {N, C, OH, OW}});
  65. }
  66. };
  67. run(dtype::Float32());
  68. run(dtype::Float16());
  69. checker.set_epsilon(5e-2);
  70. run(dtype::BFloat16());
  71. }
  72. TEST_F(CUDA, CORRELATION_BACKWARDDATA2) {
  73. ConstValue const_0{0};
  74. using Param = CorrelationBackwardData2::Param;
  75. Param param;
  76. param.is_multiply = true;
  77. param.format = Param::Format::NCHW;
  78. param.stride1 = 2;
  79. param.stride2 = 2;
  80. param.kernel_size = 3;
  81. param.pad_size = 4;
  82. Checker<CorrelationBackwardData2> checker(handle_cuda());
  83. checker.set_epsilon(1e-2);
  84. uint32_t pad_size = param.pad_size;
  85. uint32_t kernel_size = param.kernel_size;
  86. uint32_t stride1 = param.stride1;
  87. uint32_t stride2 = param.stride2;
  88. uint32_t max_displacement = param.max_displacement;
  89. auto run = [&](DType dtype) {
  90. for (size_t N : {1, 3})
  91. for (size_t C : {1, 3})
  92. for (size_t OH : {10, 100})
  93. for (size_t OW : {10, 100}) {
  94. int paddedbottomheight = OH + 2 * pad_size;
  95. int paddedbottomwidth = OW + 2 * pad_size;
  96. uint32_t kernel_radius = (kernel_size - 1) / 2;
  97. uint32_t border_size = max_displacement + kernel_radius;
  98. uint32_t top_width =
  99. ceil(static_cast<float>(
  100. paddedbottomwidth - border_size * 2) /
  101. static_cast<float>(stride1));
  102. uint32_t top_height =
  103. ceil(static_cast<float>(
  104. paddedbottomheight - border_size * 2) /
  105. static_cast<float>(stride1));
  106. uint32_t neighborhood_grid_radius = max_displacement / stride2;
  107. uint32_t neighborhood_grid_width =
  108. neighborhood_grid_radius * 2 + 1;
  109. uint32_t top_channels =
  110. neighborhood_grid_width * neighborhood_grid_width;
  111. checker.set_param(param)
  112. .set_dtype(0, dtype)
  113. .set_dtype(1, dtype)
  114. .set_dtype(2, dtype)
  115. .set_dtype(3, dtype)
  116. .execs({{N, top_channels, top_height, top_width},
  117. {N, C, OH, OW},
  118. {N, C, OH, OW},
  119. {N, C, OH, OW}});
  120. }
  121. };
  122. run(dtype::Float32());
  123. run(dtype::Float16());
  124. checker.set_epsilon(5e-2);
  125. run(dtype::BFloat16());
  126. }
  127. } // namespace test
  128. } // namespace megdnn
  129. // vim: syntax=cpp.doxygen