You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

repeat.cpp 1.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. #include "test/cuda/fixture.h"
  2. #include "test/common/checker.h"
  3. #include "test/common/tile_repeat.h"
  4. namespace megdnn {
  5. namespace test {
  6. TEST_F(CUDA, REPEAT_FORWARD) {
  7. Checker<RepeatForward> checker(handle_cuda());
  8. auto args = tile_repeat::get_args();
  9. for (auto&& arg : args) {
  10. checker.set_dtype(0, dtype::Float32())
  11. .set_param(arg.param())
  12. .execs({arg.src, {}});
  13. checker.set_dtype(0, dtype::Float16())
  14. .set_param(arg.param())
  15. .execs({arg.src, {}});
  16. }
  17. }
  18. TEST_F(CUDA, REPEAT_BACKWARD) {
  19. Checker<RepeatBackward> checker(handle_cuda());
  20. UniformFloatRNG rng(1, 2);
  21. checker.set_epsilon(1e-2).set_rng(0, &rng);
  22. ;
  23. auto args = tile_repeat::get_args();
  24. for (auto&& arg : args) {
  25. checker.set_dtype(0, dtype::Float32())
  26. .set_dtype(1, dtype::Float32())
  27. .set_param(arg.param())
  28. .execs({arg.dst, arg.src});
  29. checker.set_dtype(0, dtype::Float16())
  30. .set_dtype(1, dtype::Float16())
  31. .set_param(arg.param())
  32. .execs({arg.dst, arg.src});
  33. }
  34. }
  35. } // namespace test
  36. } // namespace megdnn
  37. // vim: syntax=cpp.doxygen