You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_inverse.cpp 2.9 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. /**
  2. * \file dnn/test/naive/matrix_inverse.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/oprs/linalg.h"
  12. #include "test/common/rng.h"
  13. #include "test/common/tensor.h"
  14. #include "test/naive/fixture.h"
  15. using namespace megdnn;
  16. using namespace test;
  17. namespace {
  18. void run_check(Handle* handle, const size_t B, const size_t N,
  19. const TensorShape& shp) {
  20. SyncedTensor<> input(handle, shp), output(handle, input.layout()),
  21. mul_check(handle, input.layout());
  22. {
  23. auto t = input.tensornd_host();
  24. InvertibleMatrixRNG{}.gen(t);
  25. }
  26. auto opr = handle->create_operator<MatrixInverse>();
  27. auto wk_size = opr->get_workspace_in_bytes(input.layout(), output.layout());
  28. std::unique_ptr<dt_byte[]> wk_storage{new dt_byte[wk_size]};
  29. opr->exec(input.tensornd_dev(), output.tensornd_dev(),
  30. {wk_storage.get(), wk_size});
  31. auto batch_mul = handle->create_operator<BatchedMatrixMul>();
  32. auto make_std_tensor = [B, N](SyncedTensor<>& t) {
  33. auto ret = t.tensornd_dev();
  34. ret.layout.ndim = 3;
  35. ret.layout[0] = B;
  36. ret.layout[1] = ret.layout[2] = N;
  37. ret.layout.init_contiguous_stride();
  38. return ret;
  39. };
  40. auto batch_mul_inp = make_std_tensor(input);
  41. auto batch_mul_wk_size = batch_mul->get_workspace_in_bytes(
  42. batch_mul_inp.layout, batch_mul_inp.layout, batch_mul_inp.layout);
  43. std::unique_ptr<dt_byte[]> batch_mul_wk{new dt_byte[batch_mul_wk_size]};
  44. batch_mul->exec(make_std_tensor(output), batch_mul_inp,
  45. make_std_tensor(mul_check),
  46. {batch_mul_wk.get(), batch_mul_wk_size});
  47. auto hptr = mul_check.ptr_host();
  48. for (size_t i = 0; i < B; ++i) {
  49. for (size_t j = 0; j < N; ++j) {
  50. for (size_t k = 0; k < N; ++k) {
  51. auto val = hptr[i * N * N + j * N + k];
  52. if (j == k) {
  53. ASSERT_LT(std::abs(val - 1.f), 1e-4) << ssprintf(
  54. "%zu,%zu,%zu/%zu,%zu: %g", i, j, k, N, B, val);
  55. } else {
  56. ASSERT_LT(std::abs(val - 0.f), 1e-4) << ssprintf(
  57. "%zu,%zu,%zu/%zu,%zu: %g", i, j, k, N, B, val);
  58. }
  59. }
  60. }
  61. }
  62. }
  63. } // namespace
  64. TEST_F(NAIVE, MATRIX_INVERSE) {
  65. run_check(handle(), 2, 1, {1, 2, 1, 1});
  66. run_check(handle(), 1, 2, {2, 2});
  67. run_check(handle(), 4, 3, {2, 2, 3, 3});
  68. run_check(handle(), 4, 23, {4, 23, 23});
  69. run_check(handle(), 1, 100, {100, 100});
  70. run_check(handle(), 100, 3, {100, 3, 3});
  71. }
  72. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台