You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

add_update.cpp 2.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. #include "test/cuda/fixture.h"
  2. #include "megdnn/oprs.h"
  3. #include "test/common/checker.h"
  4. namespace megdnn {
  5. namespace test {
  6. TEST_F(CUDA, ADD_UPDATE) {
  7. Checker<AddUpdate> checker(handle_cuda());
  8. checker.set_dtype(0, dtype::Float32())
  9. .set_dtype(1, dtype::Float32())
  10. .execs({{2, 3, 4}, {2, 3, 4}});
  11. checker.set_dtype(0, dtype::Float16())
  12. .set_dtype(1, dtype::Float16())
  13. .execs({{2, 3, 4}, {2, 3, 4}});
  14. checker.set_dtype(0, dtype::BFloat16())
  15. .set_dtype(1, dtype::BFloat16())
  16. .execs({{2, 3, 4}, {2, 3, 4}});
  17. checker.execl(
  18. {{{2, 3, 4}, dtype::Float32()}, {{2, 3, 4}, {16, 4, 1}, dtype::Float32()}});
  19. checker.execl(
  20. {{{2, 3, 4}, dtype::Float16()}, {{2, 3, 4}, {16, 4, 1}, dtype::Float16()}});
  21. checker.execl(
  22. {{{2, 3, 4}, dtype::BFloat16()},
  23. {{2, 3, 4}, {16, 4, 1}, dtype::BFloat16()}});
  24. checker.execl(
  25. {{{2, 3, 4}, {16, 4, 1}, dtype::Float32()}, {{2, 3, 4}, dtype::Float32()}});
  26. checker.execl({{{2, 3, 4}, dtype::Float32()}, {{1}, dtype::Float32()}});
  27. checker.execl({{{2, 3, 4}, dtype::Float32()}, {{2, 1, 4}, dtype::Float32()}});
  28. checker.set_param({2, -1, 3})
  29. .set_dtype(0, dtype::Int32())
  30. .set_dtype(1, dtype::Int32())
  31. .execs({{2, 3, 2}, {2, 3, 2}});
  32. checker.set_dtype(0, dtype::Int8())
  33. .set_dtype(1, dtype::Int8())
  34. .execs({{2, 3, 2}, {2, 3, 2}});
  35. checker.set_dtype(0, dtype::Uint8())
  36. .set_dtype(1, dtype::Uint8())
  37. .execs({{2, 3, 2}, {2, 3, 2}});
  38. // test scalar
  39. checker.set_dtype(0, dtype::Int8()).set_dtype(1, dtype::Int8()).execs({{1}, {1}});
  40. checker.set_dtype(0, dtype::Int8()).set_dtype(1, dtype::Int8()).execs({{4}, {1}});
  41. checker.execl({{{2, 3, 4}, dtype::Int8()}, {{2, 3, 4}, {16, 4, 1}, dtype::Int8()}});
  42. checker.execl({{{2, 3, 4}, dtype::Int8()}, {{1, 3, 1}, {16, 4, 1}, dtype::Int8()}});
  43. checker.execl({{{2, 3, 4}, {16, 4, 1}, dtype::Int8()}, {{2, 3, 4}, dtype::Int8()}});
  44. }
  45. } // namespace test
  46. } // namespace megdnn
  47. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}