You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

flip.cpp 2.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. #include <gtest/gtest.h>
  2. #include "megdnn.h"
  3. #include "megdnn/oprs.h"
  4. #include "test/common/benchmarker.h"
  5. #include "test/common/checker.h"
  6. #include "test/common/flip.h"
  7. #include "test/common/tensor.h"
  8. #include "test/cuda/fixture.h"
  9. namespace megdnn {
  10. namespace test {
  11. TEST_F(CUDA, FLIP) {
  12. using namespace flip;
  13. std::vector<TestArg> args = get_args();
  14. Checker<Flip> checker(handle_cuda());
  15. checker.set_dtype(0, dtype::Int32());
  16. checker.set_dtype(1, dtype::Int32());
  17. //! test for batch size exceed CUDNN_MAX_BATCH_X_CHANNEL_SIZE
  18. Flip::Param cur_param;
  19. for (bool vertical : {false, true}) {
  20. for (bool horizontal : {false, true}) {
  21. cur_param.horizontal = horizontal;
  22. cur_param.vertical = vertical;
  23. args.emplace_back(cur_param, TensorShape{65535, 3, 4, 1});
  24. args.emplace_back(cur_param, TensorShape{65540, 3, 4, 3});
  25. }
  26. }
  27. for (auto&& arg : args) {
  28. checker.execs({arg.src, {}});
  29. }
  30. }
  31. #if MEGDNN_WITH_BENCHMARK
  32. TEST_F(CUDA, FLIP_BENCHMARK) {
  33. auto run = [&](const TensorShapeArray& shapes) {
  34. Benchmarker<Flip> benchmarker(handle_cuda());
  35. benchmarker.set_dtype(0, dtype::Int32());
  36. benchmarker.set_dtype(1, dtype::Int32());
  37. benchmarker.set_times(5);
  38. Flip::Param param;
  39. #define BENCHMARK_FLIP(is_vertical, is_horizontal) \
  40. param.vertical = is_vertical; \
  41. param.horizontal = is_horizontal; \
  42. benchmarker.set_param(param); \
  43. printf("src:%s vertical==%d horizontal==%d\n", shape.to_string().c_str(), \
  44. is_vertical, is_horizontal); \
  45. benchmarker.execs({shape, {}});
  46. for (auto&& shape : shapes) {
  47. BENCHMARK_FLIP(false, false);
  48. BENCHMARK_FLIP(false, true);
  49. BENCHMARK_FLIP(true, false);
  50. BENCHMARK_FLIP(true, true);
  51. }
  52. #undef BENCHMARK_FLIP
  53. };
  54. TensorShapeArray shapes = {{3, 101, 98, 1}, {3, 101, 98, 3}};
  55. run(shapes);
  56. }
  57. #endif
  58. } // namespace test
  59. } // namespace megdnn
  60. // vim: syntax=cpp.doxygen