You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rotate.cpp 2.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. #include "test/common/rotate.h"
  2. #include "test/common/benchmarker.h"
  3. #include "test/common/checker.h"
  4. #include "test/common/task_record_check.h"
  5. #include "test/aarch64/fixture.h"
  6. namespace megdnn {
  7. namespace test {
  8. TEST_F(AARCH64, ROTATE) {
  9. using namespace rotate;
  10. std::vector<TestArg> args = get_args();
  11. Checker<Rotate> checker(handle());
  12. for (auto&& arg : args) {
  13. checker.set_param(arg.param)
  14. .set_dtype(0, arg.dtype)
  15. .set_dtype(1, arg.dtype)
  16. .execs({arg.src, {}});
  17. }
  18. }
  19. TEST_F(AARCH64, ROTATE_RECORD) {
  20. using namespace rotate;
  21. std::vector<TestArg> args = get_args();
  22. TaskRecordChecker<Rotate> checker(0);
  23. for (auto&& arg : args) {
  24. checker.set_param(arg.param)
  25. .set_dtype(0, arg.dtype)
  26. .set_dtype(1, arg.dtype)
  27. .execs({arg.src, {}});
  28. }
  29. }
  30. TEST_F(AARCH64, BENCHMARK_ROTATE) {
  31. using namespace rotate;
  32. using Param = param::Rotate;
  33. #define BENCHMARK_PARAM(benchmarker) \
  34. benchmarker.set_param(param); \
  35. benchmarker.set_dtype(0, dtype::Uint8());
  36. auto run = [&](const TensorShapeArray& shapes, Param param) {
  37. auto handle_naive = create_cpu_handle(2);
  38. Benchmarker<Rotate> benchmarker(handle());
  39. Benchmarker<Rotate> benchmarker_naive(handle_naive.get());
  40. BENCHMARK_PARAM(benchmarker);
  41. BENCHMARK_PARAM(benchmarker_naive);
  42. for (auto&& shape : shapes) {
  43. printf("execute %s: current---naive\n", shape.to_string().c_str());
  44. benchmarker.execs({shape, {}});
  45. benchmarker_naive.execs({shape, {}});
  46. }
  47. };
  48. Param param;
  49. TensorShapeArray shapes = {
  50. {1, 100, 100, 1},
  51. {2, 100, 100, 3},
  52. };
  53. param.clockwise = true;
  54. run(shapes, param);
  55. param.clockwise = false;
  56. run(shapes, param);
  57. }
  58. } // namespace test
  59. } // namespace megdnn
  60. // vim: syntax=cpp.doxygen