You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lstm.cpp 9.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. // #include "test/common/lstm.h"
  2. #include "megdnn/dtype.h"
  3. #include "megdnn/oprs.h"
  4. #include "test/common/checker.h"
  5. #include "test/naive/fixture.h"
  6. namespace megdnn {
  7. namespace test {
  8. TEST_F(NAIVE, LSTM_FORWARD) {
  9. Checker<LSTM> checker(handle(), true);
  10. size_t batch_size = 2;
  11. size_t input_size = 3;
  12. size_t hidden_size = 2;
  13. size_t seq_len = 2;
  14. size_t gate_hidden_size = 4 * hidden_size;
  15. LSTM::Param param;
  16. param.num_layers = 1;
  17. param.bidirectional = false;
  18. param.bias = false;
  19. param.hidden_size = hidden_size;
  20. checker.set_param(param).exect(
  21. Testcase{
  22. TensorValue(
  23. {seq_len, batch_size, input_size}, dtype::Float32(),
  24. {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), // input
  25. TensorValue(
  26. {batch_size, hidden_size}, dtype::Float32(),
  27. {1, 2, 3, 4}), // hx
  28. TensorValue(
  29. {batch_size, hidden_size}, dtype::Float32(),
  30. {2, 3, 4, 5}), // cx
  31. TensorValue(
  32. {gate_hidden_size, input_size + hidden_size},
  33. dtype::Float32(),
  34. {3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 3, 6,
  35. 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 9, 3, 5, 1,
  36. 9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1}), // flattern weights
  37. {},
  38. {},
  39. {},
  40. {}},
  41. Testcase{
  42. {},
  43. {},
  44. {},
  45. {},
  46. TensorValue(
  47. {seq_len, batch_size, hidden_size}, dtype::Float32(),
  48. {0.9951, 0.9993, 0.9999, 1.0000, 0.9993, 0.9999, 1.0000,
  49. 1.0000}), // output
  50. TensorValue(
  51. {batch_size, hidden_size}, dtype::Float32(),
  52. {0.9993, 0.9999, 1.0000, 1.0000}), // hy
  53. TensorValue(
  54. {batch_size, hidden_size}, dtype::Float32(),
  55. {4.0000, 5.0000, 6.0000, 7.0000}), // cy
  56. TensorValue(
  57. {2, 2, 2, 2}, dtype::Float32(),
  58. {0.995054, 0.999328, 0.99990, 0.999987, 3., 4., 5., 6.,
  59. 0.999329, 0.999328, 0.99990, 1., 4., 5., 6.,
  60. 7.}) // reserve space
  61. });
  62. param.bidirectional = true;
  63. checker.set_param(param).exect(
  64. Testcase{
  65. TensorValue(
  66. {seq_len, batch_size, input_size}, dtype::Float32(),
  67. {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), // input
  68. TensorValue(
  69. {2, batch_size, hidden_size}, dtype::Float32(),
  70. {1, 2, 3, 4, 5, 6, 7, 8}), // hx
  71. TensorValue(
  72. {2, batch_size, hidden_size}, dtype::Float32(),
  73. {2, 3, 4, 5, 6, 7, 8, 9}), // cx
  74. TensorValue(
  75. {gate_hidden_size, 2 * (input_size + hidden_size)},
  76. dtype::Float32(),
  77. {3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 3, 6, 1, 3, 2,
  78. 7, 2, 1, 3, 2, 1, 1, 9, 3, 5, 1, 9, 3, 5, 1, 9, 3,
  79. 5, 1, 9, 3, 5, 1, 3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1,
  80. 1, 3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 9, 3, 5, 1,
  81. 9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1}), // flattern weights
  82. {},
  83. {},
  84. {},
  85. {}},
  86. Testcase{
  87. {},
  88. {},
  89. {},
  90. {},
  91. TensorValue(
  92. {seq_len, batch_size, 2 * hidden_size}, dtype::Float32(),
  93. {0.9951, 0.9993, 1.0000, 1.0000, 0.9999, 1.0000, 1.0000,
  94. 1.0000, 0.9993, 0.9999, 1.0000, 1.0000, 1.0000, 1.0000,
  95. 1.0000, 1.0000}), // output
  96. TensorValue(
  97. {2, batch_size, hidden_size}, dtype::Float32(),
  98. {0.9993, 0.9999, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
  99. 1.0000}), // hy
  100. TensorValue(
  101. {2, batch_size, hidden_size}, dtype::Float32(),
  102. {4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000, 10.0000,
  103. 11.0000}), // cy
  104. TensorValue(
  105. {4, 2, 2, 2}, dtype::Float32(),
  106. {0.995054, 0.999328, 0.99990, 0.999987, 3., 4.,
  107. 5., 6., 0.999329, 0.999328, 0.99990, 1.,
  108. 4., 5., 6., 7., 1., 0.999328,
  109. 0.99990, 0.999987, 7., 8., 9., 10.,
  110. 0.999329, 0.999328, 0.99990, 1., 8., 9.,
  111. 10., 11.}) // reserve space
  112. });
  113. param.num_layers = 2;
  114. checker.set_param(param).exect(
  115. Testcase{
  116. TensorValue(
  117. {seq_len, batch_size, input_size}, dtype::Float32(),
  118. {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), // input
  119. TensorValue(
  120. {4, batch_size, hidden_size}, dtype::Float32(),
  121. {1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8}), // hx
  122. TensorValue(
  123. {4, batch_size, hidden_size}, dtype::Float32(),
  124. {2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9}), // cx
  125. TensorValue(
  126. {8, 22}, dtype::Float32(),
  127. {
  128. 3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 3, 6, 1, 3,
  129. 2, 7, 2, 1, 3, 2, 1, 1, 9, 3, 5, 1, 9, 3, 5, 1,
  130. 9, 3, 5, 1, 9, 3, 5, 1, 3, 6, 1, 3, 2, 7, 2, 1,
  131. 3, 2, 1, 1, 3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1,
  132. 9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1,
  133. 3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 2, 7, 2, 1,
  134. 3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 2, 7, 2, 1,
  135. 3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 2, 7, 2, 1,
  136. 3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 2, 7, 2, 1,
  137. 9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1,
  138. 9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1,
  139. }), // flattern weights
  140. {},
  141. {},
  142. {},
  143. {}},
  144. Testcase{
  145. {},
  146. {},
  147. {},
  148. {},
  149. TensorValue(
  150. {seq_len, batch_size, 2 * hidden_size}, dtype::Float32(),
  151. {0.9951, 0.9993, 1.0000, 1.0000, 0.9999, 1.0000, 1.0000,
  152. 1.0000, 0.9993, 0.9999, 1.0000, 1.0000, 1.0000, 1.0000,
  153. 1.0000, 1.0000}), // output
  154. TensorValue(
  155. {4, batch_size, hidden_size}, dtype::Float32(),
  156. {0.9993, 0.9999, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
  157. 1.0000, 0.9993, 0.9999, 1.0000, 1.0000, 1.0000, 1.0000,
  158. 1.0000, 1.0000}), // hy
  159. TensorValue(
  160. {4, batch_size, hidden_size}, dtype::Float32(),
  161. {4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000, 10.0000,
  162. 11.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000,
  163. 10.0000, 11.0000}), // cy
  164. TensorValue(
  165. {8, 2, 2, 2}, dtype::Float32(),
  166. {
  167. 0.995054, 0.999328, 0.99990, 0.999987, 3.,
  168. 4., 5., 6., 0.999329, 0.999328,
  169. 0.99990, 1., 4., 5., 6.,
  170. 7., 1., 0.999328, 0.99990, 0.999987,
  171. 7., 8., 9., 10., 0.999329,
  172. 0.999328, 0.99990, 1., 8., 9.,
  173. 10., 11., 0.995054, 0.999328, 0.99990,
  174. 0.999987, 3., 4., 5., 6.,
  175. 0.999329, 0.999328, 0.99990, 1., 4.,
  176. 5., 6., 7., 1., 0.999328,
  177. 0.99990, 0.999987, 7., 8., 9.,
  178. 10., 0.999329, 0.999328, 0.99990, 1.,
  179. 8., 9., 10., 11.,
  180. }) // reserve space
  181. });
  182. }
  183. } // namespace test
  184. } // namespace megdnn