You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

diag.cpp 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. /**
  2. * \file dnn/test/naive/diag.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/dtype.h"
  13. #include "megdnn/oprs.h"
  14. #include "test/common/checker.h"
  15. #include "test/naive/fixture.h"
  16. namespace megdnn {
  17. namespace test {
  18. TEST_F(NAIVE, DiagVector2Matrix) {
  19. Checker<Diag> checker(handle(), false);
  20. Diag::Param param;
  21. param.k = 0;
  22. checker.set_param(param).exect(
  23. Testcase{TensorValue({3}, dtype::Float32(), {1, 2, 3}), {}},
  24. Testcase{
  25. {},
  26. // clang-format off
  27. TensorValue({3, 3}, dtype::Float32(), {1, 0, 0,
  28. 0, 2, 0,
  29. 0, 0, 3})});
  30. // clang-format on
  31. }
  32. TEST_F(NAIVE, DiagVector2Matrix_PositiveK) {
  33. Checker<Diag> checker(handle(), false);
  34. Diag::Param param;
  35. param.k = 1;
  36. checker.set_param(param).exect(
  37. Testcase{TensorValue({3}, dtype::Float32(), {1, 2, 3}), {}},
  38. Testcase{
  39. {},
  40. // clang-format off
  41. TensorValue({4, 4}, dtype::Float32(), {0, 1, 0, 0,
  42. 0, 0, 2, 0,
  43. 0, 0, 0, 3,
  44. 0, 0, 0, 0,})});
  45. // clang-format on
  46. }
  47. TEST_F(NAIVE, DiagVector2Matrix_NegativeK) {
  48. Checker<Diag> checker(handle(), false);
  49. Diag::Param param;
  50. param.k = -1;
  51. checker.set_param(param).exect(
  52. Testcase{TensorValue({3}, dtype::Float32(), {1, 2, 3}), {}},
  53. Testcase{
  54. {},
  55. // clang-format off
  56. TensorValue({4, 4}, dtype::Float32(), {0, 0, 0, 0,
  57. 1, 0, 0, 0,
  58. 0, 2, 0, 0,
  59. 0, 0, 3, 0,})});
  60. // clang-format on
  61. }
  62. TEST_F(NAIVE, DiagMatrix2Vector) {
  63. Checker<Diag> checker(handle(), false);
  64. Diag::Param param;
  65. param.k = 0;
  66. checker.set_param(param).exect(
  67. // clang-format off
  68. Testcase{TensorValue({3, 3}, dtype::Float32(), {1, 2, 3,
  69. 4, 5, 6,
  70. 7, 8, 9}),
  71. // clang-format on
  72. {}},
  73. Testcase{{}, TensorValue({3}, dtype::Float32(), {1, 5, 9})});
  74. }
  75. TEST_F(NAIVE, DiagMatrix2Vector_PositiveK) {
  76. Checker<Diag> checker(handle(), false);
  77. Diag::Param param;
  78. param.k = 1;
  79. checker.set_param(param).exect(
  80. // clang-format off
  81. Testcase{TensorValue({3, 3}, dtype::Float32(), {1, 2, 3,
  82. 4, 5, 6,
  83. 7, 8, 9}),
  84. // clang-format on
  85. {}},
  86. Testcase{{}, TensorValue({2}, dtype::Float32(), {2, 6})});
  87. }
  88. TEST_F(NAIVE, DiagMatrix2Vector_NegativeK) {
  89. Checker<Diag> checker(handle(), false);
  90. Diag::Param param;
  91. param.k = -1;
  92. checker.set_param(param).exect(
  93. // clang-format off
  94. Testcase{TensorValue({3, 3}, dtype::Float32(), {1, 2, 3,
  95. 4, 5, 6,
  96. 7, 8, 9}),
  97. // clang-format on
  98. {}},
  99. Testcase{{}, TensorValue({2}, dtype::Float32(), {4, 8})});
  100. }
  101. } // namespace test
  102. } // namespace megdnn