You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 8.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. /**
  2. * \file dnn/src/naive/rnn/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing, software
  8. * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
  9. * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/naive/rnn/opr_impl.h"
  12. #include "megdnn/dtype.h"
  13. #include "megdnn/oprs/base.h"
  14. #include "megdnn/oprs/general.h"
  15. #include "src/common/opr_delegate.h"
  16. #include "src/common/rnn.h"
  17. #include "src/common/utils.h"
  18. #include "src/naive/handle.h"
  19. #include "src/naive/matrix_mul/opr_impl.h"
  20. #include "src/naive/rnn/funcs.h"
  21. #include "src/naive/rnn/rnn.h"
  22. #include <cstring>
  23. namespace megdnn {
  24. namespace naive {
  25. using rnn::RNNCellWeightWrapper;
  26. void RNNImpl::exec(
  27. _megdnn_tensor_in input, _megdnn_tensor_in hx,
  28. _megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
  29. _megdnn_tensor_out hy, _megdnn_tensor_out reserve_space,
  30. _megdnn_workspace workspace) {
  31. auto _param = param();
  32. size_t D = _param.bidirectional ? 2 : 1;
  33. size_t num_layers = _param.num_layers;
  34. size_t input_size = input.layout.shape[2];
  35. std::vector<RNNCellWeightWrapper> cells;
  36. size_t used_workspace_size = rnn::get_cells<RNNCellWeightWrapper>(
  37. D, num_layers, input_size, _param.hidden_size, _param.bias, cells,
  38. flatten_weights, workspace);
  39. Workspace new_workspace(
  40. workspace.raw_ptr + used_workspace_size,
  41. workspace.size - used_workspace_size);
  42. TensorNDArray states, states_new;
  43. states.push_back(hx);
  44. states_new.push_back(hy);
  45. rnn::exec_internal<RNNCellWeightWrapper, RNNCellForward>(
  46. cells, input, states, states_new, output, reserve_space, num_layers, D,
  47. this->handle(), new_workspace);
  48. }
  49. size_t RNNImpl::get_workspace_in_bytes(
  50. const TensorLayout& input, const TensorLayout& hx,
  51. const TensorLayout& flatten_weights, const TensorLayout& output,
  52. const TensorLayout& hy, const TensorLayout& reserve_space) {
  53. size_t workspace_size = rnn::get_workspace_in_bytes<RNNCellForward>(
  54. input, flatten_weights, param().hidden_size, param().bidirectional ? 2 : 1,
  55. this->handle());
  56. if (!param().bias) { // use fake bias (all 0)
  57. TensorLayout bias_layout = {{param().hidden_size}, flatten_weights.dtype};
  58. workspace_size += bias_layout.span().dist_byte();
  59. }
  60. workspace_size += output.span().dist_byte();
  61. return workspace_size;
  62. }
  63. size_t RNNImpl::get_reserve_size_in_bytes(const TensorLayout& input) {
  64. size_t num_layers = param().num_layers;
  65. size_t D = param().bidirectional ? 2 : 1;
  66. size_t seq_len = input.shape[0];
  67. size_t batch_size = input.shape[1];
  68. TensorLayout state_layout{{batch_size, param().hidden_size}, input.dtype};
  69. return num_layers * D * seq_len * state_layout.span().dist_byte();
  70. }
  71. void RNNBackwardImpl::exec(
  72. _megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx,
  73. _megdnn_tensor_in dy, _megdnn_tensor_in dhy, _megdnn_tensor_in flatten_weights,
  74. _megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, _megdnn_tensor_out dhx,
  75. _megdnn_tensor_out dw, _megdnn_workspace workspace) {
  76. TensorNDArray layer_inputs;
  77. // layer_inputs.push_back(x);
  78. TensorNDArray layer_outputs;
  79. std::vector<std::vector<TensorNDArray>> cell_seq_states;
  80. size_t num_layers = param().num_layers;
  81. size_t D = param().bidirectional ? 2 : 1;
  82. // size_t seq_len = x.layout.shape[0];
  83. // size_t batch_size = x.layout.shape[1];
  84. size_t input_size = x.layout.shape[2];
  85. size_t hidden_size = param().hidden_size;
  86. size_t used_workspace_size = 0;
  87. // get cells
  88. std::vector<RNNCellWeightWrapper> cells;
  89. // workspace_ptr = static_cast<uint8_t*>(workspace_ptr) +
  90. used_workspace_size += rnn::get_cells(
  91. D, num_layers, input_size, hidden_size, param().bias, cells,
  92. flatten_weights, workspace);
  93. // extract intermedia states from reserve space
  94. /*for (int layer = 0; layer < num_layers; ++layer) {
  95. TensorND layer_output{workspace_ptr, y.layout};
  96. workspace_ptr = static_cast<uint8_t*>(workspace_ptr) +
  97. layer_output.layout.span().dist_byte(); for (int d = 0; d < D; ++d) {
  98. cell_seq_states.push_back(std::vector<TensorNDArray>());
  99. // reverse direction is stored with reversed order of sequence order
  100. for (int i = 0; i < seq_len; ++i) {
  101. size_t step = i;
  102. if (d == 1) step = seq_len - i - 1;
  103. size_t offset = ((layer * D + d) * seq_len + step) *
  104. cell_output_layout.span().dist_byte(); TensorND
  105. hy{static_cast<uint8_t*>(reserve_space.raw_ptr) + offset, cell_output_layout};
  106. // states
  107. cell_seq_states[cell_seq_states.size() - 1].push_back({hy});
  108. // output
  109. offset = i * D * cell_output_layout.span().dist_byte();
  110. memcpy(static_cast<uint8_t*>(layer_output.raw_ptr) + offset,
  111. hy.raw_ptr, hy.layout.span().dist_byte());
  112. }
  113. }
  114. cell_seq_outputs.push_back(layer_output);
  115. if (layer != num_layers - 1) layer_inputs.push_back(layer_output);
  116. }*/
  117. // nonlinear mode
  118. param::RNNCell::NonlineMode nonlineMode;
  119. using ModeRNN = param::RNN::NonlineMode;
  120. using ModeRNNCell = param::RNNCell::NonlineMode;
  121. switch (param().nonlineMode) {
  122. case ModeRNN::RELU:
  123. nonlineMode = ModeRNNCell::RELU;
  124. break;
  125. case ModeRNN::TANH:
  126. nonlineMode = ModeRNNCell::TANH;
  127. break;
  128. }
  129. // get formatted inputs
  130. Workspace new_workspace = Workspace(
  131. workspace.raw_ptr + used_workspace_size,
  132. workspace.size - used_workspace_size);
  133. used_workspace_size += rnn::get_inputs_for_exec<RNNCellWeightWrapper>(
  134. x, y, reserve_space, num_layers, D, hidden_size, cells, layer_inputs,
  135. layer_outputs, cell_seq_states, nonlineMode, new_workspace);
  136. // dhy arr, dhx arr
  137. TensorNDArray dhy_arr = {dhy}, dhx_arr = {dhx};
  138. // exec
  139. /*size_t used_workspace_size = static_cast<uint8_t*>(workspace_ptr) -
  140. static_cast<uint8_t*>((void*)workspace.raw_ptr);*/
  141. new_workspace = Workspace(
  142. workspace.raw_ptr + used_workspace_size,
  143. workspace.size - used_workspace_size);
  144. rnn::backward_exec_internal<RNNCellWeightWrapper>(
  145. cells, D, num_layers, input_size, param().bias, nonlineMode, layer_inputs,
  146. layer_outputs, cell_seq_states, dy, dhy_arr, dx, dhx_arr, dw,
  147. this->handle(), new_workspace);
  148. }
  149. size_t RNNBackwardImpl::get_workspace_in_bytes(
  150. const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
  151. const TensorLayout& dy, const TensorLayout& dhy,
  152. const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
  153. const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw) {
  154. size_t D = param().bidirectional ? 2 : 1;
  155. size_t num_layers = param().num_layers;
  156. size_t hidden_size = param().hidden_size;
  157. size_t gate_hidden_size = hidden_size;
  158. size_t max_input_size = std::max(x.shape[2], D * hidden_size);
  159. size_t workspace_size = RNNCellWeightWrapper::backward_workspace_size_in_bytes(
  160. this->handle(), x.shape[1], param().hidden_size, max_input_size, x.dtype);
  161. if (!param().bias) { // use fake bias (all 0)
  162. TensorLayout bias_layout = {{gate_hidden_size}, flatten_weights.dtype};
  163. workspace_size += bias_layout.span().dist_byte() *
  164. 2; // times 2 because another bias is allocated in
  165. // backward_exec_internal
  166. }
  167. workspace_size += num_layers * y.span().dist_byte();
  168. // add back exec workspace size
  169. workspace_size += y.span().dist_byte() * 2;
  170. workspace_size += x.span().dist_byte() * 2;
  171. TensorLayout wih{{gate_hidden_size, max_input_size}, flatten_weights.dtype};
  172. TensorLayout whh{{gate_hidden_size, hidden_size}, flatten_weights.dtype};
  173. TensorLayout bias{{gate_hidden_size}, flatten_weights.dtype};
  174. workspace_size += wih.span().dist_byte();
  175. workspace_size += whh.span().dist_byte();
  176. workspace_size += bias.span().dist_byte();
  177. return workspace_size;
  178. }
  179. } // namespace naive
  180. } // namespace megdnn