You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

roi_pooling.h 890 B

123456789101112131415161718192021222324252627282930313233343536
  1. #include "megdnn/oprs.h"
  2. #include "test/common/random_state.h"
  3. #include "test/common/rng.h"
  4. namespace megdnn {
  5. namespace test {
  6. class ROIPoolingRNG final : public IIDRNG {
  7. public:
  8. ROIPoolingRNG(size_t n) : n(n), idx(0) {}
  9. dt_float32 gen_single_val() override {
  10. std::uniform_real_distribution<dt_float32> distf(0.0f, 1.0f);
  11. std::uniform_int_distribution<int> disti(0, n - 1);
  12. dt_float32 res;
  13. if (idx == 0) {
  14. res = static_cast<dt_float32>(disti(RandomState::generator()));
  15. }
  16. if (idx == 1 || idx == 2) {
  17. res = distf(RandomState::generator()) * 0.5;
  18. } else {
  19. res = distf(RandomState::generator()) * 0.5 + 0.5;
  20. }
  21. idx = (idx + 1) % 5;
  22. return res;
  23. }
  24. private:
  25. size_t n;
  26. size_t idx;
  27. };
  28. } // namespace test
  29. } // namespace megdnn
  30. // vim: syntax=cpp.doxygen