You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lstm_cell.cpp 8.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. #include "src/common/lstm_cell.h"
  2. #include "megdnn/oprs.h"
  3. #include "src/common/utils.h"
  4. namespace megdnn {
  5. void LSTMCell::deduce_layout(
  6. const TensorLayout& input, const TensorLayout& weight_ih,
  7. const TensorLayout& bias_ih, const TensorLayout& hx,
  8. const TensorLayout& weight_hh, const TensorLayout& bias_hh,
  9. const TensorLayout& cx, TensorLayout& h_new, TensorLayout& c_new,
  10. TensorLayout& gates) {
  11. h_new = TensorLayout(hx, hx.dtype);
  12. c_new = TensorLayout(cx, cx.dtype);
  13. auto opr = handle()->create_operator<RNNCellForward>();
  14. opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY;
  15. opr->deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates);
  16. }
  17. void LSTMCell::check_exec(
  18. const TensorLayout& input, const TensorLayout& weight_ih,
  19. const TensorLayout& bias_ih, const TensorLayout& hx,
  20. const TensorLayout& weight_hh, const TensorLayout& bias_hh,
  21. const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new,
  22. const TensorLayout& gates, size_t workspace_in_bytes) {
  23. TensorLayout h_new_expected, c_new_expected, gates_expected;
  24. auto errmsg = [&]() {
  25. std::string msg;
  26. msg.append("input=");
  27. msg.append(input.to_string());
  28. msg.append(", weight_ih=");
  29. msg.append(weight_ih.to_string());
  30. msg.append(", bias_ih=");
  31. msg.append(bias_ih.to_string());
  32. msg.append(", hx=");
  33. msg.append(hx.to_string());
  34. msg.append(", weight_hh=");
  35. msg.append(weight_hh.to_string());
  36. msg.append(", bias_hh=");
  37. msg.append(bias_hh.to_string());
  38. msg.append(", cx=");
  39. msg.append(cx.to_string());
  40. return msg;
  41. };
  42. #define ASSERT_BRIEF(_content) megdnn_assert(_content, "%s", errmsg().c_str());
  43. ASSERT_BRIEF(input.ndim == 2)
  44. ASSERT_BRIEF(input.shape[1] == weight_ih.shape[1])
  45. ASSERT_BRIEF(weight_ih.shape[0] == weight_hh.shape[0])
  46. ASSERT_BRIEF(weight_hh.shape[0] == 4 * weight_hh.shape[1])
  47. ASSERT_BRIEF(bias_ih.shape[0] == bias_hh.shape[0])
  48. ASSERT_BRIEF(hx.ndim == 2)
  49. ASSERT_BRIEF(hx.shape[0] == input.shape[0])
  50. ASSERT_BRIEF(hx.shape[1] == cx.shape[1]) // hidden_size
  51. ASSERT_BRIEF(cx.ndim == 2)
  52. ASSERT_BRIEF(cx.shape[0] == input.shape[0])
  53. ASSERT_BRIEF(cx.shape[1] == weight_hh.shape[1])
  54. #undef ASSERT_BRIEF
  55. deduce_layout(
  56. input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new_expected,
  57. c_new_expected, gates_expected);
  58. megdnn_assert_eq_layout(h_new_expected, h_new);
  59. megdnn_assert_eq_layout(c_new_expected, c_new);
  60. megdnn_assert_eq_layout(gates_expected, gates);
  61. auto required_workspace_in_bytes = get_workspace_in_bytes(
  62. input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates);
  63. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  64. }
  65. } // namespace megdnn
  66. namespace megdnn {
  67. namespace lstm_cell {
  68. size_t get_workspace_in_bytes(
  69. const TensorLayout& input, const TensorLayout& weight_ih,
  70. const TensorLayout& bias_ih, const TensorLayout& hx,
  71. const TensorLayout& weight_hh, const TensorLayout& bias_hh,
  72. const TensorLayout& /*cx*/, const TensorLayout& /*h_new*/,
  73. const TensorLayout& /*c_new*/, const TensorLayout& gates, Handle* handle) {
  74. TensorLayout tmp_layout;
  75. auto opr = handle->create_operator<RNNCellForward>();
  76. opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY;
  77. opr->deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, tmp_layout);
  78. size_t rnn_cell_need = opr->get_workspace_in_bytes(
  79. input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates);
  80. size_t lstm_cell_need = 2 * tmp_layout.span().dist_byte();
  81. return rnn_cell_need > lstm_cell_need ? rnn_cell_need : lstm_cell_need;
  82. }
  83. void exec(
  84. _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih,
  85. _megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
  86. _megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new,
  87. _megdnn_tensor_out gates, _megdnn_workspace workspace, Handle* handle) {
  88. auto opr = handle->create_operator<RNNCellForward>();
  89. opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY;
  90. opr->exec(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates, workspace);
  91. // activation
  92. size_t batch_size = hx.layout.shape[0];
  93. size_t hidden_size = hx.layout.shape[1];
  94. auto copy_opr = handle->create_operator<TypeCvtForward>();
  95. TensorND copy_gates{static_cast<void*>(workspace.raw_ptr), gates.layout};
  96. TensorLayout hidden_layout{TensorShape{hidden_size}, hx.layout.dtype};
  97. TensorLayout gateinfo_layout{TensorShape{batch_size, hidden_size}, hx.layout.dtype};
  98. for (size_t i = 0; i < batch_size; i++) {
  99. for (size_t j = 0; j < 4; j++) {
  100. TensorND half_step_states{
  101. // output
  102. static_cast<uint8_t*>(gates.raw_ptr()) +
  103. (4 * i + j) * hidden_layout.span().dist_byte(),
  104. hidden_layout};
  105. TensorND half_step_output{
  106. static_cast<uint8_t*>(copy_gates.raw_ptr()) +
  107. j * gateinfo_layout.span().dist_byte() +
  108. i * hidden_layout.span().dist_byte(),
  109. hidden_layout};
  110. copy_opr->exec(half_step_states, half_step_output);
  111. }
  112. }
  113. void* workspace_ptr = workspace.raw_ptr + copy_gates.layout.span().dist_byte();
  114. copy_opr->exec(copy_gates, gates);
  115. // sigmoid: i f
  116. TensorND tmp{static_cast<void*>(workspace_ptr), copy_gates.layout};
  117. TensorLayout gates_ifo_layout{
  118. TensorShape({batch_size, hidden_size * 2}), copy_gates.layout.dtype};
  119. TensorND gates_ifo_origin{copy_gates.raw_ptr(), gates_ifo_layout};
  120. TensorND gates_ifo{tmp.raw_ptr(), gates_ifo_layout};
  121. auto sigmoid = handle->create_operator<ElemwiseForward>();
  122. sigmoid->param().mode = Elemwise::Param::Mode::SIGMOID;
  123. sigmoid->exec({gates_ifo_origin}, gates_ifo);
  124. // tanh: g
  125. TensorLayout g_layout{
  126. TensorShape({batch_size, hidden_size}), copy_gates.layout.dtype};
  127. TensorND g_origin{
  128. static_cast<char*>(copy_gates.raw_ptr()) +
  129. gates_ifo_layout.span().dist_byte(),
  130. g_layout};
  131. TensorND g{
  132. static_cast<char*>(tmp.raw_ptr()) + gates_ifo_layout.span().dist_byte(),
  133. g_layout};
  134. auto tanh = handle->create_operator<ElemwiseForward>();
  135. tanh->param().mode = Elemwise::Param::Mode::TANH;
  136. tanh->exec({g_origin}, g);
  137. // sigmoid: o
  138. TensorLayout three_gates_ifo_layout{
  139. TensorShape({batch_size, hidden_size * 3}), copy_gates.layout.dtype};
  140. TensorLayout o_layout{
  141. TensorShape({batch_size, hidden_size}), copy_gates.layout.dtype};
  142. TensorND o_origin{
  143. static_cast<char*>(copy_gates.raw_ptr()) +
  144. three_gates_ifo_layout.span().dist_byte(),
  145. o_layout};
  146. TensorND o{
  147. static_cast<char*>(tmp.raw_ptr()) +
  148. three_gates_ifo_layout.span().dist_byte(),
  149. o_layout};
  150. sigmoid->exec({o_origin}, o);
  151. // extract i f o
  152. TensorND i{static_cast<char*>(tmp.raw_ptr()), g_layout};
  153. TensorND f{
  154. static_cast<char*>(tmp.raw_ptr()) + g_layout.span().dist_byte(), g_layout};
  155. // calculate new cell state
  156. auto elewise_mul_add = handle->create_operator<ElemwiseForward>();
  157. elewise_mul_add->param().mode = Elemwise::Param::Mode::FUSE_MUL_ADD4;
  158. elewise_mul_add->exec({f, cx, i, g}, c_new);
  159. // calculate new hidden state
  160. tanh->exec({c_new}, h_new);
  161. auto elewise_mul = handle->create_operator<ElemwiseForward>();
  162. elewise_mul->param().mode = Elemwise::Param::Mode::MUL;
  163. elewise_mul->exec({o, h_new}, h_new);
  164. }
  165. } // namespace lstm_cell
  166. } // namespace megdnn