You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

roi_pooling.cpp 2.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. #include "test/cuda/fixture.h"
  2. #include "test/common/checker.h"
  3. #include "test/common/roi_pooling.h"
  4. namespace megdnn {
  5. namespace test {
  6. TEST_F(CUDA, ROI_POOLING_FORWARD) {
  7. size_t N = 10, C = 3, IH = 102, IW = 108, spatial_scale = 100;
  8. size_t OH = 12, OW = 13, M = 7;
  9. ROIPoolingRNG rng(N);
  10. using Param = ROIPooling::Param;
  11. Param param;
  12. param.scale = spatial_scale;
  13. Checker<ROIPoolingForward> checker(handle_cuda());
  14. auto run = [&](DType dtype) {
  15. for (auto mode : {Param::Mode::MAX, Param::Mode::AVERAGE}) {
  16. param.mode = mode;
  17. checker.set_param(param)
  18. .set_rng(1, &rng)
  19. .set_dtype(0, dtype)
  20. .set_dtype(1, dtype)
  21. .set_dtype(2, dtype)
  22. .set_dtype(3, dtype::Int32())
  23. .execs({{N, C, IH, IW}, {M, 5}, {M, C, OH, OW}, {M, C, OH, OW}});
  24. }
  25. };
  26. run(dtype::Float32());
  27. run(dtype::Float16());
  28. }
  29. TEST_F(CUDA, ROI_POOLING_BACKWARD) {
  30. size_t N = 10, C = 3, IH = 102, IW = 108, spatial_scale = 100;
  31. size_t OH = 12, OW = 13, M = 7;
  32. ROIPoolingRNG rng(N);
  33. UniformIntRNG index_rng(0, OH * OW - 1);
  34. using Param = ROIPooling::Param;
  35. Param param;
  36. param.scale = spatial_scale;
  37. Checker<ROIPoolingBackward> checker(handle_cuda());
  38. checker.set_epsilon(1e-2);
  39. auto run = [&](DType dtype) {
  40. for (auto mode : {Param::Mode::MAX, Param::Mode::AVERAGE}) {
  41. param.mode = mode;
  42. checker.set_param(param)
  43. .set_dtype(0, dtype)
  44. .set_dtype(1, dtype)
  45. .set_dtype(2, dtype)
  46. .set_dtype(4, dtype)
  47. .set_dtype(3, dtype::Int32())
  48. .set_rng(2, &rng)
  49. .set_rng(3, &index_rng)
  50. .execs({{M, C, OH, OW},
  51. {N, C, IH, IW},
  52. {M, 5},
  53. {M, C, OH, OW},
  54. {N, C, IH, IW}});
  55. }
  56. };
  57. run(dtype::Float32());
  58. run(dtype::Float16());
  59. }
  60. } // namespace test
  61. } // namespace megdnn
  62. // vim: syntax=cpp.doxygen