You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

svd.cpp 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. #include "test/common/svd.h"
  2. #include "test/common/checker.h"
  3. #include "test/common/rng.h"
  4. #include "test/common/tensor.h"
  5. #include "test/common/utils.h"
  6. #include "test/common/workspace_wrapper.h"
  7. using namespace megdnn;
  8. using namespace test;
  9. using Param = SVDForward::Param;
  10. namespace {
  11. template <typename T>
  12. void fill_diag(const TensorND& v, TensorND& diag) {
  13. const auto& layout = diag.layout;
  14. megdnn_assert_contiguous(layout);
  15. megdnn_assert(layout.ndim >= 2);
  16. size_t block_cnt = 1;
  17. for (size_t i = 0; i < layout.ndim - 2; i++) {
  18. block_cnt *= layout[i];
  19. }
  20. size_t m = layout[layout.ndim - 2];
  21. size_t n = layout[layout.ndim - 1];
  22. size_t mn = std::min(m, n);
  23. auto v_ptr = v.ptr<T>();
  24. auto ptr = diag.ptr<T>();
  25. memset(ptr, 0, diag.layout.span().dist_byte());
  26. auto ld = layout.stride[layout.ndim - 2];
  27. for (size_t blk = 0; blk < block_cnt; blk++) {
  28. size_t offset(0), s_offset(0);
  29. if (block_cnt > 1) {
  30. offset = blk * layout.stride[layout.ndim - 3];
  31. s_offset = blk * v.layout.stride[v.layout.ndim - 2];
  32. }
  33. for (size_t i = 0; i < mn; i++) {
  34. ptr[offset + i * ld + i] = v_ptr[s_offset + i];
  35. }
  36. }
  37. }
  38. std::shared_ptr<Tensor<>> matmul(Handle* handle, const TensorND& A, const TensorND& B) {
  39. auto matmul_opr = handle->create_operator<BatchedMatrixMul>();
  40. TensorLayout result_layout;
  41. matmul_opr->deduce_layout(A.layout, B.layout, result_layout);
  42. std::shared_ptr<Tensor<>> result(new Tensor<>(handle, result_layout));
  43. WorkspaceWrapper ws(
  44. handle,
  45. matmul_opr->get_workspace_in_bytes(A.layout, B.layout, result->layout()));
  46. matmul_opr->exec(A, B, result->tensornd(), ws.workspace());
  47. return result;
  48. }
  49. } // namespace
  50. std::vector<SVDTestcase> SVDTestcase::make() {
  51. std::vector<SVDTestcase> ret;
  52. auto param = Param(false /* compute_uv */, false /* full_matrices */);
  53. auto add_shape = [&ret, &param](const TensorShape& shape) {
  54. ret.push_back({param, TensorLayout{shape, dtype::Float32()}});
  55. };
  56. add_shape({1, 7, 7});
  57. add_shape({1, 3, 7});
  58. add_shape({1, 7, 3});
  59. for (size_t rows : {2, 3, 5, 7, 10, 32, 100}) {
  60. for (size_t cols : {2, 3, 5, 7, 10, 32, 100}) {
  61. param.compute_uv = false;
  62. param.full_matrices = false;
  63. add_shape({3, rows, cols});
  64. param.compute_uv = true;
  65. add_shape({2, rows, cols});
  66. param.full_matrices = true;
  67. add_shape({3, rows, cols});
  68. }
  69. }
  70. NormalRNG data_rng;
  71. auto fill_data = [&](TensorND& data) {
  72. auto sz = data.layout.span().dist_byte(), szf = sz / sizeof(dt_float32);
  73. auto pf = static_cast<dt_float32*>(data.raw_ptr());
  74. data_rng.fill_fast_float32(pf, szf);
  75. };
  76. for (auto&& i : ret) {
  77. i.m_mem.reset(new dt_float32[i.m_mat.layout.span().dist_elem()]);
  78. i.m_mat.reset_ptr(i.m_mem.get());
  79. fill_data(i.m_mat);
  80. }
  81. return ret;
  82. }
  83. SVDTestcase::Result SVDTestcase::run(SVDForward* opr) {
  84. auto handle = opr->handle();
  85. auto src = make_tensor_h2d(handle, m_mat);
  86. // Deduce layout
  87. TensorLayout u_layout, s_layout, vt_layout;
  88. opr->param() = m_param;
  89. opr->deduce_layout(m_mat.layout, u_layout, s_layout, vt_layout);
  90. // Alloc tensor on device
  91. Tensor<> u{handle, u_layout}, s{handle, s_layout}, vt{handle, vt_layout};
  92. WorkspaceWrapper ws(
  93. handle,
  94. opr->get_workspace_in_bytes(m_mat.layout, u_layout, s_layout, vt_layout));
  95. opr->exec(*src, u.tensornd(), s.tensornd(), vt.tensornd(), ws.workspace());
  96. auto u_host = make_tensor_d2h(handle, u.tensornd());
  97. // Defined in wsdk8/Include/shared/inaddr.h Surprise! It's Windows.
  98. #undef s_host
  99. auto s_host = make_tensor_d2h(handle, s.tensornd());
  100. auto vt_host = make_tensor_d2h(handle, vt.tensornd());
  101. if (m_param.compute_uv) {
  102. // Copy back singular value, build diag(s)
  103. std::unique_ptr<dt_float32> diag_s_host_mem(
  104. new dt_float32[m_mat.layout.span().dist_elem()]);
  105. TensorLayout diag_layout = m_mat.layout;
  106. if (!m_param.full_matrices) {
  107. SmallVector<size_t> shape;
  108. for (int i = 0; i < (int)diag_layout.ndim - 2; i++) {
  109. shape.push_back(diag_layout[i]);
  110. }
  111. size_t x = std::min(
  112. diag_layout[diag_layout.ndim - 1],
  113. diag_layout[diag_layout.ndim - 2]);
  114. shape.push_back(x);
  115. shape.push_back(x);
  116. diag_layout = {shape, diag_layout.dtype};
  117. }
  118. TensorND diag_s_host{diag_s_host_mem.get(), diag_layout};
  119. fill_diag<dt_float32>(*s_host, diag_s_host);
  120. // Try to recover original matrix by u * diag(s) * vt
  121. auto diag_s_dev = make_tensor_h2d(handle, diag_s_host);
  122. auto tmp = matmul(handle, u.tensornd(), *diag_s_dev);
  123. auto recovered = matmul(handle, tmp->tensornd(), vt.tensornd());
  124. return {u_host, s_host, vt_host,
  125. make_tensor_d2h(handle, recovered->tensornd())};
  126. }
  127. return {u_host, s_host, vt_host, nullptr};
  128. }
  129. // vim: syntax=cpp.doxygen