You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

roi_align.cpp 2.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. #include "test/cuda/fixture.h"
  2. #include "test/common/checker.h"
  3. #include "test/common/roi_pooling.h"
  4. namespace megdnn {
  5. namespace test {
  6. TEST_F(CUDA, ROI_ALIGN_FORWARD) {
  7. size_t N = 10, C = 3, IH = 102, IW = 108;
  8. size_t OH = 12, OW = 13, M = 7;
  9. ROIPoolingRNG rng(N);
  10. ConstValue const_0{0};
  11. ConsecutiveRNG consecutive_rng{0.f, 1.f / (N * C * IH * IW * 1.f)};
  12. using Param = ROIAlign::Param;
  13. Param param;
  14. param.spatial_scale = 100;
  15. param.offset = 0.0;
  16. param.pooled_height = OH;
  17. param.pooled_width = OW;
  18. param.sample_height = 16;
  19. param.sample_width = 16;
  20. Checker<ROIAlignForward> checker(handle_cuda());
  21. auto run = [&](DType dtype) {
  22. for (auto mode : {Param::Mode::MAX, Param::Mode::AVERAGE}) {
  23. param.mode = mode;
  24. if (mode == Param::Mode::MAX) {
  25. checker.set_rng(0, &consecutive_rng);
  26. }
  27. checker.set_param(param)
  28. .set_rng(1, &rng)
  29. .set_dtype(0, dtype)
  30. .set_dtype(1, dtype)
  31. .set_dtype(2, dtype)
  32. .set_dtype(3, dtype::Int32())
  33. .execs({{N, C, IH, IW}, {M, 5}, {}, {}});
  34. }
  35. };
  36. run(dtype::Float32());
  37. run(dtype::Float16());
  38. }
  39. TEST_F(CUDA, ROI_ALIGN_BACKWARD) {
  40. size_t N = 10, C = 3, IH = 102, IW = 108;
  41. size_t OH = 12, OW = 13, M = 7;
  42. ROIPoolingRNG rng(N);
  43. ConstValue const_0{0};
  44. using Param = ROIAlign::Param;
  45. Param param;
  46. param.spatial_scale = 100;
  47. param.offset = 0.0;
  48. param.pooled_height = OH;
  49. param.pooled_width = OW;
  50. param.sample_height = 7;
  51. param.sample_width = 7;
  52. UniformIntRNG index_rng(0, param.sample_height * param.sample_width - 1);
  53. Checker<ROIAlignBackward> checker(handle_cuda());
  54. checker.set_epsilon(1e-2);
  55. auto run = [&](DType dtype) {
  56. for (auto mode : {Param::Mode::MAX, Param::Mode::AVERAGE}) {
  57. param.mode = mode;
  58. checker.set_param(param)
  59. .set_dtype(0, dtype)
  60. .set_dtype(1, dtype)
  61. .set_dtype(3, dtype)
  62. .set_dtype(2, dtype::Int32())
  63. .set_rng(1, &rng)
  64. .set_rng(2, &index_rng)
  65. .set_rng(3, &const_0)
  66. .execs({{M, C, OH, OW}, {M, 5}, {M, C, OH, OW}, {N, C, IH, IW}});
  67. }
  68. };
  69. run(dtype::Float32());
  70. run(dtype::Float16());
  71. checker.set_epsilon(5e-2);
  72. run(dtype::BFloat16());
  73. }
  74. } // namespace test
  75. } // namespace megdnn
  76. // vim: syntax=cpp.doxygen