You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

funcs.h 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. /**
  2. * \file dnn/src/naive/rnn/funcs.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing, software
  8. * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
  9. * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. // #ifndef _RNN_H
  12. // #define _RNN_H
  13. #include "megdnn/oprs.h"
  14. namespace megdnn {
  15. namespace naive {
  16. namespace rnn {
  17. template <typename CellOpr>
  18. void cell_opr_exec(
  19. _megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
  20. _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_ih,
  21. _megdnn_tensor_in bias_hh, const TensorNDArray& states,
  22. TensorNDArray& states_new, _megdnn_workspace workspace,
  23. param::RNNCell::NonlineMode nonline_mode, Handle* handle);
  24. template <typename CellOpr>
  25. size_t cell_opr_get_workspace_in_bytes(
  26. const TensorLayout& input, const TensorLayout& weight_ih,
  27. const TensorLayout& weight_hh, const TensorLayout& bias_ih,
  28. const TensorLayout& bias_hh, const TensorLayout& hx, Handle* handle);
  29. template <typename CellOpr>
  30. size_t get_workspace_in_bytes(
  31. const TensorLayout& input, const TensorLayout& flatten_weights,
  32. size_t hidden_size,
  33. size_t D, // num_directions
  34. Handle* handle) {
  35. size_t seq_len = input.shape[0];
  36. size_t batch_size = input.shape[1];
  37. size_t input_size = input.shape[2];
  38. size_t gate_hidden_size = flatten_weights.shape[0];
  39. // concat workspace
  40. TensorLayout direction_output_layout{
  41. TensorShape{seq_len, batch_size, hidden_size}, input.dtype};
  42. TensorLayout output_layout{{seq_len, batch_size, D * hidden_size}, input.dtype};
  43. TensorLayoutArray layer_layouts;
  44. for (size_t i = 0; i < D; ++i)
  45. layer_layouts.push_back(direction_output_layout);
  46. auto concat_opr = handle->create_operator<ConcatForward>();
  47. concat_opr->param().axis = -1;
  48. size_t concat_workspace =
  49. concat_opr->get_workspace_in_bytes(layer_layouts, output_layout);
  50. // cell workspace
  51. TensorLayout weight_ih{{gate_hidden_size, input_size}, flatten_weights.dtype};
  52. TensorLayout D_weight_ih{
  53. {gate_hidden_size, D * hidden_size}, flatten_weights.dtype};
  54. TensorLayout weight_hh{{gate_hidden_size, hidden_size}, flatten_weights.dtype};
  55. TensorLayout bias{{gate_hidden_size}, flatten_weights.dtype};
  56. TensorLayout hx{{batch_size, hidden_size}, input.dtype};
  57. TensorLayout cell_input = {{input.shape[1], input.shape[2]}, input.dtype};
  58. TensorLayout D_cell_input = {{input.shape[1], D * hidden_size}, input.dtype};
  59. size_t cell_workspace = cell_opr_get_workspace_in_bytes<CellOpr>(
  60. cell_input, weight_ih, weight_hh, bias, bias, hx, handle);
  61. size_t D_cell_workspace = cell_opr_get_workspace_in_bytes<CellOpr>(
  62. D_cell_input, D_weight_ih, weight_hh, bias, bias, hx, handle);
  63. return std::max(std::max(cell_workspace, D_cell_workspace), concat_workspace);
  64. }
  65. template <class Cell, typename CellOpr>
  66. void exec_internal(
  67. std::vector<Cell>& cells, _megdnn_tensor_in input, const TensorNDArray& states,
  68. TensorNDArray& states_new, _megdnn_tensor_out output,
  69. _megdnn_tensor_out reserve_space, size_t num_layers, size_t D,
  70. param::RNNCell::NonlineMode nonline_mode, Handle* handle,
  71. _megdnn_workspace workspace) {
  72. size_t seq_len = input.layout.shape[0];
  73. size_t batch_size = input.layout.shape[1];
  74. size_t input_size = input.layout.shape[2];
  75. size_t hidden_size = cells[0].weight_hh.layout.shape[1];
  76. TensorLayout cell_output_layout{
  77. TensorShape{batch_size, hidden_size}, states[0].layout.dtype};
  78. TensorLayout cell_first_input_layout{
  79. TensorShape{batch_size, input_size}, input.layout.dtype};
  80. TensorLayout cell_input_layout{
  81. TensorShape{batch_size, D * hidden_size}, input.layout.dtype};
  82. TensorLayout direction_output_layout{
  83. TensorShape{seq_len, batch_size, hidden_size}, output.layout.dtype};
  84. TensorND tmp_output{workspace.raw_ptr, output.layout};
  85. _megdnn_workspace new_workspace{
  86. workspace.raw_ptr + tmp_output.layout.span().dist_byte(),
  87. workspace.size - tmp_output.layout.span().dist_byte()};
  88. auto cell_opr = handle->create_operator<CellOpr>();
  89. auto copy_opr = handle->create_operator<TypeCvtForward>();
  90. // copy states to states_new
  91. for (size_t i = 0; i < states.size(); ++i)
  92. copy_opr->exec(states[i], states_new[i]);
  93. void* reserve_ptr = reserve_space.raw_ptr();
  94. // layer 1
  95. for (size_t d = 0; d < D; ++d) {
  96. size_t cell_idx = d;
  97. auto& cell = cells[cell_idx];
  98. TensorNDArray cur_states;
  99. size_t states_offset = cell_idx * cell_output_layout.span().dist_byte();
  100. for (size_t i = 0; i < states.size(); ++i) {
  101. cur_states.push_back(TensorND{
  102. static_cast<uint8_t*>(states_new[i].raw_ptr()) + states_offset,
  103. cell_output_layout});
  104. }
  105. for (size_t i = 0; i < seq_len; ++i) {
  106. size_t step = d == 0 ? i : seq_len - 1 - i;
  107. TensorND step_input{
  108. static_cast<uint8_t*>(input.raw_ptr()) +
  109. step * cell_first_input_layout.span().dist_byte(),
  110. cell_first_input_layout};
  111. TensorND step_output{
  112. static_cast<uint8_t*>(output.raw_ptr()) +
  113. (step * D) * cell_output_layout.span().dist_byte() +
  114. d * cell_output_layout.span().dist_byte() / batch_size,
  115. cell_output_layout};
  116. TensorNDArray tmp_states;
  117. for (size_t s = 0; s < cur_states.size(); ++s) {
  118. tmp_states.push_back(TensorND{reserve_ptr, cur_states[s].layout});
  119. size_t size_in_bytes = cur_states[s].layout.span().dist_byte();
  120. reserve_ptr = static_cast<uint8_t*>(reserve_ptr) + size_in_bytes;
  121. }
  122. cell_opr_exec<CellOpr>(
  123. step_input, cell.weight_ih, cell.weight_hh, cell.bias_ih,
  124. cell.bias_hh, cur_states, tmp_states, new_workspace, nonline_mode,
  125. handle);
  126. for (size_t s = 0; s < tmp_states.size(); ++s) {
  127. copy_opr->exec(tmp_states[s], cur_states[s]);
  128. }
  129. TensorLayout half_output_layout{
  130. TensorShape{hidden_size}, states[0].layout.dtype};
  131. if (D == 2) {
  132. for (size_t i = 0; i < batch_size; i++) {
  133. TensorND half_cur_states{
  134. // output
  135. static_cast<uint8_t*>(cur_states[0].raw_ptr()) +
  136. i * half_output_layout.span().dist_byte(),
  137. half_output_layout};
  138. TensorND half_step_output{
  139. static_cast<uint8_t*>(step_output.raw_ptr()) +
  140. i * half_output_layout.span().dist_byte() * 2,
  141. half_output_layout};
  142. copy_opr->exec(half_cur_states, half_step_output);
  143. }
  144. } else
  145. copy_opr->exec(cur_states[0], step_output);
  146. }
  147. }
  148. for (size_t layer = 1; layer < num_layers; ++layer) {
  149. for (size_t d = 0; d < D; ++d) {
  150. size_t cell_idx = layer * D + d;
  151. auto& cell = cells[cell_idx];
  152. TensorNDArray cur_states;
  153. size_t states_offset = cell_idx * cell_output_layout.span().dist_byte();
  154. for (size_t i = 0; i < states.size(); ++i) {
  155. cur_states.push_back(TensorND{
  156. static_cast<uint8_t*>(states_new[i].raw_ptr()) + states_offset,
  157. cell_output_layout});
  158. }
  159. for (size_t i = 0; i < seq_len; ++i) {
  160. size_t step = d == 0 ? i : seq_len - 1 - i;
  161. TensorND step_input{
  162. static_cast<uint8_t*>(output.raw_ptr()) +
  163. step * cell_input_layout.span().dist_byte(),
  164. cell_input_layout};
  165. TensorND step_output{
  166. static_cast<uint8_t*>(tmp_output.raw_ptr()) +
  167. (step * D) * cell_output_layout.span().dist_byte() +
  168. d * cell_output_layout.span().dist_byte() / batch_size,
  169. cell_output_layout};
  170. TensorNDArray tmp_states;
  171. for (size_t s = 0; s < cur_states.size(); ++s) {
  172. tmp_states.push_back(TensorND{reserve_ptr, cur_states[s].layout});
  173. size_t size_in_bytes = cur_states[s].layout.span().dist_byte();
  174. reserve_ptr = static_cast<uint8_t*>(reserve_ptr) + size_in_bytes;
  175. }
  176. cell_opr_exec<CellOpr>(
  177. step_input, cell.weight_ih, cell.weight_hh, cell.bias_ih,
  178. cell.bias_hh, cur_states, tmp_states, new_workspace,
  179. nonline_mode, handle);
  180. // copy states to cur_states
  181. for (size_t s = 0; s < tmp_states.size(); ++s) {
  182. copy_opr->exec(tmp_states[s], cur_states[s]);
  183. }
  184. TensorLayout half_output_layout{
  185. TensorShape{hidden_size}, states[0].layout.dtype};
  186. if (D == 2) {
  187. for (size_t i = 0; i < batch_size; i++) {
  188. TensorND half_cur_states{
  189. // output
  190. static_cast<uint8_t*>(cur_states[0].raw_ptr()) +
  191. i * half_output_layout.span().dist_byte(),
  192. half_output_layout};
  193. TensorND half_step_output{
  194. static_cast<uint8_t*>(step_output.raw_ptr()) +
  195. i * half_output_layout.span().dist_byte() * 2,
  196. half_output_layout};
  197. copy_opr->exec(half_cur_states, half_step_output);
  198. }
  199. } else
  200. copy_opr->exec(cur_states[0], step_output);
  201. }
  202. }
  203. copy_opr->exec(tmp_output, output);
  204. }
  205. }
  206. template <class Cell>
  207. size_t get_cells(
  208. size_t D, size_t num_layers, size_t input_size, size_t hidden_size, bool bias,
  209. std::vector<Cell>& cells, _megdnn_tensor_in flatten_weights,
  210. _megdnn_workspace workspace) {
  211. cells.reserve(D * num_layers);
  212. void* weight_ptr = flatten_weights.raw_ptr();
  213. for (size_t layer = 0; layer < num_layers; ++layer) {
  214. for (size_t d = 0; d < D; ++d) {
  215. size_t cell_input_size = D * hidden_size;
  216. if (layer == 0)
  217. cell_input_size = input_size;
  218. Cell cell(
  219. weight_ptr, hidden_size, cell_input_size, bias,
  220. flatten_weights.layout.dtype, workspace);
  221. weight_ptr =
  222. static_cast<uint8_t*>(weight_ptr) + cell.weight_size_in_bytes();
  223. cells.push_back(cell);
  224. }
  225. }
  226. return cells[0].workspace_size_in_bytes();
  227. }
  228. template <class Cell>
  229. size_t get_inputs_for_exec(
  230. _megdnn_tensor_in x, _megdnn_tensor_in y,
  231. const std::vector<TensorNDArray> unfold_hx, _megdnn_tensor_in reserve_space,
  232. size_t num_layers, size_t D, size_t hidden_size, const std::vector<Cell>& cells,
  233. TensorNDArray& layer_inputs, TensorNDArray& layer_outputs,
  234. std::vector<std::vector<TensorNDArray>>& cell_seq_states,
  235. param::RNNCell::NonlineMode /*nonlineMode*/, _megdnn_workspace workspace) {
  236. // return used workspace size
  237. layer_inputs.push_back(x);
  238. size_t seq_len = x.layout.shape[0];
  239. size_t batch_size = x.layout.shape[1];
  240. size_t num_states = cells[0].num_states();
  241. TensorLayout cell_output_layout{{batch_size, hidden_size}, y.layout.dtype};
  242. TensorLayout direction_output_layout{
  243. {seq_len, batch_size, hidden_size}, y.layout.dtype};
  244. void* workspace_ptr = workspace.raw_ptr;
  245. // extract intermedia states from reserve space
  246. for (size_t layer = 0; layer < num_layers; ++layer) {
  247. TensorND layer_output{workspace_ptr, y.layout};
  248. workspace_ptr = static_cast<uint8_t*>(workspace_ptr) +
  249. layer_output.layout.span().dist_byte();
  250. for (size_t d = 0; d < D; ++d) {
  251. cell_seq_states.push_back(std::vector<TensorNDArray>());
  252. cell_seq_states[cell_seq_states.size() - 1].push_back(
  253. {unfold_hx[layer * d]});
  254. // reverse direction is stored with reversed order of sequence order
  255. for (size_t i = 0; i < seq_len; ++i) {
  256. size_t step = i;
  257. if (d == 1)
  258. step = seq_len - i - 1;
  259. size_t offset = ((layer * D + d) * seq_len + step) *
  260. cell_output_layout.span().dist_byte() * num_states;
  261. TensorNDArray cur_states;
  262. for (size_t s = 0; s < num_states; ++s) {
  263. TensorND h{
  264. static_cast<uint8_t*>(reserve_space.raw_ptr()) + offset +
  265. s * cell_output_layout.span().dist_byte(),
  266. cell_output_layout};
  267. cur_states.push_back(h);
  268. }
  269. TensorND hy{
  270. static_cast<uint8_t*>(reserve_space.raw_ptr()) + offset,
  271. cell_output_layout};
  272. // states
  273. cell_seq_states[cell_seq_states.size() - 1].push_back(cur_states);
  274. // output
  275. offset = i * D * cell_output_layout.span().dist_byte();
  276. memcpy(static_cast<uint8_t*>(layer_output.raw_ptr()) + offset,
  277. hy.raw_ptr(), hy.layout.span().dist_byte());
  278. }
  279. }
  280. layer_outputs.push_back(layer_output);
  281. if (layer != num_layers - 1)
  282. layer_inputs.push_back(layer_output);
  283. }
  284. return static_cast<uint8_t*>(workspace_ptr) -
  285. static_cast<uint8_t*>((void*)workspace.raw_ptr);
  286. }
  287. template <class Cell>
  288. void backward_exec_internal(
  289. std::vector<Cell>& cells, size_t D, size_t num_layers, size_t input_size,
  290. bool bias, param::RNNCell::NonlineMode nonlineMode,
  291. const TensorNDArray& layer_inputs, const TensorNDArray& layer_outputs,
  292. const std::vector<std::vector<TensorNDArray>>& cell_seq_states,
  293. _megdnn_tensor_in dy, const TensorNDArray& dhy, _megdnn_tensor_out dx,
  294. TensorNDArray& dstates, _megdnn_tensor_out dw, Handle* handle,
  295. _megdnn_workspace workspace) {
  296. /*
  297. layer_inputs: array of input of each layer, element 0: [seq_len, batch_size,
  298. input_size], element others: [seq_len, batch_size, D * hidden_size]
  299. layer_outputs: array of outputs of each rnn. To access outputs of the cell at
  300. (layer, d), use layer_outputs[layer]. The shape is [seq_len, batch_size,
  301. output_size(D*hidden_size)] (in sequence order) cell_seq_states: arrray of states
  302. of each cell at each step. To access the states of the cell at (layer, d) at
  303. sequence step (step), use cell_seq_states[layer*D + d][step]
  304. */
  305. size_t seq_len = layer_inputs[0].layout.shape[0];
  306. size_t batch_size = layer_inputs[0].layout.shape[1];
  307. DType dtype = layer_inputs[0].layout.dtype;
  308. size_t cell_y_size = layer_outputs[0].layout.shape[2] / D;
  309. size_t hidden_size = cell_y_size;
  310. TensorLayout cell_y_layout = {{batch_size, cell_y_size}, dtype};
  311. void* workspace_ptr = workspace.raw_ptr;
  312. TensorND layer_output_grad{
  313. workspace_ptr, {{seq_len, batch_size, D * hidden_size}, dtype}};
  314. workspace_ptr = static_cast<uint8_t*>(workspace_ptr) +
  315. layer_output_grad.layout.span().dist_byte();
  316. memcpy(layer_output_grad.raw_ptr(), dy.raw_ptr(), dy.layout.span().dist_byte());
  317. TensorNDArray direction_dx_arr;
  318. for (size_t i = 0; i < D; ++i) {
  319. TensorLayout direction_dx_layout{{seq_len, batch_size, hidden_size}, dtype};
  320. direction_dx_arr.push_back(TensorND(workspace_ptr, direction_dx_layout));
  321. workspace_ptr = static_cast<uint8_t*>(workspace_ptr) +
  322. direction_dx_layout.span().dist_byte();
  323. }
  324. TensorNDArray L0_direction_dx_arr;
  325. for (size_t i = 0; i < D; ++i) {
  326. TensorLayout direction_dx_layout{{seq_len, batch_size, input_size}, dtype};
  327. L0_direction_dx_arr.push_back(TensorND(workspace_ptr, direction_dx_layout));
  328. workspace_ptr = static_cast<uint8_t*>(workspace_ptr) +
  329. direction_dx_layout.span().dist_byte();
  330. }
  331. std::vector<TensorNDArray> dstates_arr;
  332. for (size_t layer = 0; layer < num_layers; ++layer) {
  333. for (size_t d = 0; d < D; ++d) {
  334. TensorNDArray cell_states;
  335. cell_states.reserve(dstates.size());
  336. for (size_t i = 0; i < dstates.size(); ++i) {
  337. size_t offset = (layer * D + d) * cell_y_layout.span().dist_byte();
  338. TensorND dhx_cell{
  339. static_cast<uint8_t*>(dstates[i].raw_ptr()) + offset,
  340. cell_y_layout};
  341. memcpy(dhx_cell.raw_ptr(),
  342. static_cast<uint8_t*>(dhy[i].raw_ptr()) + offset,
  343. cell_y_layout.span().dist_byte());
  344. cell_states.emplace_back(dhx_cell);
  345. }
  346. dstates_arr.push_back(cell_states);
  347. }
  348. }
  349. memset(dw.raw_ptr(), 0, dw.layout.span().dist_byte());
  350. std::vector<Cell> cell_grads;
  351. size_t used_workspace_size = static_cast<uint8_t*>(workspace_ptr) -
  352. static_cast<uint8_t*>((void*)(workspace.raw_ptr));
  353. workspace_ptr =
  354. static_cast<uint8_t*>(workspace_ptr) +
  355. get_cells(
  356. D, num_layers, input_size, hidden_size, bias, cell_grads, dw,
  357. Workspace(
  358. workspace.raw_ptr + used_workspace_size,
  359. workspace.size - used_workspace_size));
  360. auto add_opr = handle->create_operator<ElemwiseForward>();
  361. add_opr->param().mode = Elemwise::Mode::ADD;
  362. auto copy_opr = handle->create_operator<TypeCvtForward>();
  363. // initialize dx to zero
  364. memset(dx.raw_ptr(), 0, dx.layout.span().dist_byte());
  365. // calculate grads
  366. for (int layer = (int)num_layers - 1; layer >= 0; --layer) {
  367. for (int d = (int)D - 1; d >= 0; --d) {
  368. Cell& cell = cells[layer * D + d];
  369. Cell& cell_grad = cell_grads[layer * D + d];
  370. size_t input_size = layer_inputs[layer].layout.shape[2];
  371. const TensorND& x_arr = layer_inputs[layer];
  372. const TensorND& y_arr = layer_outputs[layer];
  373. TensorLayout x_layout = {{batch_size, input_size}, dtype};
  374. // tmp tensors
  375. void* tmp_workspace_ptr = workspace_ptr;
  376. TensorND dwi_tmp{tmp_workspace_ptr, cell_grad.weight_ih.layout};
  377. tmp_workspace_ptr = static_cast<uint8_t*>(tmp_workspace_ptr) +
  378. dwi_tmp.layout.span().dist_byte();
  379. TensorND dwh_tmp{tmp_workspace_ptr, cell_grad.weight_hh.layout};
  380. tmp_workspace_ptr = static_cast<uint8_t*>(tmp_workspace_ptr) +
  381. dwh_tmp.layout.span().dist_byte();
  382. TensorND dbias_tmp{tmp_workspace_ptr, cell_grad.bias_ih.layout};
  383. tmp_workspace_ptr = static_cast<uint8_t*>(tmp_workspace_ptr) +
  384. dbias_tmp.layout.span().dist_byte();
  385. size_t used_workspace_size =
  386. static_cast<uint8_t*>(tmp_workspace_ptr) -
  387. static_cast<uint8_t*>((void*)(workspace.raw_ptr));
  388. for (size_t i = 0; i < seq_len; ++i) {
  389. size_t step = i;
  390. if (d == 0)
  391. step = seq_len - i - 1;
  392. TensorND x{
  393. static_cast<uint8_t*>(x_arr.raw_ptr()) +
  394. step * x_layout.span().dist_byte(),
  395. x_layout},
  396. y{static_cast<uint8_t*>(y_arr.raw_ptr()) +
  397. (step * D + d) * cell_y_layout.span().dist_byte(),
  398. cell_y_layout};
  399. const TensorNDArray& cell_states = cell_seq_states[layer * D + d][step];
  400. TensorNDArray& dstates_new = dstates_arr[layer * D + d];
  401. TensorND dy_t{
  402. static_cast<uint8_t*>(layer_output_grad.raw_ptr()) +
  403. (step * D + d) * cell_y_layout.span().dist_byte(),
  404. cell_y_layout};
  405. add_opr->exec({dstates_new[0], dy_t}, dy_t);
  406. TensorND dx_t;
  407. if (layer == 0)
  408. dx_t = {static_cast<uint8_t*>(L0_direction_dx_arr[d].raw_ptr()) +
  409. step * x_layout.span().dist_byte(),
  410. x_layout};
  411. else
  412. dx_t = {static_cast<uint8_t*>(direction_dx_arr[d].raw_ptr()) +
  413. step * x_layout.span().dist_byte(),
  414. x_layout};
  415. TensorNDArray douts = {dy_t};
  416. for (size_t s = 1; s < dstates_new.size(); ++s)
  417. douts.push_back(dstates_new[s]);
  418. cell.backward(
  419. handle, nonlineMode, x, cell_states, y, douts, dx_t,
  420. dstates_new, dwi_tmp, dwh_tmp, dbias_tmp,
  421. Workspace(
  422. workspace.raw_ptr + used_workspace_size,
  423. workspace.size - used_workspace_size));
  424. // add step gradient to overall gradient
  425. add_opr->exec({dwi_tmp, cell_grad.weight_ih}, cell_grad.weight_ih);
  426. add_opr->exec({dwh_tmp, cell_grad.weight_hh}, cell_grad.weight_hh);
  427. add_opr->exec({dbias_tmp, cell_grad.bias_ih}, cell_grad.bias_ih);
  428. add_opr->exec({dbias_tmp, cell_grad.bias_hh}, cell_grad.bias_hh);
  429. }
  430. }
  431. // add gradient of different directions to layer_output_grad.
  432. if (layer == 0) {
  433. for (size_t i = 0; i < D; ++i)
  434. add_opr->exec({L0_direction_dx_arr[i], dx}, dx);
  435. } else {
  436. if (D == 1)
  437. copy_opr->exec(direction_dx_arr[0], layer_output_grad);
  438. else {
  439. for (size_t t = 0; t < seq_len; ++t) {
  440. size_t offset = t * D * cell_y_layout.span().dist_byte();
  441. for (size_t d = 0; d < D; ++d) {
  442. TensorND src{
  443. static_cast<uint8_t*>(direction_dx_arr[d].raw_ptr()) +
  444. offset,
  445. cell_y_layout};
  446. TensorND dst{
  447. static_cast<uint8_t*>(layer_output_grad.raw_ptr()) +
  448. offset + d * cell_y_layout.span().dist_byte(),
  449. cell_y_layout};
  450. copy_opr->exec(src, dst);
  451. }
  452. }
  453. }
  454. }
  455. }
  456. }
  457. } // namespace rnn
  458. } // namespace naive
  459. } // namespace megdnn
  460. // #include "funcs.tpp"
  461. // #endif