You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

diag.cpp 1.1 kB

1234567891011121314151617181920212223242526272829303132
  1. #include "test/cuda/fixture.h"
  2. #include "megdnn/oprs.h"
  3. #include "test/common/checker.h"
  4. namespace megdnn {
  5. namespace test {
  6. TEST_F(CUDA, DIAG) {
  7. Checker<Diag> checker(handle_cuda());
  8. for (DType dtype :
  9. std::vector<DType>{dtype::Float16(), dtype::Int32(), dtype::Float32()})
  10. for (int k = -5; k < 5; ++k) {
  11. checker.set_param({k});
  12. checker.set_dtype(0, dtype);
  13. checker.set_dtype(1, dtype);
  14. size_t absk = static_cast<size_t>(std::abs(k));
  15. checker.exec(TensorShapeArray{{8}, {8 + absk, 8 + absk}});
  16. //! NOTE: diag for vector or matrix is a vector
  17. auto oshape = [&](int n, int m) -> TensorShape {
  18. size_t o = (k >= 0 ? std::min(m - k, n) : std::min(n + k, m));
  19. return {o};
  20. };
  21. checker.exec(TensorShapeArray{{8, 6}, oshape(8, 6)});
  22. checker.exec(TensorShapeArray{{6, 8}, oshape(6, 8)});
  23. checker.exec(TensorShapeArray{{8, 8}, oshape(8, 8)});
  24. }
  25. }
  26. } // namespace test
  27. } // namespace megdnn
  28. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}