You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

deformable_conv.cpp 3.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. #include "test/naive/fixture.h"
  2. #include "megdnn/oprs/nn.h"
  3. #include "test/common/benchmarker.h"
  4. #include "test/common/checker.h"
  5. #include "test/common/random_state.h"
  6. using namespace megdnn;
  7. using namespace test;
  8. TEST_F(NAIVE, DEFORMABLE_CONV_FWD) {
  9. Checker<DeformableConv> checker(handle());
  10. DeformableConv::Param param;
  11. UniformIntRNG im_rng{0, 4};
  12. UniformIntRNG filter_rng{0, 4};
  13. UniformIntRNG offset_rng{-2, 2};
  14. UniformIntRNG mask_rng{0, 1};
  15. checker.set_rng(0, &im_rng)
  16. .set_rng(1, &filter_rng)
  17. .set_rng(2, &offset_rng)
  18. .set_rng(3, &mask_rng);
  19. param.pad_h = 1;
  20. param.pad_w = 1;
  21. param.stride_h = 1;
  22. param.stride_w = 1;
  23. param.dilate_h = 1;
  24. param.dilate_w = 1;
  25. param.format = DeformableConv::Param::Format::NCHW;
  26. param.sparse = DeformableConv::Param::Sparse::GROUP;
  27. checker.set_param(param).execs(
  28. {{1, 2, 5, 5},
  29. {2, 1, 1, 3, 3},
  30. {1, 2 * 2 * 3 * 3, 5, 5},
  31. {1, 2 * 3 * 3, 5, 5},
  32. {}});
  33. checker.set_param(param).execs(
  34. {{1, 2, 5, 5},
  35. {2, 1, 1, 3, 3},
  36. {1, 2 * 2 * 3 * 3, 5, 5},
  37. {1, 2 * 3 * 3, 5, 5},
  38. {}});
  39. param.sparse = DeformableConv::Param::Sparse::DENSE;
  40. checker.set_param(param).execs(
  41. {{1, 2, 5, 5},
  42. {2, 2, 3, 3},
  43. {1, 2 * 2 * 3 * 3, 5, 5},
  44. {1, 2 * 3 * 3, 5, 5},
  45. {}});
  46. }
  47. TEST_F(NAIVE, DEFORMABLE_CONV_BWD_FILTER) {
  48. Checker<DeformableConvBackwardFilter> checker(handle());
  49. DeformableConv::Param param;
  50. UniformIntRNG im_rng{0, 4};
  51. UniformIntRNG offset_rng{-2, 2};
  52. UniformIntRNG mask_rng{0, 1};
  53. UniformIntRNG out_grad_rng{0, 1};
  54. checker.set_rng(0, &im_rng)
  55. .set_rng(1, &offset_rng)
  56. .set_rng(2, &mask_rng)
  57. .set_rng(3, &out_grad_rng);
  58. param.pad_h = 1;
  59. param.pad_w = 1;
  60. param.stride_h = 1;
  61. param.stride_w = 1;
  62. param.dilate_h = 1;
  63. param.dilate_w = 1;
  64. param.format = DeformableConv::Param::Format::NCHW;
  65. param.sparse = DeformableConv::Param::Sparse::GROUP;
  66. checker.set_param(param).execs(
  67. {{1, 2, 5, 5},
  68. {1, 2 * 2 * 3 * 3, 5, 5},
  69. {1, 2 * 3 * 3, 5, 5},
  70. {1, 2, 5, 5},
  71. {2, 1, 1, 3, 3}});
  72. }
  73. TEST_F(NAIVE, DEFORMABLE_CONV_BWD_DATA) {
  74. Checker<DeformableConvBackwardData> checker(handle());
  75. DeformableConv::Param param;
  76. ConstValue im_rng{1};
  77. ConstValue filter_rng{0.99};
  78. ConstValue offset_rng{1.1};
  79. ConstValue mask_rng{1};
  80. ConstValue out_grad_rng{1};
  81. checker.set_rng(0, &im_rng)
  82. .set_rng(1, &filter_rng)
  83. .set_rng(2, &offset_rng)
  84. .set_rng(3, &mask_rng)
  85. .set_rng(4, &out_grad_rng);
  86. param.pad_h = 1;
  87. param.pad_w = 1;
  88. param.stride_h = 1;
  89. param.stride_w = 1;
  90. param.dilate_h = 1;
  91. param.dilate_w = 1;
  92. param.format = DeformableConv::Param::Format::NCHW;
  93. param.sparse = DeformableConv::Param::Sparse::GROUP;
  94. checker.set_param(param).execs(
  95. {{1, 2, 5, 5},
  96. {2, 1, 1, 3, 3},
  97. {1, 1 * 2 * 3 * 3, 5, 5},
  98. {1, 1 * 3 * 3, 5, 5},
  99. {1, 2, 5, 5},
  100. {1, 2, 5, 5},
  101. {1, 1 * 2 * 3 * 3, 5, 5},
  102. {1, 1 * 3 * 3, 5, 5}});
  103. }
  104. // vim: syntax=cpp.doxygen