You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_inverse.cpp 2.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. #include "megdnn/oprs/linalg.h"
  2. #include "test/common/rng.h"
  3. #include "test/common/tensor.h"
  4. #include "test/naive/fixture.h"
  5. using namespace megdnn;
  6. using namespace test;
  7. namespace {
  8. void run_check(Handle* handle, const size_t B, const size_t N, const TensorShape& shp) {
  9. SyncedTensor<> input(handle, shp), output(handle, input.layout()),
  10. mul_check(handle, input.layout());
  11. {
  12. auto t = input.tensornd_host();
  13. InvertibleMatrixRNG{}.gen(t);
  14. }
  15. auto opr = handle->create_operator<MatrixInverse>();
  16. auto wk_size = opr->get_workspace_in_bytes(input.layout(), output.layout());
  17. std::unique_ptr<dt_byte[]> wk_storage{new dt_byte[wk_size]};
  18. opr->exec(input.tensornd_dev(), output.tensornd_dev(), {wk_storage.get(), wk_size});
  19. auto batch_mul = handle->create_operator<BatchedMatrixMul>();
  20. auto make_std_tensor = [B, N](SyncedTensor<>& t) {
  21. auto ret = t.tensornd_dev();
  22. ret.layout.ndim = 3;
  23. ret.layout[0] = B;
  24. ret.layout[1] = ret.layout[2] = N;
  25. ret.layout.init_contiguous_stride();
  26. return ret;
  27. };
  28. auto batch_mul_inp = make_std_tensor(input);
  29. auto batch_mul_wk_size = batch_mul->get_workspace_in_bytes(
  30. batch_mul_inp.layout, batch_mul_inp.layout, batch_mul_inp.layout);
  31. std::unique_ptr<dt_byte[]> batch_mul_wk{new dt_byte[batch_mul_wk_size]};
  32. batch_mul->exec(
  33. make_std_tensor(output), batch_mul_inp, make_std_tensor(mul_check),
  34. {batch_mul_wk.get(), batch_mul_wk_size});
  35. auto hptr = mul_check.ptr_host();
  36. for (size_t i = 0; i < B; ++i) {
  37. for (size_t j = 0; j < N; ++j) {
  38. for (size_t k = 0; k < N; ++k) {
  39. auto val = hptr[i * N * N + j * N + k];
  40. if (j == k) {
  41. ASSERT_LT(std::abs(val - 1.f), 1e-4)
  42. << ssprintf("%zu,%zu,%zu/%zu,%zu: %g", i, j, k, N, B, val);
  43. } else {
  44. ASSERT_LT(std::abs(val - 0.f), 1e-4)
  45. << ssprintf("%zu,%zu,%zu/%zu,%zu: %g", i, j, k, N, B, val);
  46. }
  47. }
  48. }
  49. }
  50. }
  51. } // namespace
  52. TEST_F(NAIVE, MATRIX_INVERSE) {
  53. run_check(handle(), 2, 1, {1, 2, 1, 1});
  54. run_check(handle(), 1, 2, {2, 2});
  55. run_check(handle(), 4, 3, {2, 2, 3, 3});
  56. run_check(handle(), 4, 23, {4, 23, 23});
  57. run_check(handle(), 1, 100, {100, 100});
  58. run_check(handle(), 100, 3, {100, 3, 3});
  59. }
  60. // vim: syntax=cpp.doxygen