You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lstmcell.cpp 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. /**
  2. * \file dnn/test/naive/lstmcell.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing, software
  8. * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
  9. * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/dtype.h"
  12. #include "megdnn/oprs.h"
  13. #include "test/common/checker.h"
  14. #include "test/naive/fixture.h"
  15. namespace megdnn {
  16. namespace test {
  17. TEST_F(NAIVE, LSTMCELL) {
  18. Checker<LSTMCell> checker(handle(), true);
  19. for (size_t batch : {1, 4})
  20. for (size_t n : {3, 4, 5, 23, 100})
  21. for (size_t out : {3, 6, 25, 100}) {
  22. checker.exec(
  23. {{batch, n},
  24. {out * 4, n},
  25. {1, out * 4},
  26. {batch, out},
  27. {out * 4, out},
  28. {1, out * 4},
  29. {batch, out},
  30. {},
  31. {},
  32. {}});
  33. }
  34. size_t batch_size = 2;
  35. size_t input_size = 3;
  36. size_t hidden_size = 2;
  37. checker.exect(
  38. Testcase{
  39. TensorValue(
  40. {batch_size, input_size}, dtype::Float32(),
  41. {1, 2, 3, 4, 5, 6}), // input
  42. TensorValue(
  43. {4 * hidden_size, input_size}, dtype::Float32(),
  44. {
  45. 0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
  46. 0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
  47. 0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
  48. 0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
  49. }), // weight_ih
  50. TensorValue(
  51. {4 * hidden_size}, dtype::Float32(),
  52. {0, 0, 0, 0, 0, 0, 0, 0}), // bias_ih
  53. TensorValue(
  54. {batch_size, hidden_size}, dtype::Float32(),
  55. {1, 2, 3, 4}), // hx
  56. TensorValue(
  57. {4 * hidden_size, hidden_size}, dtype::Float32(),
  58. {0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
  59. 0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
  60. 0.3535, 0.3535}), // weight_hh
  61. TensorValue(
  62. {4 * hidden_size}, dtype::Float32(),
  63. {0, 0, 0, 0, 0, 0, 0, 0}), // bias_hh
  64. TensorValue(
  65. {batch_size, hidden_size}, dtype::Float32(),
  66. {2, 3, 4, 5}), // cx
  67. {},
  68. {},
  69. {}},
  70. Testcase{
  71. {},
  72. {},
  73. {},
  74. {},
  75. {},
  76. {},
  77. {},
  78. TensorValue(
  79. {batch_size, hidden_size}, dtype::Float32(),
  80. {0.9541, 0.9593, 0.9995, 0.9996}), // hy
  81. TensorValue(
  82. {batch_size, hidden_size}, dtype::Float32(),
  83. {2.8771, 3.8373, 4.9979, 5.9975}), // cy
  84. TensorValue(
  85. {batch_size, 4 * hidden_size}, dtype::Float32(),
  86. {3.18198, 3.18198, 7.7781, 7.7781, 3.18198, 3.18198,
  87. 7.77817, 7.77817, 3.18198, 3.18198, 7.77817, 7.77817,
  88. 3.18198, 3.18198, 7.77817, 7.77817}), // cy
  89. });
  90. batch_size = 2;
  91. input_size = 2;
  92. hidden_size = 1;
  93. checker.exect(
  94. Testcase{
  95. TensorValue(
  96. {batch_size, input_size}, dtype::Float32(),
  97. {1, 2, 3, 4}), // input
  98. TensorValue(
  99. {4 * hidden_size, input_size}, dtype::Float32(),
  100. {0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
  101. 0.3535}), // weight_ih
  102. TensorValue(
  103. {4 * hidden_size}, dtype::Float32(),
  104. {0.3535, 0.3535, 0.3535, 0.3535}), // bias_ih
  105. TensorValue(
  106. {batch_size, hidden_size}, dtype::Float32(), {1, 2}), // hx
  107. TensorValue(
  108. {4 * hidden_size, hidden_size}, dtype::Float32(),
  109. {0.3535, 0.3535, 0.3535, 0.3535}), // weight_hh
  110. TensorValue(
  111. {4 * hidden_size}, dtype::Float32(),
  112. {0.3535, 0.3535, 0.3535, 0.3535}), // bias_hh
  113. TensorValue(
  114. {batch_size, hidden_size}, dtype::Float32(), {4, 5}), // cx
  115. {},
  116. {},
  117. {}},
  118. Testcase{
  119. {},
  120. {},
  121. {},
  122. {},
  123. {},
  124. {},
  125. {},
  126. TensorValue(
  127. {batch_size, hidden_size}, dtype::Float32(),
  128. {0.8927, 0.9799}), // hy
  129. TensorValue(
  130. {batch_size, hidden_size}, dtype::Float32(),
  131. {4.4393, 5.8788}), // cy
  132. TensorValue(
  133. {batch_size, 4 * hidden_size}, dtype::Float32(),
  134. {2.1210, 3.8885, 2.1210, 3.8885, 2.1210, 3.8885, 2.1210,
  135. 3.8885}), // gates
  136. });
  137. }
  138. } // namespace test
  139. } // namespace megdnn