You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lstm.cpp 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. /**
  2. * \file dnn/test/naive/lstm.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing, software
  8. * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
  9. * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. // #include "test/common/lstm.h"
  12. #include "megdnn/dtype.h"
  13. #include "megdnn/oprs.h"
  14. #include "test/common/checker.h"
  15. #include "test/naive/fixture.h"
  16. namespace megdnn {
  17. namespace test {
  18. TEST_F(NAIVE, LSTM_FORWARD) {
  19. Checker<LSTM> checker(handle(), true);
  20. size_t batch_size = 2;
  21. size_t input_size = 3;
  22. size_t hidden_size = 2;
  23. size_t seq_len = 2;
  24. size_t gate_hidden_size = 4 * hidden_size;
  25. LSTM::Param param;
  26. param.num_layers = 1;
  27. param.bidirectional = false;
  28. param.bias = false;
  29. param.hidden_size = hidden_size;
  30. checker.set_param(param).exect(
  31. Testcase{
  32. TensorValue(
  33. {seq_len, batch_size, input_size}, dtype::Float32(),
  34. {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), // input
  35. TensorValue(
  36. {batch_size, hidden_size}, dtype::Float32(),
  37. {1, 2, 3, 4}), // hx
  38. TensorValue(
  39. {batch_size, hidden_size}, dtype::Float32(),
  40. {2, 3, 4, 5}), // cx
  41. TensorValue(
  42. {gate_hidden_size, input_size + hidden_size},
  43. dtype::Float32(),
  44. {3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 3, 6,
  45. 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 9, 3, 5, 1,
  46. 9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1}), // flattern weights
  47. {},
  48. {},
  49. {},
  50. {}},
  51. Testcase{
  52. {},
  53. {},
  54. {},
  55. {},
  56. TensorValue(
  57. {seq_len, batch_size, hidden_size}, dtype::Float32(),
  58. {0.9951, 0.9993, 0.9999, 1.0000, 0.9993, 0.9999, 1.0000,
  59. 1.0000}), // output
  60. TensorValue(
  61. {batch_size, hidden_size}, dtype::Float32(),
  62. {0.9993, 0.9999, 1.0000, 1.0000}), // hy
  63. TensorValue(
  64. {batch_size, hidden_size}, dtype::Float32(),
  65. {4.0000, 5.0000, 6.0000, 7.0000}), // cy
  66. TensorValue(
  67. {2, 2, 2, 2}, dtype::Float32(),
  68. {0.995054, 0.999328, 0.99990, 0.999987, 3., 4., 5., 6.,
  69. 0.999329, 0.999328, 0.99990, 1., 4., 5., 6.,
  70. 7.}) // reserve space
  71. });
  72. param.bidirectional = true;
  73. checker.set_param(param).exect(
  74. Testcase{
  75. TensorValue(
  76. {seq_len, batch_size, input_size}, dtype::Float32(),
  77. {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), // input
  78. TensorValue(
  79. {2, batch_size, hidden_size}, dtype::Float32(),
  80. {1, 2, 3, 4, 5, 6, 7, 8}), // hx
  81. TensorValue(
  82. {2, batch_size, hidden_size}, dtype::Float32(),
  83. {2, 3, 4, 5, 6, 7, 8, 9}), // cx
  84. TensorValue(
  85. {gate_hidden_size, 2 * (input_size + hidden_size)},
  86. dtype::Float32(),
  87. {3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 3, 6, 1, 3, 2,
  88. 7, 2, 1, 3, 2, 1, 1, 9, 3, 5, 1, 9, 3, 5, 1, 9, 3,
  89. 5, 1, 9, 3, 5, 1, 3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1,
  90. 1, 3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 9, 3, 5, 1,
  91. 9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1}), // flattern weights
  92. {},
  93. {},
  94. {},
  95. {}},
  96. Testcase{
  97. {},
  98. {},
  99. {},
  100. {},
  101. TensorValue(
  102. {seq_len, batch_size, 2 * hidden_size}, dtype::Float32(),
  103. {0.9951, 0.9993, 1.0000, 1.0000, 0.9999, 1.0000, 1.0000,
  104. 1.0000, 0.9993, 0.9999, 1.0000, 1.0000, 1.0000, 1.0000,
  105. 1.0000, 1.0000}), // output
  106. TensorValue(
  107. {2, batch_size, hidden_size}, dtype::Float32(),
  108. {0.9993, 0.9999, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
  109. 1.0000}), // hy
  110. TensorValue(
  111. {2, batch_size, hidden_size}, dtype::Float32(),
  112. {4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000, 10.0000,
  113. 11.0000}), // cy
  114. TensorValue(
  115. {4, 2, 2, 2}, dtype::Float32(),
  116. {0.995054, 0.999328, 0.99990, 0.999987, 3., 4.,
  117. 5., 6., 0.999329, 0.999328, 0.99990, 1.,
  118. 4., 5., 6., 7., 1., 0.999328,
  119. 0.99990, 0.999987, 7., 8., 9., 10.,
  120. 0.999329, 0.999328, 0.99990, 1., 8., 9.,
  121. 10., 11.}) // reserve space
  122. });
  123. param.num_layers = 2;
  124. checker.set_param(param).exect(
  125. Testcase{
  126. TensorValue(
  127. {seq_len, batch_size, input_size}, dtype::Float32(),
  128. {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), // input
  129. TensorValue(
  130. {4, batch_size, hidden_size}, dtype::Float32(),
  131. {1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8}), // hx
  132. TensorValue(
  133. {4, batch_size, hidden_size}, dtype::Float32(),
  134. {2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9}), // cx
  135. TensorValue(
  136. {8, 22}, dtype::Float32(),
  137. {
  138. 3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 3, 6, 1, 3,
  139. 2, 7, 2, 1, 3, 2, 1, 1, 9, 3, 5, 1, 9, 3, 5, 1,
  140. 9, 3, 5, 1, 9, 3, 5, 1, 3, 6, 1, 3, 2, 7, 2, 1,
  141. 3, 2, 1, 1, 3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1,
  142. 9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1,
  143. 3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 2, 7, 2, 1,
  144. 3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 2, 7, 2, 1,
  145. 3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 2, 7, 2, 1,
  146. 3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 2, 7, 2, 1,
  147. 9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1,
  148. 9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1,
  149. }), // flattern weights
  150. {},
  151. {},
  152. {},
  153. {}},
  154. Testcase{
  155. {},
  156. {},
  157. {},
  158. {},
  159. TensorValue(
  160. {seq_len, batch_size, 2 * hidden_size}, dtype::Float32(),
  161. {0.9951, 0.9993, 1.0000, 1.0000, 0.9999, 1.0000, 1.0000,
  162. 1.0000, 0.9993, 0.9999, 1.0000, 1.0000, 1.0000, 1.0000,
  163. 1.0000, 1.0000}), // output
  164. TensorValue(
  165. {4, batch_size, hidden_size}, dtype::Float32(),
  166. {0.9993, 0.9999, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
  167. 1.0000, 0.9993, 0.9999, 1.0000, 1.0000, 1.0000, 1.0000,
  168. 1.0000, 1.0000}), // hy
  169. TensorValue(
  170. {4, batch_size, hidden_size}, dtype::Float32(),
  171. {4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000, 10.0000,
  172. 11.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000,
  173. 10.0000, 11.0000}), // cy
  174. TensorValue(
  175. {8, 2, 2, 2}, dtype::Float32(),
  176. {
  177. 0.995054, 0.999328, 0.99990, 0.999987, 3.,
  178. 4., 5., 6., 0.999329, 0.999328,
  179. 0.99990, 1., 4., 5., 6.,
  180. 7., 1., 0.999328, 0.99990, 0.999987,
  181. 7., 8., 9., 10., 0.999329,
  182. 0.999328, 0.99990, 1., 8., 9.,
  183. 10., 11., 0.995054, 0.999328, 0.99990,
  184. 0.999987, 3., 4., 5., 6.,
  185. 0.999329, 0.999328, 0.99990, 1., 4.,
  186. 5., 6., 7., 1., 0.999328,
  187. 0.99990, 0.999987, 7., 8., 9.,
  188. 10., 0.999329, 0.999328, 0.99990, 1.,
  189. 8., 9., 10., 11.,
  190. }) // reserve space
  191. });
  192. }
  193. } // namespace test
  194. } // namespace megdnn