You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lrn.cpp 899 B

1234567891011121314151617181920212223242526272829303132
  1. #include "test/cuda/fixture.h"
  2. #include "test/common/checker.h"
  3. #include "test/common/local.h"
  4. namespace megdnn {
  5. namespace test {
  6. TEST_F(CUDA, LRN_FORWARD) {
  7. Checker<LRNForward> checker(handle_cuda());
  8. for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) {
  9. checker.set_dtype(0, dtype);
  10. checker.execs({{2, 11, 12, 13}, {}});
  11. for (size_t w = 10; w <= 50; ++w) {
  12. checker.execs({{2, w, 12, 13}, {}});
  13. }
  14. }
  15. }
  16. TEST_F(CUDA, LRN_BACKWARD) {
  17. Checker<LRNBackward> checker(handle_cuda());
  18. auto shape = TensorShape{2, 11, 12, 13};
  19. checker.set_dtype(0, dtype::Float32());
  20. checker.exec(TensorShapeArray{shape, shape, shape, shape});
  21. checker.set_dtype(1, dtype::Float32());
  22. checker.exec(TensorShapeArray{shape, shape, shape, shape});
  23. }
  24. } // namespace test
  25. } // namespace megdnn
  26. // vim: syntax=cpp.doxygen