You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_inverse.cpp 2.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. /**
  2. * \file dnn/src/common/matrix_inverse.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/oprs/linalg.h"
  12. #include "src/common/utils.h"
  13. using namespace megdnn;
  14. void MatrixInverse::deduce_layout(const TensorLayout& src, TensorLayout& dst) {
  15. canonize_params(src, nullptr, nullptr);
  16. dst = src;
  17. }
  18. size_t MatrixInverse::get_workspace_in_bytes(
  19. const TensorLayout& src, const TensorLayout& dst) {
  20. size_t batch, n;
  21. canonize_params(src, &batch, &n);
  22. megdnn_assert(
  23. src.eq_layout(dst), "src and dst unequal: %s vs %s",
  24. src.to_string().c_str(), dst.to_string().c_str());
  25. return get_workspace_in_bytes(batch, n, src.dtype.size());
  26. }
  27. void MatrixInverse::canonize_params(
  28. const TensorLayout& layout, size_t* batch, size_t* n) {
  29. megdnn_assert(
  30. layout.is_contiguous() && layout.ndim >= 2 &&
  31. layout[layout.ndim - 2] == layout[layout.ndim - 1],
  32. "invalid MatrixInverse layout: %s", layout.to_string().c_str());
  33. megdnn_assert(
  34. DNN_FLOAT16_SELECT(layout.dtype == dtype::Float16(), false) ||
  35. layout.dtype == dtype::Float32(),
  36. "MatrixInverse only supports f16 & f32");
  37. if (batch) {
  38. *batch = 1;
  39. for (size_t i = 0; i < layout.ndim - 2; ++i) {
  40. *batch *= layout[i];
  41. }
  42. }
  43. if (n) {
  44. *n = layout[layout.ndim - 1];
  45. }
  46. }
  47. void MatrixInverse::check_exec(
  48. const TensorLayout& src, const TensorLayout& dst, _megdnn_workspace workspace,
  49. size_t* batch, size_t* n) {
  50. canonize_params(src, batch, n);
  51. megdnn_assert(
  52. src.eq_layout(dst), "src and dst unequal: %s vs %s",
  53. src.to_string().c_str(), dst.to_string().c_str());
  54. megdnn_assert(
  55. workspace.size >= get_workspace_in_bytes(*batch, *n, src.dtype.size()));
  56. }
  57. // vim: syntax=cpp.doxygen