You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rnn.py 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622
  1. # -*- coding: utf-8 -*-
  2. import math
  3. import numbers
  4. from abc import abstractmethod
  5. from typing import Optional, Tuple
  6. import numpy as np
  7. from ..core._imperative_rt.core2 import apply
  8. from ..core.ops import builtin
  9. from ..core.ops.builtin import BatchNorm
  10. from ..functional import stack, zeros
  11. from ..tensor import Parameter, Tensor
  12. from . import init
  13. from .module import Module
  14. class RNNCellBase(Module):
  15. def __init__(
  16. self, input_size: int, hidden_size: int, bias: bool, num_chunks: int,
  17. ) -> None:
  18. # num_chunks indicates the number of gates
  19. super(RNNCellBase, self).__init__()
  20. self.input_size = input_size
  21. self.hidden_size = hidden_size
  22. self.bias = bias
  23. # initialize weights
  24. self.gate_hidden_size = num_chunks * hidden_size
  25. self.weight_ih = Parameter(
  26. np.zeros((self.gate_hidden_size, input_size), dtype=np.float32)
  27. )
  28. self.weight_hh = Parameter(
  29. np.zeros((self.gate_hidden_size, hidden_size), dtype=np.float32)
  30. )
  31. if bias:
  32. self.bias_ih = Parameter(
  33. np.zeros((self.gate_hidden_size), dtype=np.float32)
  34. )
  35. self.bias_hh = Parameter(
  36. np.zeros((self.gate_hidden_size), dtype=np.float32)
  37. )
  38. else:
  39. self.bias_ih = zeros(shape=(self.gate_hidden_size))
  40. self.bias_hh = zeros(shape=(self.gate_hidden_size))
  41. self.reset_parameters()
  42. # if bias is False self.bias will remain zero
  43. def reset_parameters(self) -> None:
  44. stdv = 1.0 / math.sqrt(self.hidden_size)
  45. for weight in self.parameters():
  46. init.uniform_(weight, -stdv, stdv)
  47. @abstractmethod
  48. def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
  49. raise NotImplementedError("forward not implemented !")
  50. class RNNCell(RNNCellBase):
  51. r"""An Elman RNN cell with tanh or ReLU non-linearity.
  52. .. math::
  53. h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh})
  54. If :attr:`nonlinearity` is `'relu'`, then ReLU is used in place of tanh.
  55. Args:
  56. input_size: The number of expected features in the input `x`
  57. hidden_size: The number of features in the hidden state `h`
  58. bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
  59. Default: ``True``
  60. nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
  61. Inputs: input, hidden
  62. - **input** of shape `(batch, input_size)`: tensor containing input features
  63. - **hidden** of shape `(batch, hidden_size)`: tensor containing the initial hidden
  64. state for each element in the batch.
  65. Defaults to zero if not provided.
  66. Outputs: h'
  67. - **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state
  68. for each element in the batch
  69. Shape:
  70. - Input1: :math:`(N, H_{in})` tensor containing input features where
  71. :math:`H_{in}` = `input_size`
  72. - Input2: :math:`(N, H_{out})` tensor containing the initial hidden
  73. state for each element in the batch where :math:`H_{out}` = `hidden_size`
  74. Defaults to zero if not provided.
  75. - Output: :math:`(N, H_{out})` tensor containing the next hidden state
  76. for each element in the batch
  77. Examples:
  78. .. code-block::
  79. import numpy as np
  80. import megengine as mge
  81. import megengine.module as M
  82. m = M.RNNCell(10, 20)
  83. inp = mge.tensor(np.random.randn(3, 10), dtype=np.float32)
  84. hx = mge.tensor(np.random.randn(3, 20), dtype=np.float32)
  85. out = m(inp, hx)
  86. print(out.numpy().shape)
  87. Outputs:
  88. .. code-block::
  89. (3, 20)
  90. """
  91. def __init__(
  92. self,
  93. input_size: int,
  94. hidden_size: int,
  95. bias: bool = True,
  96. nonlinearity: str = "tanh",
  97. ) -> None:
  98. self.nonlinearity = nonlinearity
  99. super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1)
  100. def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
  101. if hx is None:
  102. hx = zeros(shape=(input.shape[0], self.gate_hidden_size),)
  103. op = builtin.RNNCell(nonlineMode=self.nonlinearity)
  104. return apply(
  105. op, input, self.weight_ih, self.bias_ih, hx, self.weight_hh, self.bias_hh
  106. )[0]
  107. class LSTMCell(RNNCellBase):
  108. r"""A long short-term memory (LSTM) cell.
  109. .. math::
  110. \begin{array}{ll}
  111. i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
  112. f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
  113. g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
  114. o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
  115. c' = f * c + i * g \\
  116. h' = o * \tanh(c') \\
  117. \end{array}
  118. where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
  119. Args:
  120. input_size: The number of expected features in the input `x`
  121. hidden_size: The number of features in the hidden state `h`
  122. bias: If ``False``, then the layer does not use bias weights `b_ih` and
  123. `b_hh`. Default: ``True``
  124. Inputs: input, (h_0, c_0)
  125. - **input** of shape `(batch, input_size)`: tensor containing input features
  126. - **h_0** of shape `(batch, hidden_size)`: tensor containing the initial hidden
  127. state for each element in the batch.
  128. - **c_0** of shape `(batch, hidden_size)`: tensor containing the initial cell state
  129. for each element in the batch.
  130. If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero.
  131. Outputs: (h_1, c_1)
  132. - **h_1** of shape `(batch, hidden_size)`: tensor containing the next hidden state
  133. for each element in the batch
  134. - **c_1** of shape `(batch, hidden_size)`: tensor containing the next cell state
  135. for each element in the batch
  136. Examples:
  137. .. code-block::
  138. import numpy as np
  139. import megengine as mge
  140. import megengine.module as M
  141. m = M.LSTMCell(10, 20)
  142. inp = mge.tensor(np.random.randn(3, 10), dtype=np.float32)
  143. hx = mge.tensor(np.random.randn(3, 20), dtype=np.float32)
  144. cx = mge.tensor(np.random.randn(3, 20), dtype=np.float32)
  145. hy, cy = m(inp, (hx, cx))
  146. print(hy.numpy().shape)
  147. print(cy.numpy().shape)
  148. Outputs:
  149. .. code-block::
  150. (3, 20)
  151. (3, 20)
  152. """
  153. def __init__(self, input_size: int, hidden_size: int, bias: bool = True,) -> None:
  154. super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4)
  155. def forward(
  156. self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
  157. ) -> Tuple[Tensor, Tensor]:
  158. # hx: (h, c)
  159. if hx is None:
  160. h = zeros(shape=(input.shape[0], self.hidden_size))
  161. c = zeros(shape=(input.shape[0], self.hidden_size))
  162. else:
  163. h, c = hx
  164. op = builtin.LSTMCell()
  165. return apply(
  166. op, input, self.weight_ih, self.bias_ih, h, self.weight_hh, self.bias_hh, c
  167. )[:2]
  168. class RNNBase(Module):
  169. def __init__(
  170. self,
  171. input_size: int,
  172. hidden_size: int,
  173. num_layers: int = 1,
  174. bias: bool = True,
  175. batch_first: bool = False,
  176. dropout: float = 0,
  177. bidirectional: bool = False,
  178. proj_size: int = 0,
  179. ) -> None:
  180. super(RNNBase, self).__init__()
  181. self.input_size = input_size
  182. self.hidden_size = hidden_size
  183. self.num_layers = num_layers
  184. self.bias = bias
  185. self.batch_first = batch_first
  186. self.dropout = float(dropout)
  187. self.bidirectional = bidirectional
  188. self.num_directions = 2 if self.bidirectional else 1
  189. self.proj_size = proj_size
  190. # check validity of dropout
  191. if (
  192. not isinstance(dropout, numbers.Number)
  193. or not 0 <= dropout <= 1
  194. or isinstance(dropout, bool)
  195. ):
  196. raise ValueError(
  197. "Dropout should be a float in [0, 1], which indicates the probability "
  198. "of an element to be zero"
  199. )
  200. if proj_size < 0:
  201. raise ValueError(
  202. "proj_size should be a positive integer or zero to disable projections"
  203. )
  204. elif proj_size >= hidden_size:
  205. raise ValueError("proj_size has to be smaller than hidden_size")
  206. self.cells = []
  207. for layer in range(self.num_layers):
  208. self.cells.append([])
  209. for _ in range(self.num_directions):
  210. self.cells[layer].append(self.create_cell(layer))
  211. # parameters have been initialized during the creation of the cells
  212. # if flatten, then delete cells
  213. self._flatten_parameters(self.cells)
  214. def _flatten_parameters(self, cells):
  215. gate_hidden_size = cells[0][0].gate_hidden_size
  216. size_dim1 = 0
  217. for layer in range(self.num_layers):
  218. for direction in range(self.num_directions):
  219. size_dim1 += cells[layer][direction].weight_ih.shape[1]
  220. size_dim1 += cells[layer][direction].weight_hh.shape[1]
  221. if self.bias:
  222. size_dim1 += 2 * self.num_directions * self.num_layers
  223. self._flatten_weights = Parameter(
  224. np.zeros((gate_hidden_size, size_dim1), dtype=np.float32)
  225. )
  226. self.reset_parameters()
  227. def reset_parameters(self) -> None:
  228. stdv = 1.0 / math.sqrt(self.hidden_size)
  229. for weight in self.parameters():
  230. init.uniform_(weight, -stdv, stdv)
  231. @abstractmethod
  232. def create_cell(self, layer):
  233. raise NotImplementedError("Cell not implemented !")
  234. @abstractmethod
  235. def init_hidden(self):
  236. raise NotImplementedError("init_hidden not implemented !")
  237. @abstractmethod
  238. def get_output_from_hidden(self, hx):
  239. raise NotImplementedError("get_output_from_hidden not implemented !")
  240. @abstractmethod
  241. def apply_op(self, input, hx):
  242. raise NotImplementedError("apply_op not implemented !")
  243. def _apply_fn_to_hx(self, hx, fn):
  244. return fn(hx)
  245. def _stack_h_n(self, h_n):
  246. return stack(h_n, axis=0)
  247. def forward(self, input: Tensor, hx=None):
  248. if self.batch_first:
  249. batch_size = input.shape[0]
  250. input = input.transpose((1, 0, 2)) # [seq_len, batch_size, dim]
  251. else:
  252. batch_size = input.shape[1]
  253. if hx is None:
  254. hx = self.init_hidden(batch_size)
  255. output, h = self.apply_op(input, hx)
  256. if self.batch_first:
  257. output = output.transpose((1, 0, 2))
  258. return output, h
  259. class RNN(RNNBase):
  260. r"""Applies a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to an
  261. input sequence.
  262. For each element in the input sequence, each layer computes the following
  263. function:
  264. .. math::
  265. h_t = \tanh(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh})
  266. where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is
  267. the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the
  268. previous layer at time `t-1` or the initial hidden state at time `0`.
  269. If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`.
  270. Args:
  271. input_size: The number of expected features in the input `x`
  272. hidden_size: The number of features in the hidden state `h`
  273. num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
  274. would mean stacking two RNNs together to form a `stacked RNN`,
  275. with the second RNN taking in outputs of the first RNN and
  276. computing the final results. Default: 1
  277. nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
  278. bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
  279. Default: ``True``
  280. batch_first: If ``True``, then the input and output tensors are provided
  281. as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
  282. Note that this does not apply to hidden or cell states. See the
  283. Inputs/Outputs sections below for details. Default: ``False``
  284. dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
  285. RNN layer except the last layer, with dropout probability equal to
  286. :attr:`dropout`. Default: 0
  287. bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False``
  288. Inputs: input, h_0
  289. * **input**: tensor of shape :math:`(L, N, H_{in})` when ``batch_first=False`` or
  290. :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
  291. the input sequence. The input can also be a packed variable length sequence.
  292. See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
  293. :func:`torch.nn.utils.rnn.pack_sequence` for details.
  294. * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hidden
  295. state for each element in the batch. Defaults to zeros if not provided.
  296. where:
  297. .. math::
  298. \begin{aligned}
  299. N ={} & \text{batch size} \\
  300. L ={} & \text{sequence length} \\
  301. D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
  302. H_{in} ={} & \text{input\_size} \\
  303. H_{out} ={} & \text{hidden\_size}
  304. \end{aligned}
  305. Outputs: output, h_n
  306. * **output**: tensor of shape :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
  307. :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
  308. `(h_t)` from the last layer of the RNN, for each `t`. If a
  309. :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
  310. will also be a packed sequence.
  311. * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state
  312. for each element in the batch.
  313. Examples:
  314. .. code-block::
  315. import numpy as np
  316. import megengine as mge
  317. import megengine.module as M
  318. m = M.RNN(10,20,2,batch_first=False,nonlinearity="relu",bias=True,bidirectional=True)
  319. inp = mge.tensor(np.random.randn(6, 30, 10), dtype=np.float32)
  320. hx = mge.tensor(np.random.randn(4, 30, 20), dtype=np.float32)
  321. out, hn = m(inp, hx)
  322. print(out.numpy().shape)
  323. Outputs:
  324. .. code-block::
  325. (6, 30, 40)
  326. """
  327. def __init__(self, *args, **kwargs) -> None:
  328. self.nonlinearity = kwargs.pop("nonlinearity", "tanh")
  329. super(RNN, self).__init__(*args, **kwargs)
  330. def create_cell(self, layer):
  331. if layer == 0:
  332. input_size = self.input_size
  333. else:
  334. input_size = self.num_directions * self.hidden_size
  335. return RNNCell(input_size, self.hidden_size, self.bias, self.nonlinearity)
  336. def init_hidden(self, batch_size):
  337. hidden_shape = (
  338. self.num_directions * self.num_layers,
  339. batch_size,
  340. self.hidden_size,
  341. )
  342. return zeros(shape=hidden_shape)
  343. def get_output_from_hidden(self, hx):
  344. return hx
  345. def apply_op(self, input, hx):
  346. fwd_mode = (
  347. BatchNorm.FwdMode.TRAINING if self.training else BatchNorm.FwdMode.INFERENCE
  348. )
  349. op = builtin.RNN(
  350. num_layers=self.num_layers,
  351. bidirectional=self.bidirectional,
  352. bias=self.bias,
  353. hidden_size=self.hidden_size,
  354. dropout=self.dropout,
  355. nonlineMode=self.nonlinearity,
  356. fwd_mode=fwd_mode,
  357. )
  358. output, h = apply(op, input, hx, self._flatten_weights)[:2]
  359. output = output + h.sum() * 0
  360. h = h + output.sum() * 0
  361. return output, h
  362. class LSTM(RNNBase):
  363. r"""Applies a multi-layer long short-term memory LSTM to an input
  364. sequence.
  365. For each element in the input sequence, each layer computes the following
  366. function:
  367. .. math::
  368. \begin{array}{ll} \\
  369. i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
  370. f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
  371. g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
  372. o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
  373. c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
  374. h_t = o_t \odot \tanh(c_t) \\
  375. \end{array}
  376. where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell
  377. state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}`
  378. is the hidden state of the layer at time `t-1` or the initial hidden
  379. state at time `0`, and :math:`i_t`, :math:`f_t`, :math:`g_t`,
  380. :math:`o_t` are the input, forget, cell, and output gates, respectively.
  381. :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
  382. In a multilayer LSTM, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
  383. (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
  384. dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
  385. variable which is :math:`0` with probability :attr:`dropout`.
  386. If ``proj_size > 0`` is specified, LSTM with projections will be used. This changes
  387. the LSTM cell in the following way. First, the dimension of :math:`h_t` will be changed from
  388. ``hidden_size`` to ``proj_size`` (dimensions of :math:`W_{hi}` will be changed accordingly).
  389. Second, the output hidden state of each layer will be multiplied by a learnable projection
  390. matrix: :math:`h_t = W_{hr}h_t`. Note that as a consequence of this, the output
  391. of LSTM network will be of different shape as well. See Inputs/Outputs sections below for exact
  392. dimensions of all variables. You can find more details in https://arxiv.org/abs/1402.1128.
  393. Args:
  394. input_size: The number of expected features in the input `x`
  395. hidden_size: The number of features in the hidden state `h`
  396. num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
  397. would mean stacking two LSTMs together to form a `stacked LSTM`,
  398. with the second LSTM taking in outputs of the first LSTM and
  399. computing the final results. Default: 1
  400. bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
  401. Default: ``True``
  402. batch_first: If ``True``, then the input and output tensors are provided
  403. as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
  404. Note that this does not apply to hidden or cell states. See the
  405. Inputs/Outputs sections below for details. Default: ``False``
  406. dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
  407. LSTM layer except the last layer, with dropout probability equal to
  408. :attr:`dropout`. Default: 0
  409. bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False``
  410. proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0
  411. Inputs: input, (h_0, c_0)
  412. * **input**: tensor of shape :math:`(L, N, H_{in})` when ``batch_first=False`` or
  413. :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
  414. the input sequence. The input can also be a packed variable length sequence.
  415. See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
  416. :func:`torch.nn.utils.rnn.pack_sequence` for details.
  417. * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the
  418. initial hidden state for each element in the batch.
  419. Defaults to zeros if (h_0, c_0) is not provided.
  420. * **c_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{cell})` containing the
  421. initial cell state for each element in the batch.
  422. Defaults to zeros if (h_0, c_0) is not provided.
  423. where:
  424. .. math::
  425. \begin{aligned}
  426. N ={} & \text{batch size} \\
  427. L ={} & \text{sequence length} \\
  428. D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
  429. H_{in} ={} & \text{input\_size} \\
  430. H_{cell} ={} & \text{hidden\_size} \\
  431. H_{out} ={} & \text{proj\_size if } \text{proj\_size}>0 \text{ otherwise hidden\_size} \\
  432. \end{aligned}
  433. Outputs: output, (h_n, c_n)
  434. * **output**: tensor of shape :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
  435. :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
  436. `(h_t)` from the last layer of the LSTM, for each `t`. If a
  437. :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
  438. will also be a packed sequence.
  439. * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the
  440. final hidden state for each element in the batch.
  441. * **c_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{cell})` containing the
  442. final cell state for each element in the batch.
  443. Examples:
  444. .. code-block::
  445. import numpy as np
  446. import megengine as mge
  447. import megengine.module as M
  448. m = M.LSTM(10, 20, 2, batch_first=False, bidirectional=True, bias=True)
  449. inp = mge.tensor(np.random.randn(6, 30, 10), dtype=np.float32)
  450. hx = mge.tensor(np.random.randn(4, 30, 20), dtype=np.float32)
  451. cx = mge.tensor(np.random.randn(4, 30, 20), dtype=np.float32)
  452. out, (hn, cn) = m(inp,(hx,cx))
  453. print(out.numpy().shape)
  454. Outputs:
  455. .. code-block::
  456. (6, 30, 40)
  457. """
  458. def __init__(self, *args, **kwargs) -> None:
  459. super(LSTM, self).__init__(*args, **kwargs)
  460. def create_cell(self, layer):
  461. if layer == 0:
  462. input_size = self.input_size
  463. else:
  464. input_size = self.num_directions * self.hidden_size
  465. return LSTMCell(input_size, self.hidden_size, self.bias)
  466. def init_hidden(self, batch_size):
  467. hidden_shape = (
  468. self.num_directions * self.num_layers,
  469. batch_size,
  470. self.hidden_size,
  471. )
  472. h = zeros(shape=hidden_shape)
  473. c = zeros(shape=hidden_shape)
  474. return (h, c)
  475. def get_output_from_hidden(self, hx):
  476. return hx[0]
  477. def apply_op(self, input, hx):
  478. fwd_mode = (
  479. BatchNorm.FwdMode.TRAINING if self.training else BatchNorm.FwdMode.INFERENCE
  480. )
  481. op = builtin.LSTM(
  482. num_layers=self.num_layers,
  483. bidirectional=self.bidirectional,
  484. bias=self.bias,
  485. hidden_size=self.hidden_size,
  486. proj_size=self.proj_size,
  487. dropout=self.dropout,
  488. fwd_mode=fwd_mode,
  489. )
  490. output, h, c = apply(op, input, hx[0], hx[1], self._flatten_weights)[:3]
  491. placeholders = [output.sum() * 0, h.sum() * 0, c.sum() * 0]
  492. output = output + placeholders[1] + placeholders[2]
  493. h = h + placeholders[0] + placeholders[2]
  494. c = c + placeholders[0] + placeholders[1]
  495. return output, (h, c)
  496. def _apply_fn_to_hx(self, hx, fn):
  497. return (fn(hx[0]), fn(hx[1]))
  498. def _stack_h_n(self, h_n):
  499. h = [tup[0] for tup in h_n]
  500. c = [tup[1] for tup in h_n]
  501. return (stack(h, axis=0), stack(c, axis=0))