You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

group_norm.cpp 1.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. #include "test/cuda/fixture.h"
  2. #include "test/common/checker.h"
  3. namespace megdnn {
  4. namespace test {
  5. TEST_F(CUDA, GROUPNORM_FORWARD) {
  6. using Param = GroupNormForward::Param;
  7. Param param;
  8. param.affine = true;
  9. param.eps = 1e-6;
  10. Checker<GroupNormForward> checker(handle_cuda());
  11. checker.set_epsilon(1e-2);
  12. auto run = [&](DType d) {
  13. for (size_t group : {1, 3})
  14. for (size_t C : {6, 9}) {
  15. param.group = group;
  16. checker.set_param(param)
  17. .set_dtype(0, d)
  18. .set_dtype(1, d)
  19. .set_dtype(2, d)
  20. .set_dtype(3, d)
  21. .set_dtype(4, dtype::Float32())
  22. .set_dtype(5, dtype::Float32())
  23. .execs({{2, C, 2, 1},
  24. {C},
  25. {C},
  26. {2, C, 2, 1},
  27. {2, group},
  28. {2, group}});
  29. }
  30. };
  31. run(dtype::Float32());
  32. run(dtype::Float16());
  33. run(dtype::BFloat16());
  34. }
  35. } // namespace test
  36. } // namespace megdnn
  37. // vim: syntax=cpp.doxygen