You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mask_conv.cpp 1.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. #include "test/cuda/fixture.h"
  2. #include "megdnn/oprs.h"
  3. #include "test/common/benchmarker.h"
  4. #include "test/common/checker.h"
  5. #include "test/common/mask_conv.h"
  6. #include "test/common/rng.h"
  7. using namespace megdnn;
  8. using namespace test;
  9. TEST_F(CUDA, MASK_CONV) {
  10. mask_conv_test(handle_cuda());
  11. }
  12. #if MEGDNN_WITH_BENCHMARK
  13. TEST_F(CUDA, MASK_CONV_BENCHMARK) {
  14. mask_conv_benchmark(handle_cuda());
  15. }
  16. #endif
  17. TEST_F(CUDA, MASK_PROPAGATE) {
  18. Checker<MaskPropagate> checker(handle_cuda());
  19. auto run = [&](size_t IH, size_t IW, size_t FH, size_t FW, size_t SH = 1,
  20. size_t SW = 1, size_t PH = 0, size_t PW = 0, size_t DH = 1,
  21. size_t DW = 1) {
  22. using Param = param::MaskPropagate;
  23. Param param(PH, PW, SH, SW, FH, FW, DH, DW);
  24. TensorShape src_shape({IH, IW}), dst({});
  25. auto rng = std::make_unique<BernoulliRNG>(0.5);
  26. checker.set_param(param).set_rng(0, rng.get()).execs({src_shape, dst});
  27. #undef test
  28. };
  29. #define cb(DType) \
  30. checker.set_dtype(0, DType()); \
  31. run(3, 3, 1, 1); \
  32. run(5, 5, 2, 3, 2, 2); \
  33. run(5, 5, 3, 3, 2, 2, 1, 2); \
  34. run(5, 5, 3, 3, 2, 1, 1, 2); \
  35. run(5, 5, 3, 3, 1, 2, 2, 2); \
  36. run(24, 23, 4, 4, 1, 1, 3, 2); \
  37. run(24, 23, 4, 4, 1, 1, 3, 2, 2, 2); \
  38. run(24, 23, 4, 4, 1, 1, 3, 2, 2, 3); \
  39. run(24, 23, 4, 4, 1, 1, 3, 2, 3, 3);
  40. // cb(dtype::Int32)
  41. MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb);
  42. #undef cb
  43. }