You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rnn.py 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import math
  10. import numbers
  11. from abc import abstractmethod
  12. from typing import Optional, Tuple
  13. import numpy as np
  14. from ..core._imperative_rt.core2 import apply
  15. from ..core.ops import builtin
  16. from ..core.ops.builtin import BatchNorm
  17. from ..functional import stack, zeros
  18. from ..tensor import Parameter, Tensor
  19. from . import init
  20. from .module import Module
  21. class RNNCellBase(Module):
  22. def __init__(
  23. self, input_size: int, hidden_size: int, bias: bool, num_chunks: int,
  24. ) -> None:
  25. # num_chunks indicates the number of gates
  26. super(RNNCellBase, self).__init__()
  27. self.input_size = input_size
  28. self.hidden_size = hidden_size
  29. self.bias = bias
  30. # initialize weights
  31. self.gate_hidden_size = num_chunks * hidden_size
  32. self.weight_ih = Parameter(
  33. np.zeros((self.gate_hidden_size, input_size), dtype=np.float32)
  34. )
  35. self.weight_hh = Parameter(
  36. np.zeros((self.gate_hidden_size, hidden_size), dtype=np.float32)
  37. )
  38. if bias:
  39. self.bias_ih = Parameter(
  40. np.zeros((self.gate_hidden_size), dtype=np.float32)
  41. )
  42. self.bias_hh = Parameter(
  43. np.zeros((self.gate_hidden_size), dtype=np.float32)
  44. )
  45. else:
  46. self.bias_ih = zeros(shape=(self.gate_hidden_size))
  47. self.bias_hh = zeros(shape=(self.gate_hidden_size))
  48. self.reset_parameters()
  49. # if bias is False self.bias will remain zero
  50. def reset_parameters(self) -> None:
  51. stdv = 1.0 / math.sqrt(self.hidden_size)
  52. for weight in self.parameters():
  53. init.uniform_(weight, -stdv, stdv)
  54. @abstractmethod
  55. def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
  56. raise NotImplementedError("forward not implemented !")
  57. class RNNCell(RNNCellBase):
  58. r"""An Elman RNN cell with tanh or ReLU non-linearity.
  59. .. math::
  60. h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh})
  61. If :attr:`nonlinearity` is `'relu'`, then ReLU is used in place of tanh.
  62. Args:
  63. input_size: The number of expected features in the input `x`
  64. hidden_size: The number of features in the hidden state `h`
  65. bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
  66. Default: ``True``
  67. nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
  68. Inputs: input, hidden
  69. - **input** of shape `(batch, input_size)`: tensor containing input features
  70. - **hidden** of shape `(batch, hidden_size)`: tensor containing the initial hidden
  71. state for each element in the batch.
  72. Defaults to zero if not provided.
  73. Outputs: h'
  74. - **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state
  75. for each element in the batch
  76. Shape:
  77. - Input1: :math:`(N, H_{in})` tensor containing input features where
  78. :math:`H_{in}` = `input_size`
  79. - Input2: :math:`(N, H_{out})` tensor containing the initial hidden
  80. state for each element in the batch where :math:`H_{out}` = `hidden_size`
  81. Defaults to zero if not provided.
  82. - Output: :math:`(N, H_{out})` tensor containing the next hidden state
  83. for each element in the batch
  84. Examples:
  85. .. code-block::
  86. import numpy as np
  87. import megengine as mge
  88. import megengine.module as M
  89. m = M.RNNCell(10, 20)
  90. inp = mge.tensor(np.random.randn(3, 10), dtype=np.float32)
  91. hx = mge.tensor(np.random.randn(3, 20), dtype=np.float32)
  92. out = m(inp, hx)
  93. print(out.numpy().shape)
  94. Outputs:
  95. .. code-block::
  96. (3, 20)
  97. """
  98. def __init__(
  99. self,
  100. input_size: int,
  101. hidden_size: int,
  102. bias: bool = True,
  103. nonlinearity: str = "tanh",
  104. ) -> None:
  105. self.nonlinearity = nonlinearity
  106. super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1)
  107. def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
  108. if hx is None:
  109. hx = zeros(shape=(input.shape[0], self.gate_hidden_size),)
  110. op = builtin.RNNCell(nonlineMode=self.nonlinearity)
  111. return apply(
  112. op, input, self.weight_ih, self.bias_ih, hx, self.weight_hh, self.bias_hh
  113. )[0]
  114. class LSTMCell(RNNCellBase):
  115. r"""A long short-term memory (LSTM) cell.
  116. .. math::
  117. \begin{array}{ll}
  118. i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
  119. f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
  120. g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
  121. o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
  122. c' = f * c + i * g \\
  123. h' = o * \tanh(c') \\
  124. \end{array}
  125. where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
  126. Args:
  127. input_size: The number of expected features in the input `x`
  128. hidden_size: The number of features in the hidden state `h`
  129. bias: If ``False``, then the layer does not use bias weights `b_ih` and
  130. `b_hh`. Default: ``True``
  131. Inputs: input, (h_0, c_0)
  132. - **input** of shape `(batch, input_size)`: tensor containing input features
  133. - **h_0** of shape `(batch, hidden_size)`: tensor containing the initial hidden
  134. state for each element in the batch.
  135. - **c_0** of shape `(batch, hidden_size)`: tensor containing the initial cell state
  136. for each element in the batch.
  137. If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero.
  138. Outputs: (h_1, c_1)
  139. - **h_1** of shape `(batch, hidden_size)`: tensor containing the next hidden state
  140. for each element in the batch
  141. - **c_1** of shape `(batch, hidden_size)`: tensor containing the next cell state
  142. for each element in the batch
  143. Examples:
  144. .. code-block::
  145. import numpy as np
  146. import megengine as mge
  147. import megengine.module as M
  148. m = M.LSTMCell(10, 20)
  149. inp = mge.tensor(np.random.randn(3, 10), dtype=np.float32)
  150. hx = mge.tensor(np.random.randn(3, 20), dtype=np.float32)
  151. cx = mge.tensor(np.random.randn(3, 20), dtype=np.float32)
  152. hy, cy = m(inp, (hx, cx))
  153. print(hy.numpy().shape)
  154. print(cy.numpy().shape)
  155. Outputs:
  156. .. code-block::
  157. (3, 20)
  158. (3, 20)
  159. """
  160. def __init__(self, input_size: int, hidden_size: int, bias: bool = True,) -> None:
  161. super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4)
  162. def forward(
  163. self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
  164. ) -> Tuple[Tensor, Tensor]:
  165. # hx: (h, c)
  166. if hx is None:
  167. h = zeros(shape=(input.shape[0], self.hidden_size))
  168. c = zeros(shape=(input.shape[0], self.hidden_size))
  169. else:
  170. h, c = hx
  171. op = builtin.LSTMCell()
  172. return apply(
  173. op, input, self.weight_ih, self.bias_ih, h, self.weight_hh, self.bias_hh, c
  174. )[:2]
  175. class RNNBase(Module):
  176. def __init__(
  177. self,
  178. input_size: int,
  179. hidden_size: int,
  180. num_layers: int = 1,
  181. bias: bool = True,
  182. batch_first: bool = False,
  183. dropout: float = 0,
  184. bidirectional: bool = False,
  185. proj_size: int = 0,
  186. ) -> None:
  187. super(RNNBase, self).__init__()
  188. self.input_size = input_size
  189. self.hidden_size = hidden_size
  190. self.num_layers = num_layers
  191. self.bias = bias
  192. self.batch_first = batch_first
  193. self.dropout = float(dropout)
  194. self.bidirectional = bidirectional
  195. self.num_directions = 2 if self.bidirectional else 1
  196. self.proj_size = proj_size
  197. # check validity of dropout
  198. if (
  199. not isinstance(dropout, numbers.Number)
  200. or not 0 <= dropout <= 1
  201. or isinstance(dropout, bool)
  202. ):
  203. raise ValueError(
  204. "Dropout should be a float in [0, 1], which indicates the probability "
  205. "of an element to be zero"
  206. )
  207. if proj_size < 0:
  208. raise ValueError(
  209. "proj_size should be a positive integer or zero to disable projections"
  210. )
  211. elif proj_size >= hidden_size:
  212. raise ValueError("proj_size has to be smaller than hidden_size")
  213. self.cells = []
  214. for layer in range(self.num_layers):
  215. self.cells.append([])
  216. for _ in range(self.num_directions):
  217. self.cells[layer].append(self.create_cell(layer))
  218. # parameters have been initialized during the creation of the cells
  219. # if flatten, then delete cells
  220. self._flatten_parameters(self.cells)
  221. def _flatten_parameters(self, cells):
  222. gate_hidden_size = cells[0][0].gate_hidden_size
  223. size_dim1 = 0
  224. for layer in range(self.num_layers):
  225. for direction in range(self.num_directions):
  226. size_dim1 += cells[layer][direction].weight_ih.shape[1]
  227. size_dim1 += cells[layer][direction].weight_hh.shape[1]
  228. if self.bias:
  229. size_dim1 += 2 * self.num_directions * self.num_layers
  230. self._flatten_weights = Parameter(
  231. np.zeros((gate_hidden_size, size_dim1), dtype=np.float32)
  232. )
  233. self.reset_parameters()
  234. def reset_parameters(self) -> None:
  235. stdv = 1.0 / math.sqrt(self.hidden_size)
  236. for weight in self.parameters():
  237. init.uniform_(weight, -stdv, stdv)
  238. @abstractmethod
  239. def create_cell(self, layer):
  240. raise NotImplementedError("Cell not implemented !")
  241. @abstractmethod
  242. def init_hidden(self):
  243. raise NotImplementedError("init_hidden not implemented !")
  244. @abstractmethod
  245. def get_output_from_hidden(self, hx):
  246. raise NotImplementedError("get_output_from_hidden not implemented !")
  247. @abstractmethod
  248. def apply_op(self, input, hx):
  249. raise NotImplementedError("apply_op not implemented !")
  250. def _apply_fn_to_hx(self, hx, fn):
  251. return fn(hx)
  252. def _stack_h_n(self, h_n):
  253. return stack(h_n, axis=0)
  254. def forward(self, input: Tensor, hx=None):
  255. if self.batch_first:
  256. batch_size = input.shape[0]
  257. input = input.transpose((1, 0, 2)) # [seq_len, batch_size, dim]
  258. else:
  259. batch_size = input.shape[1]
  260. if hx is None:
  261. hx = self.init_hidden(batch_size)
  262. output, h = self.apply_op(input, hx)
  263. if self.batch_first:
  264. output = output.transpose((1, 0, 2))
  265. return output, h
  266. class RNN(RNNBase):
  267. r"""Applies a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to an
  268. input sequence.
  269. For each element in the input sequence, each layer computes the following
  270. function:
  271. .. math::
  272. h_t = \tanh(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh})
  273. where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is
  274. the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the
  275. previous layer at time `t-1` or the initial hidden state at time `0`.
  276. If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`.
  277. Args:
  278. input_size: The number of expected features in the input `x`
  279. hidden_size: The number of features in the hidden state `h`
  280. num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
  281. would mean stacking two RNNs together to form a `stacked RNN`,
  282. with the second RNN taking in outputs of the first RNN and
  283. computing the final results. Default: 1
  284. nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
  285. bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
  286. Default: ``True``
  287. batch_first: If ``True``, then the input and output tensors are provided
  288. as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
  289. Note that this does not apply to hidden or cell states. See the
  290. Inputs/Outputs sections below for details. Default: ``False``
  291. dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
  292. RNN layer except the last layer, with dropout probability equal to
  293. :attr:`dropout`. Default: 0
  294. bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False``
  295. Inputs: input, h_0
  296. * **input**: tensor of shape :math:`(L, N, H_{in})` when ``batch_first=False`` or
  297. :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
  298. the input sequence. The input can also be a packed variable length sequence.
  299. See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
  300. :func:`torch.nn.utils.rnn.pack_sequence` for details.
  301. * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hidden
  302. state for each element in the batch. Defaults to zeros if not provided.
  303. where:
  304. .. math::
  305. \begin{aligned}
  306. N ={} & \text{batch size} \\
  307. L ={} & \text{sequence length} \\
  308. D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
  309. H_{in} ={} & \text{input\_size} \\
  310. H_{out} ={} & \text{hidden\_size}
  311. \end{aligned}
  312. Outputs: output, h_n
  313. * **output**: tensor of shape :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
  314. :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
  315. `(h_t)` from the last layer of the RNN, for each `t`. If a
  316. :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
  317. will also be a packed sequence.
  318. * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state
  319. for each element in the batch.
  320. Examples:
  321. .. code-block::
  322. import numpy as np
  323. import megengine as mge
  324. import megengine.module as M
  325. m = M.RNN(10,20,2,batch_first=False,nonlinearity="relu",bias=True,bidirectional=True)
  326. inp = mge.tensor(np.random.randn(6, 30, 10), dtype=np.float32)
  327. hx = mge.tensor(np.random.randn(4, 30, 20), dtype=np.float32)
  328. out, hn = m(inp, hx)
  329. print(out.numpy().shape)
  330. Outputs:
  331. .. code-block::
  332. (6, 30, 40)
  333. """
  334. def __init__(self, *args, **kwargs) -> None:
  335. self.nonlinearity = kwargs.pop("nonlinearity", "tanh")
  336. super(RNN, self).__init__(*args, **kwargs)
  337. def create_cell(self, layer):
  338. if layer == 0:
  339. input_size = self.input_size
  340. else:
  341. input_size = self.num_directions * self.hidden_size
  342. return RNNCell(input_size, self.hidden_size, self.bias, self.nonlinearity)
  343. def init_hidden(self, batch_size):
  344. hidden_shape = (
  345. self.num_directions * self.num_layers,
  346. batch_size,
  347. self.hidden_size,
  348. )
  349. return zeros(shape=hidden_shape)
  350. def get_output_from_hidden(self, hx):
  351. return hx
  352. def apply_op(self, input, hx):
  353. fwd_mode = (
  354. BatchNorm.FwdMode.TRAINING if self.training else BatchNorm.FwdMode.INFERENCE
  355. )
  356. op = builtin.RNN(
  357. num_layers=self.num_layers,
  358. bidirectional=self.bidirectional,
  359. bias=self.bias,
  360. hidden_size=self.hidden_size,
  361. dropout=self.dropout,
  362. nonlineMode=self.nonlinearity,
  363. fwd_mode=fwd_mode,
  364. )
  365. output, h = apply(op, input, hx, self._flatten_weights)[:2]
  366. output = output + h.sum() * 0
  367. h = h + output.sum() * 0
  368. return output, h
  369. class LSTM(RNNBase):
  370. r"""Applies a multi-layer long short-term memory LSTM to an input
  371. sequence.
  372. For each element in the input sequence, each layer computes the following
  373. function:
  374. .. math::
  375. \begin{array}{ll} \\
  376. i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
  377. f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
  378. g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
  379. o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
  380. c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
  381. h_t = o_t \odot \tanh(c_t) \\
  382. \end{array}
  383. where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell
  384. state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}`
  385. is the hidden state of the layer at time `t-1` or the initial hidden
  386. state at time `0`, and :math:`i_t`, :math:`f_t`, :math:`g_t`,
  387. :math:`o_t` are the input, forget, cell, and output gates, respectively.
  388. :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
  389. In a multilayer LSTM, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
  390. (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
  391. dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
  392. variable which is :math:`0` with probability :attr:`dropout`.
  393. If ``proj_size > 0`` is specified, LSTM with projections will be used. This changes
  394. the LSTM cell in the following way. First, the dimension of :math:`h_t` will be changed from
  395. ``hidden_size`` to ``proj_size`` (dimensions of :math:`W_{hi}` will be changed accordingly).
  396. Second, the output hidden state of each layer will be multiplied by a learnable projection
  397. matrix: :math:`h_t = W_{hr}h_t`. Note that as a consequence of this, the output
  398. of LSTM network will be of different shape as well. See Inputs/Outputs sections below for exact
  399. dimensions of all variables. You can find more details in https://arxiv.org/abs/1402.1128.
  400. Args:
  401. input_size: The number of expected features in the input `x`
  402. hidden_size: The number of features in the hidden state `h`
  403. num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
  404. would mean stacking two LSTMs together to form a `stacked LSTM`,
  405. with the second LSTM taking in outputs of the first LSTM and
  406. computing the final results. Default: 1
  407. bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
  408. Default: ``True``
  409. batch_first: If ``True``, then the input and output tensors are provided
  410. as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
  411. Note that this does not apply to hidden or cell states. See the
  412. Inputs/Outputs sections below for details. Default: ``False``
  413. dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
  414. LSTM layer except the last layer, with dropout probability equal to
  415. :attr:`dropout`. Default: 0
  416. bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False``
  417. proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0
  418. Inputs: input, (h_0, c_0)
  419. * **input**: tensor of shape :math:`(L, N, H_{in})` when ``batch_first=False`` or
  420. :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
  421. the input sequence. The input can also be a packed variable length sequence.
  422. See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
  423. :func:`torch.nn.utils.rnn.pack_sequence` for details.
  424. * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the
  425. initial hidden state for each element in the batch.
  426. Defaults to zeros if (h_0, c_0) is not provided.
  427. * **c_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{cell})` containing the
  428. initial cell state for each element in the batch.
  429. Defaults to zeros if (h_0, c_0) is not provided.
  430. where:
  431. .. math::
  432. \begin{aligned}
  433. N ={} & \text{batch size} \\
  434. L ={} & \text{sequence length} \\
  435. D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
  436. H_{in} ={} & \text{input\_size} \\
  437. H_{cell} ={} & \text{hidden\_size} \\
  438. H_{out} ={} & \text{proj\_size if } \text{proj\_size}>0 \text{ otherwise hidden\_size} \\
  439. \end{aligned}
  440. Outputs: output, (h_n, c_n)
  441. * **output**: tensor of shape :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
  442. :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
  443. `(h_t)` from the last layer of the LSTM, for each `t`. If a
  444. :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
  445. will also be a packed sequence.
  446. * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the
  447. final hidden state for each element in the batch.
  448. * **c_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{cell})` containing the
  449. final cell state for each element in the batch.
  450. Examples:
  451. .. code-block::
  452. import numpy as np
  453. import megengine as mge
  454. import megengine.module as M
  455. m = M.LSTM(10, 20, 2, batch_first=False, bidirectional=True, bias=True)
  456. inp = mge.tensor(np.random.randn(6, 30, 10), dtype=np.float32)
  457. hx = mge.tensor(np.random.randn(4, 30, 20), dtype=np.float32)
  458. cx = mge.tensor(np.random.randn(4, 30, 20), dtype=np.float32)
  459. out, (hn, cn) = m(inp,(hx,cx))
  460. print(out.numpy().shape)
  461. Outputs:
  462. .. code-block::
  463. (6, 30, 40)
  464. """
  465. def __init__(self, *args, **kwargs) -> None:
  466. super(LSTM, self).__init__(*args, **kwargs)
  467. def create_cell(self, layer):
  468. if layer == 0:
  469. input_size = self.input_size
  470. else:
  471. input_size = self.num_directions * self.hidden_size
  472. return LSTMCell(input_size, self.hidden_size, self.bias)
  473. def init_hidden(self, batch_size):
  474. hidden_shape = (
  475. self.num_directions * self.num_layers,
  476. batch_size,
  477. self.hidden_size,
  478. )
  479. h = zeros(shape=hidden_shape)
  480. c = zeros(shape=hidden_shape)
  481. return (h, c)
  482. def get_output_from_hidden(self, hx):
  483. return hx[0]
  484. def apply_op(self, input, hx):
  485. fwd_mode = (
  486. BatchNorm.FwdMode.TRAINING if self.training else BatchNorm.FwdMode.INFERENCE
  487. )
  488. op = builtin.LSTM(
  489. num_layers=self.num_layers,
  490. bidirectional=self.bidirectional,
  491. bias=self.bias,
  492. hidden_size=self.hidden_size,
  493. proj_size=self.proj_size,
  494. dropout=self.dropout,
  495. fwd_mode=fwd_mode,
  496. )
  497. output, h, c = apply(op, input, hx[0], hx[1], self._flatten_weights)[:3]
  498. placeholders = [output.sum() * 0, h.sum() * 0, c.sum() * 0]
  499. output = output + placeholders[1] + placeholders[2]
  500. h = h + placeholders[0] + placeholders[2]
  501. c = c + placeholders[0] + placeholders[1]
  502. return output, (h, c)
  503. def _apply_fn_to_hx(self, hx, fn):
  504. return (fn(hx[0]), fn(hx[1]))
  505. def _stack_h_n(self, h_n):
  506. h = [tup[0] for tup in h_n]
  507. c = [tup[1] for tup in h_n]
  508. return (stack(h, axis=0), stack(c, axis=0))