You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

svd.h 1.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. /**
  2. * \file dnn/test/common/svd.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "./checker.h"
  13. #include "megdnn/oprs.h"
  14. namespace megdnn {
  15. namespace test {
  16. class SVDTestcase {
  17. std::unique_ptr<dt_float32> m_mem;
  18. SVDTestcase(const SVDForward::Param& param, const TensorLayout& mat)
  19. : m_param{param}, m_mat{nullptr, mat} {}
  20. public:
  21. SVDForward::Param m_param;
  22. TensorND m_mat;
  23. struct Result {
  24. std::shared_ptr<TensorND> u;
  25. std::shared_ptr<TensorND> s;
  26. std::shared_ptr<TensorND> vt;
  27. std::shared_ptr<TensorND> recovered_mat;
  28. };
  29. Result run(SVDForward* opr);
  30. static std::vector<SVDTestcase> make();
  31. };
  32. } // namespace test
  33. } // namespace megdnn
  34. // vim: syntax=cpp.doxygen