You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

diag.cpp 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. #include "megdnn/dtype.h"
  2. #include "megdnn/oprs.h"
  3. #include "test/common/checker.h"
  4. #include "test/naive/fixture.h"
  5. namespace megdnn {
  6. namespace test {
  7. TEST_F(NAIVE, DiagVector2Matrix) {
  8. Checker<Diag> checker(handle(), false);
  9. Diag::Param param;
  10. param.k = 0;
  11. checker.set_param(param).exect(
  12. Testcase{TensorValue({3}, dtype::Float32(), {1, 2, 3}), {}},
  13. Testcase{
  14. {},
  15. // clang-format off
  16. TensorValue({3, 3}, dtype::Float32(), {1, 0, 0,
  17. 0, 2, 0,
  18. 0, 0, 3})});
  19. // clang-format on
  20. }
  21. TEST_F(NAIVE, DiagVector2Matrix_PositiveK) {
  22. Checker<Diag> checker(handle(), false);
  23. Diag::Param param;
  24. param.k = 1;
  25. checker.set_param(param).exect(
  26. Testcase{TensorValue({3}, dtype::Float32(), {1, 2, 3}), {}},
  27. Testcase{
  28. {},
  29. // clang-format off
  30. TensorValue({4, 4}, dtype::Float32(), {0, 1, 0, 0,
  31. 0, 0, 2, 0,
  32. 0, 0, 0, 3,
  33. 0, 0, 0, 0,})});
  34. // clang-format on
  35. }
  36. TEST_F(NAIVE, DiagVector2Matrix_NegativeK) {
  37. Checker<Diag> checker(handle(), false);
  38. Diag::Param param;
  39. param.k = -1;
  40. checker.set_param(param).exect(
  41. Testcase{TensorValue({3}, dtype::Float32(), {1, 2, 3}), {}},
  42. Testcase{
  43. {},
  44. // clang-format off
  45. TensorValue({4, 4}, dtype::Float32(), {0, 0, 0, 0,
  46. 1, 0, 0, 0,
  47. 0, 2, 0, 0,
  48. 0, 0, 3, 0,})});
  49. // clang-format on
  50. }
  51. TEST_F(NAIVE, DiagMatrix2Vector) {
  52. Checker<Diag> checker(handle(), false);
  53. Diag::Param param;
  54. param.k = 0;
  55. checker.set_param(param).exect(
  56. // clang-format off
  57. Testcase{TensorValue({3, 3}, dtype::Float32(), {1, 2, 3,
  58. 4, 5, 6,
  59. 7, 8, 9}),
  60. // clang-format on
  61. {}},
  62. Testcase{{}, TensorValue({3}, dtype::Float32(), {1, 5, 9})});
  63. }
  64. TEST_F(NAIVE, DiagMatrix2Vector_PositiveK) {
  65. Checker<Diag> checker(handle(), false);
  66. Diag::Param param;
  67. param.k = 1;
  68. checker.set_param(param).exect(
  69. // clang-format off
  70. Testcase{TensorValue({3, 3}, dtype::Float32(), {1, 2, 3,
  71. 4, 5, 6,
  72. 7, 8, 9}),
  73. // clang-format on
  74. {}},
  75. Testcase{{}, TensorValue({2}, dtype::Float32(), {2, 6})});
  76. }
  77. TEST_F(NAIVE, DiagMatrix2Vector_NegativeK) {
  78. Checker<Diag> checker(handle(), false);
  79. Diag::Param param;
  80. param.k = -1;
  81. checker.set_param(param).exect(
  82. // clang-format off
  83. Testcase{TensorValue({3, 3}, dtype::Float32(), {1, 2, 3,
  84. 4, 5, 6,
  85. 7, 8, 9}),
  86. // clang-format on
  87. {}},
  88. Testcase{{}, TensorValue({2}, dtype::Float32(), {4, 8})});
  89. }
  90. } // namespace test
  91. } // namespace megdnn