You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

deformable_ps_roi_pooling.cpp 3.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. #include "megdnn/oprs/nn.h"
  2. #include "src/cuda/utils.h"
  3. #include "test/common/checker.h"
  4. #include "test/common/random_state.h"
  5. #include "test/common/roi_pooling.h"
  6. #include "test/cuda/benchmark.h"
  7. #include "test/cuda/fixture.h"
  8. using namespace megdnn;
  9. using namespace test;
  10. TEST_F(CUDA, DEFORMABLE_PSROI_POOLING_FWD) {
  11. Checker<DeformablePSROIPooling> checker(handle_cuda());
  12. auto run = [&checker](
  13. size_t N, size_t C, size_t IH, size_t IW, size_t OH, size_t OW,
  14. bool no_trans, size_t nr_bbox, size_t nr_cls, size_t part_sz,
  15. size_t sample_per_part, float trans_std, float spatial_scale) {
  16. DeformablePSROIPooling::Param param;
  17. param.no_trans = no_trans;
  18. param.pooled_h = OH;
  19. param.pooled_w = OW;
  20. param.trans_std = trans_std;
  21. param.spatial_scale = spatial_scale;
  22. param.part_size = part_sz;
  23. param.sample_per_part = sample_per_part;
  24. ROIPoolingRNG rois(N);
  25. checker.set_rng(1, &rois);
  26. checker.set_param(param).execs(
  27. {{N, C, IH, IW}, {nr_bbox, 5}, {nr_cls, 2, OH, OW}, {}, {}});
  28. };
  29. run(2, 4, 5, 5, 3, 3, true, 2, 2, 1, 1, 1.f, 1.f);
  30. run(2, 4, 5, 5, 3, 3, false, 2, 2, 1, 1, 1.f, 1.f);
  31. run(2, 4, 5, 5, 3, 3, false, 2, 2, 1, 1, 0.5f, 1.5f);
  32. run(2, 4, 100, 100, 60, 60, false, 2, 2, 1, 1, 0.5f, 1.5f);
  33. run(10, 3, 102, 108, 12, 13, false, 7, 2, 2, 2, 0.5f, 1.5f);
  34. run(2, 32, 100, 100, 50, 50, false, 16, 4, 1, 1, 1.f, 1.f);
  35. }
  36. TEST_F(CUDA, DEFORMABLE_PSROI_POOLING_BWD) {
  37. Checker<DeformablePSROIPoolingBackward> checker(handle_cuda());
  38. auto run = [&checker](
  39. size_t N, size_t C, size_t IH, size_t IW, size_t OH, size_t OW,
  40. bool no_trans, size_t nr_bbox, size_t nr_cls, size_t part_sz,
  41. size_t sample_per_part, float trans_std, float spatial_scale) {
  42. DeformablePSROIPooling::Param param;
  43. param.no_trans = no_trans;
  44. param.pooled_h = OH;
  45. param.pooled_w = OW;
  46. param.trans_std = trans_std;
  47. param.spatial_scale = spatial_scale;
  48. param.part_size = part_sz;
  49. param.sample_per_part = sample_per_part;
  50. ROIPoolingRNG rois(N);
  51. checker.set_rng(1, &rois);
  52. checker.set_param(param).execs({
  53. {N, C, IH, IW}, // data
  54. {nr_bbox, 5}, // rois
  55. {nr_cls, 2, OH, OW}, // trans
  56. {nr_bbox, C, OH, OW}, // out_diff
  57. {nr_bbox, C, OH, OW}, // out_count
  58. {N, C, IH, IW}, // data_diff
  59. {nr_cls, 2, OH, OW} // trans_diff
  60. });
  61. };
  62. run(2, 4, 5, 5, 3, 3, true, 2, 2, 1, 1, 1.f, 1.f);
  63. run(2, 4, 5, 5, 3, 3, false, 2, 2, 2, 2, 1.f, 1.f);
  64. run(2, 4, 5, 5, 3, 3, false, 2, 2, 1, 1, 1.f, 1.f);
  65. run(2, 4, 5, 5, 3, 3, false, 2, 2, 1, 1, 0.5f, 1.5f);
  66. run(2, 4, 100, 100, 60, 60, false, 2, 2, 1, 1, 0.5f, 1.5f);
  67. run(10, 3, 102, 108, 12, 13, false, 7, 2, 2, 2, 0.5f, 1.5f);
  68. run(2, 32, 100, 100, 50, 50, false, 16, 4, 1, 1, 1.f, 1.f);
  69. }
  70. // vim: syntax=cpp.doxygen