You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

roi_align.cpp 3.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. /**
  2. * \file dnn/test/cuda/roi_align.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/cuda/fixture.h"
  12. #include "test/common/checker.h"
  13. #include "test/common/roi_pooling.h"
  14. namespace megdnn {
  15. namespace test {
  16. TEST_F(CUDA, ROI_ALIGN_FORWARD) {
  17. size_t N = 10, C = 3, IH = 102, IW = 108;
  18. size_t OH = 12, OW = 13, M = 7;
  19. ROIPoolingRNG rng(N);
  20. ConstValue const_0{0};
  21. ConsecutiveRNG consecutive_rng{0.f, 1.f / (N * C * IH * IW * 1.f)};
  22. using Param = ROIAlign::Param;
  23. Param param;
  24. param.spatial_scale = 100;
  25. param.offset = 0.0;
  26. param.pooled_height = OH;
  27. param.pooled_width = OW;
  28. param.sample_height = 16;
  29. param.sample_width = 16;
  30. Checker<ROIAlignForward> checker(handle_cuda());
  31. auto run = [&](DType dtype) {
  32. for (auto mode : {Param::Mode::MAX, Param::Mode::AVERAGE}) {
  33. param.mode = mode;
  34. if (mode == Param::Mode::MAX) {
  35. checker.set_rng(0, &consecutive_rng);
  36. }
  37. checker.set_param(param)
  38. .set_rng(1, &rng)
  39. .set_dtype(0, dtype)
  40. .set_dtype(1, dtype)
  41. .set_dtype(2, dtype)
  42. .set_dtype(3, dtype::Int32())
  43. .execs({{N, C, IH, IW}, {M, 5}, {}, {}});
  44. }
  45. };
  46. run(dtype::Float32());
  47. run(dtype::Float16());
  48. }
  49. TEST_F(CUDA, ROI_ALIGN_BACKWARD) {
  50. size_t N = 10, C = 3, IH = 102, IW = 108;
  51. size_t OH = 12, OW = 13, M = 7;
  52. ROIPoolingRNG rng(N);
  53. ConstValue const_0{0};
  54. using Param = ROIAlign::Param;
  55. Param param;
  56. param.spatial_scale = 100;
  57. param.offset = 0.0;
  58. param.pooled_height = OH;
  59. param.pooled_width = OW;
  60. param.sample_height = 7;
  61. param.sample_width = 7;
  62. UniformIntRNG index_rng(0, param.sample_height * param.sample_width - 1);
  63. Checker<ROIAlignBackward> checker(handle_cuda());
  64. checker.set_epsilon(1e-2);
  65. auto run = [&](DType dtype) {
  66. for (auto mode : {Param::Mode::MAX, Param::Mode::AVERAGE}) {
  67. param.mode = mode;
  68. checker.set_param(param)
  69. .set_dtype(0, dtype)
  70. .set_dtype(1, dtype)
  71. .set_dtype(3, dtype)
  72. .set_dtype(2, dtype::Int32())
  73. .set_rng(1, &rng)
  74. .set_rng(2, &index_rng)
  75. .set_rng(3, &const_0)
  76. .execs({{M, C, OH, OW}, {M, 5}, {M, C, OH, OW}, {N, C, IH, IW}});
  77. }
  78. };
  79. run(dtype::Float32());
  80. run(dtype::Float16());
  81. checker.set_epsilon(5e-2);
  82. run(dtype::BFloat16());
  83. }
  84. } // namespace test
  85. } // namespace megdnn
  86. // vim: syntax=cpp.doxygen