You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

deformable_ps_roi_pooling.cpp 1.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. #include "test/naive/fixture.h"
  2. #include "megdnn/oprs/nn.h"
  3. #include "test/common/benchmarker.h"
  4. #include "test/common/checker.h"
  5. #include "test/common/random_state.h"
  6. using namespace megdnn;
  7. using namespace test;
  8. TEST_F(NAIVE, DEFORMABLE_PSROI_POOLING_FWD) {
  9. Checker<DeformablePSROIPooling> checker(handle());
  10. DeformablePSROIPooling::Param param;
  11. param.no_trans = true;
  12. param.pooled_h = 3;
  13. param.pooled_w = 3;
  14. param.trans_std = 1.f;
  15. param.spatial_scale = 1.f;
  16. param.part_size = 1;
  17. param.sample_per_part = 1;
  18. UniformIntRNG data{0, 4};
  19. UniformIntRNG rois{0, 4};
  20. UniformIntRNG trans{-2, 2};
  21. checker.set_rng(0, &data).set_rng(1, &rois).set_rng(2, &trans);
  22. checker.set_param(param).execs({{4, 2, 5, 5}, {2, 5}, {4, 2, 5, 5}, {}, {}});
  23. }
  24. TEST_F(NAIVE, DEFORMABLE_PSROI_POOLING_BWD) {
  25. Checker<DeformablePSROIPoolingBackward> checker(handle());
  26. DeformablePSROIPoolingBackward::Param param;
  27. param.no_trans = true;
  28. param.pooled_h = 3;
  29. param.pooled_w = 3;
  30. param.trans_std = 1.f;
  31. param.spatial_scale = 1.f;
  32. param.part_size = 1;
  33. param.sample_per_part = 1;
  34. UniformIntRNG data{0, 4};
  35. UniformIntRNG rois{0, 4};
  36. UniformIntRNG trans{-2, 2};
  37. UniformIntRNG out_diff{-2, 2};
  38. UniformIntRNG out_count{-2, 2};
  39. checker.set_rng(0, &data)
  40. .set_rng(1, &rois)
  41. .set_rng(2, &trans)
  42. .set_rng(3, &out_diff)
  43. .set_rng(4, &out_count);
  44. checker.set_param(param).execs(
  45. {{4, 2, 5, 5}, // data
  46. {2, 5}, // rois
  47. {4, 2, 5, 5}, // trans
  48. {2, 2, 3, 3}, // out_diff
  49. {2, 2, 3, 3}, // out_count
  50. {4, 2, 5, 5},
  51. {4, 2, 5, 5}});
  52. }
  53. // vim: syntax=cpp.doxygen