You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lstm.cpp 8.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. /**
  2. * \file dnn/test/arm_common/lstm.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/arm_common/fixture.h"
  12. #include "megdnn/oprs.h"
  13. #include "megdnn/oprs/general.h"
  14. #include "test/common/benchmarker.h"
  15. #include "test/common/checker.h"
  16. #include "test/common/task_record_check.h"
  17. using namespace megdnn;
  18. using namespace test;
  19. namespace {
  20. //! in arm_common the reserve tensor is not used
  21. void output_canonizer(const CheckerHelper::TensorValueArray& arr) {
  22. const TensorND& reserve = arr.back();
  23. TensorND& modif_reserve = const_cast<TensorND&>(reserve);
  24. modif_reserve.layout = TensorLayout();
  25. }
  26. } // namespace
  27. TEST_F(ARM_COMMON, LSTMCell) {
  28. Checker<LSTMCell> checker(handle());
  29. checker.set_output_canonizer(output_canonizer);
  30. checker.exec(
  31. {{1, 10},
  32. {40, 10},
  33. {1, 40},
  34. {1, 10},
  35. {40, 10},
  36. {1, 40},
  37. {1, 10},
  38. {},
  39. {},
  40. {}});
  41. for (size_t batch : {2})
  42. for (size_t n : {3, 4, 5, 23, 100})
  43. for (size_t out : {3, 6, 25, 100}) {
  44. checker.exec(
  45. {{batch, n},
  46. {out * 4, n},
  47. {1, out * 4},
  48. {batch, out},
  49. {out * 4, out},
  50. {1, out * 4},
  51. {batch, out},
  52. {},
  53. {},
  54. {}});
  55. checker.exec(
  56. {{batch, n},
  57. {out * 4, n},
  58. {batch, out * 4},
  59. {batch, out},
  60. {out * 4, out},
  61. {batch, out * 4},
  62. {batch, out},
  63. {},
  64. {},
  65. {}});
  66. }
  67. }
  68. TEST_F(ARM_COMMON, LSTMCellRecord) {
  69. TaskRecordChecker<LSTMCell> checker(0);
  70. checker.exec(
  71. {{1, 10},
  72. {40, 10},
  73. {1, 40},
  74. {1, 10},
  75. {40, 10},
  76. {1, 40},
  77. {1, 10},
  78. {},
  79. {},
  80. {}});
  81. }
  82. namespace {
  83. void test_lstm(bool bias, bool direction, Handle* handle) {
  84. Checker<LSTM> checker(handle, true);
  85. //! because lstm has tanh, exp mathematical compute, after more iteration,
  86. //! the error will more than 1e-3
  87. checker.set_epsilon(1e-2);
  88. checker.set_output_canonizer(output_canonizer);
  89. for (size_t input_size : {2, 8, 13})
  90. for (size_t hidden_size : {1, 4, 17}) {
  91. size_t dir_size = direction == false ? 1 : 2;
  92. LSTM::Param param;
  93. param.bidirectional = direction;
  94. size_t gate_hidden_size = 4 * hidden_size;
  95. param.bias = bias;
  96. param.hidden_size = hidden_size;
  97. for (size_t seq_len : {1, 3, 5})
  98. for (size_t batch_size : {1, 2, 4})
  99. for (size_t number_layer : {1, 2, 4, 5, 8}) {
  100. size_t flatten_size = 0;
  101. for (size_t layer = 0; layer < number_layer; layer++) {
  102. for (size_t dir = 0; dir < dir_size; dir++) {
  103. flatten_size += layer == 0
  104. ? input_size
  105. : dir_size * hidden_size; // ih
  106. flatten_size += hidden_size; // hh
  107. }
  108. }
  109. if (bias) {
  110. flatten_size += 2 * dir_size * number_layer;
  111. }
  112. param.num_layers = number_layer;
  113. checker.set_param(param).exec(
  114. {{seq_len, batch_size, input_size}, // input
  115. {number_layer * dir_size, batch_size,
  116. hidden_size}, // hx
  117. {number_layer * dir_size, batch_size,
  118. hidden_size}, // hy
  119. {gate_hidden_size, flatten_size}, // flat weight
  120. {},
  121. {},
  122. {},
  123. {}});
  124. }
  125. }
  126. }
  127. } // namespace
  128. TEST_F(ARM_COMMON, LSTM_FORWARD_NO_BIAS_NO_DIRCTION) {
  129. test_lstm(false, false, handle());
  130. }
  131. TEST_F(ARM_COMMON, LSTM_FORWARD_BIAS_NO_DIRCTION) {
  132. test_lstm(true, false, handle());
  133. }
  134. TEST_F(ARM_COMMON, LSTM_FORWARD_DIRECTION_NO_BIAS) {
  135. test_lstm(false, true, handle());
  136. }
  137. TEST_F(ARM_COMMON, LSTM_FORWARD_DIRECTION_BIAS) {
  138. test_lstm(true, true, handle());
  139. }
  140. TEST_F(ARM_COMMON, LSTM_FORWARD_RECORD) {
  141. TaskRecordChecker<LSTM> checker(0);
  142. size_t input_size = 2;
  143. size_t hidden_size = 2;
  144. size_t gate_hidden_size = 4 * hidden_size;
  145. LSTM::Param param;
  146. param.bidirectional = false;
  147. param.bias = false;
  148. param.hidden_size = hidden_size;
  149. // checker.set_output_canonizer(output_canonizer);
  150. for (size_t seq_len : {1, 3, 5})
  151. for (size_t batch_size : {1, 2, 4})
  152. for (size_t number_layer : {1, 2, 4, 5, 8}) {
  153. param.num_layers = number_layer;
  154. checker.set_param(param).exec(
  155. {{seq_len, batch_size, input_size}, // input
  156. {number_layer, batch_size, hidden_size}, // hx
  157. {number_layer, batch_size, hidden_size}, // hy
  158. {number_layer, gate_hidden_size,
  159. input_size + hidden_size}, // flat weight
  160. {},
  161. {},
  162. {},
  163. {}});
  164. }
  165. }
  166. #if MEGDNN_WITH_BENCHMARK
  167. TEST_F(ARM_COMMON, BENCHMARK_LSTM_FORWARD) {
  168. Benchmarker<LSTM> optimized_bench(handle());
  169. constexpr size_t RUNS = 20;
  170. auto run = [&](size_t hidden_size, size_t input_size) {
  171. optimized_bench.set_times(20).set_display(true);
  172. size_t gate_hidden_size = 4 * hidden_size;
  173. for (bool direction : {false, true}) {
  174. LSTM::Param param;
  175. param.hidden_size = hidden_size;
  176. param.bidirectional = direction;
  177. param.bias = false;
  178. size_t dir_size = direction == false ? 1 : 2;
  179. for (size_t seq_len : {1, 5, 8})
  180. for (size_t batch_size : {1, 8, 16})
  181. for (size_t number_layer : {1}) {
  182. param.num_layers = number_layer;
  183. size_t flatten_size = 0;
  184. for (size_t layer = 0; layer < number_layer; layer++) {
  185. for (size_t dir = 0; dir < dir_size; dir++) {
  186. flatten_size += layer == 0
  187. ? input_size
  188. : dir_size * hidden_size; // ih
  189. flatten_size += hidden_size; // hh
  190. }
  191. }
  192. optimized_bench.set_param(param).exec(
  193. {{seq_len, batch_size, input_size}, // input
  194. {number_layer * dir_size, batch_size,
  195. hidden_size}, // hx
  196. {number_layer * dir_size, batch_size,
  197. hidden_size}, // hy
  198. {gate_hidden_size, flatten_size}, // flat weight
  199. {},
  200. {},
  201. {},
  202. {}});
  203. }
  204. }
  205. };
  206. run(512, 256);
  207. }
  208. #endif
  209. // vim: syntax=cpp.doxygen