You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lstm.cpp 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. #include "test/arm_common/fixture.h"
  2. #include "megdnn/oprs.h"
  3. #include "megdnn/oprs/general.h"
  4. #include "test/common/benchmarker.h"
  5. #include "test/common/checker.h"
  6. #include "test/common/task_record_check.h"
  7. using namespace megdnn;
  8. using namespace test;
  9. namespace {
  10. //! in arm_common the reserve tensor is not used
  11. void output_canonizer(const CheckerHelper::TensorValueArray& arr) {
  12. const TensorND& reserve = arr.back();
  13. TensorND& modif_reserve = const_cast<TensorND&>(reserve);
  14. modif_reserve.layout = TensorLayout();
  15. }
  16. } // namespace
  17. TEST_F(ARM_COMMON, LSTMCell) {
  18. Checker<LSTMCell> checker(handle());
  19. checker.set_output_canonizer(output_canonizer);
  20. checker.exec(
  21. {{1, 10},
  22. {40, 10},
  23. {1, 40},
  24. {1, 10},
  25. {40, 10},
  26. {1, 40},
  27. {1, 10},
  28. {},
  29. {},
  30. {}});
  31. for (size_t batch : {2})
  32. for (size_t n : {3, 4, 5, 23, 100})
  33. for (size_t out : {3, 6, 25, 100}) {
  34. checker.exec(
  35. {{batch, n},
  36. {out * 4, n},
  37. {1, out * 4},
  38. {batch, out},
  39. {out * 4, out},
  40. {1, out * 4},
  41. {batch, out},
  42. {},
  43. {},
  44. {}});
  45. checker.exec(
  46. {{batch, n},
  47. {out * 4, n},
  48. {batch, out * 4},
  49. {batch, out},
  50. {out * 4, out},
  51. {batch, out * 4},
  52. {batch, out},
  53. {},
  54. {},
  55. {}});
  56. }
  57. }
  58. TEST_F(ARM_COMMON, LSTMCellRecord) {
  59. TaskRecordChecker<LSTMCell> checker(0);
  60. checker.exec(
  61. {{1, 10},
  62. {40, 10},
  63. {1, 40},
  64. {1, 10},
  65. {40, 10},
  66. {1, 40},
  67. {1, 10},
  68. {},
  69. {},
  70. {}});
  71. }
  72. namespace {
  73. void test_lstm(bool bias, bool direction, Handle* handle) {
  74. Checker<LSTM> checker(handle, true);
  75. //! because lstm has tanh, exp mathematical compute, after more iteration,
  76. //! the error will more than 1e-3
  77. checker.set_epsilon(1e-2);
  78. checker.set_output_canonizer(output_canonizer);
  79. for (size_t input_size : {2, 8, 13})
  80. for (size_t hidden_size : {1, 4, 17}) {
  81. size_t dir_size = direction == false ? 1 : 2;
  82. LSTM::Param param;
  83. param.bidirectional = direction;
  84. size_t gate_hidden_size = 4 * hidden_size;
  85. param.bias = bias;
  86. param.hidden_size = hidden_size;
  87. for (size_t seq_len : {1, 3, 5})
  88. for (size_t batch_size : {1, 2, 4})
  89. for (size_t number_layer : {1, 2, 4, 5, 8}) {
  90. size_t flatten_size = 0;
  91. for (size_t layer = 0; layer < number_layer; layer++) {
  92. for (size_t dir = 0; dir < dir_size; dir++) {
  93. flatten_size += layer == 0
  94. ? input_size
  95. : dir_size * hidden_size; // ih
  96. flatten_size += hidden_size; // hh
  97. }
  98. }
  99. if (bias) {
  100. flatten_size += 2 * dir_size * number_layer;
  101. }
  102. param.num_layers = number_layer;
  103. checker.set_param(param).exec(
  104. {{seq_len, batch_size, input_size}, // input
  105. {number_layer * dir_size, batch_size,
  106. hidden_size}, // hx
  107. {number_layer * dir_size, batch_size,
  108. hidden_size}, // hy
  109. {gate_hidden_size, flatten_size}, // flat weight
  110. {},
  111. {},
  112. {},
  113. {}});
  114. }
  115. }
  116. }
  117. } // namespace
  118. TEST_F(ARM_COMMON, LSTM_FORWARD_NO_BIAS_NO_DIRCTION) {
  119. test_lstm(false, false, handle());
  120. }
  121. TEST_F(ARM_COMMON, LSTM_FORWARD_BIAS_NO_DIRCTION) {
  122. test_lstm(true, false, handle());
  123. }
  124. TEST_F(ARM_COMMON, LSTM_FORWARD_DIRECTION_NO_BIAS) {
  125. test_lstm(false, true, handle());
  126. }
  127. TEST_F(ARM_COMMON, LSTM_FORWARD_DIRECTION_BIAS) {
  128. test_lstm(true, true, handle());
  129. }
  130. TEST_F(ARM_COMMON, LSTM_FORWARD_RECORD) {
  131. TaskRecordChecker<LSTM> checker(0);
  132. size_t input_size = 2;
  133. size_t hidden_size = 2;
  134. size_t gate_hidden_size = 4 * hidden_size;
  135. LSTM::Param param;
  136. param.bidirectional = false;
  137. param.bias = false;
  138. param.hidden_size = hidden_size;
  139. // checker.set_output_canonizer(output_canonizer);
  140. for (size_t seq_len : {1, 3, 5})
  141. for (size_t batch_size : {1, 2, 4})
  142. for (size_t number_layer : {1, 2, 4, 5, 8}) {
  143. param.num_layers = number_layer;
  144. checker.set_param(param).exec(
  145. {{seq_len, batch_size, input_size}, // input
  146. {number_layer, batch_size, hidden_size}, // hx
  147. {number_layer, batch_size, hidden_size}, // hy
  148. {number_layer, gate_hidden_size,
  149. input_size + hidden_size}, // flat weight
  150. {},
  151. {},
  152. {},
  153. {}});
  154. }
  155. }
  156. #if MEGDNN_WITH_BENCHMARK
  157. TEST_F(ARM_COMMON, BENCHMARK_LSTM_FORWARD) {
  158. Benchmarker<LSTM> optimized_bench(handle());
  159. auto run = [&](size_t hidden_size, size_t input_size) {
  160. optimized_bench.set_times(20).set_display(true);
  161. size_t gate_hidden_size = 4 * hidden_size;
  162. for (bool direction : {false, true}) {
  163. LSTM::Param param;
  164. param.hidden_size = hidden_size;
  165. param.bidirectional = direction;
  166. param.bias = false;
  167. size_t dir_size = direction == false ? 1 : 2;
  168. for (size_t seq_len : {1, 5, 8})
  169. for (size_t batch_size : {1, 8, 16})
  170. for (size_t number_layer : {1}) {
  171. param.num_layers = number_layer;
  172. size_t flatten_size = 0;
  173. for (size_t layer = 0; layer < number_layer; layer++) {
  174. for (size_t dir = 0; dir < dir_size; dir++) {
  175. flatten_size += layer == 0
  176. ? input_size
  177. : dir_size * hidden_size; // ih
  178. flatten_size += hidden_size; // hh
  179. }
  180. }
  181. optimized_bench.set_param(param).exec(
  182. {{seq_len, batch_size, input_size}, // input
  183. {number_layer * dir_size, batch_size,
  184. hidden_size}, // hx
  185. {number_layer * dir_size, batch_size,
  186. hidden_size}, // hy
  187. {gate_hidden_size, flatten_size}, // flat weight
  188. {},
  189. {},
  190. {},
  191. {}});
  192. }
  193. }
  194. };
  195. run(512, 256);
  196. }
  197. #endif
  198. // vim: syntax=cpp.doxygen