You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rnn.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import math
  10. import numbers
  11. from abc import abstractmethod
  12. from typing import Optional, Tuple
  13. import numpy as np
  14. from ..core._imperative_rt.core2 import apply
  15. from ..core.ops import builtin
  16. from ..device import is_cuda_available
  17. from ..functional import concat, expand_dims, repeat, stack, zeros
  18. from ..functional.nn import concat
  19. from ..tensor import Parameter, Tensor
  20. from . import init
  21. from .module import Module
  22. class RNNCellBase(Module):
  23. def __init__(
  24. self,
  25. input_size: int,
  26. hidden_size: int,
  27. bias: bool,
  28. num_chunks: int,
  29. device=None,
  30. dtype=None,
  31. ) -> None:
  32. # num_chunks indicates the number of gates
  33. super(RNNCellBase, self).__init__()
  34. self.input_size = input_size
  35. self.hidden_size = hidden_size
  36. self.bias = bias
  37. # initialize weights
  38. common_kwargs = {"device": device, "dtype": dtype}
  39. self.gate_hidden_size = num_chunks * hidden_size
  40. self.weight_ih = Parameter(
  41. np.random.uniform(size=(self.gate_hidden_size, input_size)).astype(
  42. np.float32
  43. ),
  44. **common_kwargs,
  45. )
  46. self.weight_hh = Parameter(
  47. np.random.uniform(size=(self.gate_hidden_size, hidden_size)).astype(
  48. np.float32
  49. ),
  50. **common_kwargs,
  51. )
  52. if bias:
  53. self.bias_ih = Parameter(
  54. np.random.uniform(size=(self.gate_hidden_size)).astype(np.float32),
  55. **common_kwargs,
  56. )
  57. self.bias_hh = Parameter(
  58. np.random.uniform(size=(self.gate_hidden_size)).astype(np.float32),
  59. **common_kwargs,
  60. )
  61. else:
  62. self.bias_ih = zeros(shape=(self.gate_hidden_size), **common_kwargs)
  63. self.bias_hh = zeros(shape=(self.gate_hidden_size), **common_kwargs)
  64. self.reset_parameters()
  65. # if bias is False self.bias will remain zero
  66. def get_op(self):
  67. return builtin.RNNCell()
  68. def reset_parameters(self) -> None:
  69. stdv = 1.0 / math.sqrt(self.hidden_size)
  70. for weight in self.parameters():
  71. init.uniform_(weight, -stdv, stdv)
  72. def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
  73. if hx is None:
  74. hx = zeros(
  75. shape=(input.shape[0], self.gate_hidden_size),
  76. dtype=input.dtype,
  77. device=input.device,
  78. )
  79. op = self.get_op()
  80. return apply(
  81. op, input, self.weight_ih, self.bias_ih, hx, self.weight_hh, self.bias_hh
  82. )[0]
  83. # return linear(input, self.weight_ih, self.bias_ih) + linear(hx, self.weight_hh, self.bias_hh)
  84. class RNNCell(RNNCellBase):
  85. def __init__(
  86. self,
  87. input_size: int,
  88. hidden_size: int,
  89. bias: bool = True,
  90. nonlinearity: str = "tanh",
  91. device=None,
  92. dtype=None,
  93. ) -> None:
  94. self.nonlinearity = nonlinearity
  95. super(RNNCell, self).__init__(
  96. input_size, hidden_size, bias, num_chunks=1, device=device, dtype=dtype
  97. )
  98. # self.activate = tanh if nonlinearity == "tanh" else relu
  99. def get_op(self):
  100. return builtin.RNNCell(nonlineMode=self.nonlinearity)
  101. def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
  102. return super().forward(input, hx)
  103. class LSTMCell(RNNCellBase):
  104. def __init__(
  105. self,
  106. input_size: int,
  107. hidden_size: int,
  108. bias: bool = True,
  109. device=None,
  110. dtype=None,
  111. ) -> None:
  112. super(LSTMCell, self).__init__(
  113. input_size, hidden_size, bias, num_chunks=4, device=device, dtype=dtype
  114. )
  115. def get_op(self):
  116. return builtin.LSTMCell()
  117. def forward(
  118. self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
  119. ) -> Tuple[Tensor, Tensor]:
  120. # hx: (h, c)
  121. if hx is None:
  122. h = zeros(
  123. shape=(input.shape[0], self.hidden_size),
  124. dtype=input.dtype,
  125. device=input.device,
  126. )
  127. c = zeros(
  128. shape=(input.shape[0], self.hidden_size),
  129. dtype=input.dtype,
  130. device=input.device,
  131. )
  132. else:
  133. h, c = hx
  134. op = self.get_op()
  135. return apply(
  136. op, input, self.weight_ih, self.bias_ih, h, self.weight_hh, self.bias_hh, c
  137. )[:2]
  138. def is_gpu(device: str) -> bool:
  139. if "xpux" in device and is_cuda_available():
  140. return True
  141. if "gpu" in device:
  142. return True
  143. return False
  144. class RNNBase(Module):
  145. def __init__(
  146. self,
  147. input_size: int,
  148. hidden_size: int,
  149. num_layers: int = 1,
  150. bias: bool = True,
  151. batch_first: bool = False,
  152. dropout: float = 0,
  153. bidirectional: bool = False,
  154. proj_size: int = 0,
  155. device=None,
  156. dtype=None,
  157. ) -> None:
  158. super(RNNBase, self).__init__()
  159. # self.mode = mode
  160. self.input_size = input_size
  161. self.hidden_size = hidden_size
  162. self.num_layers = num_layers
  163. self.bias = bias
  164. self.batch_first = batch_first
  165. self.dropout = float(dropout)
  166. self.bidirectional = bidirectional
  167. self.num_directions = 2 if self.bidirectional else 1
  168. self.proj_size = proj_size
  169. # check validity of dropout
  170. if (
  171. not isinstance(dropout, numbers.Number)
  172. or not 0 <= dropout <= 1
  173. or isinstance(dropout, bool)
  174. ):
  175. raise ValueError(
  176. "Dropout should be a float in [0, 1], which indicates the probability "
  177. "of an element to be zero"
  178. )
  179. if proj_size < 0:
  180. raise ValueError(
  181. "proj_size should be a positive integer or zero to disable projections"
  182. )
  183. elif proj_size >= hidden_size:
  184. raise ValueError("proj_size has to be smaller than hidden_size")
  185. self.cells = []
  186. for layer in range(self.num_layers):
  187. self.cells.append([])
  188. for _ in range(self.num_directions):
  189. self.cells[layer].append(self.create_cell(layer, device, dtype))
  190. # parameters have been initialized during the creation of the cells
  191. # if flatten, then delete cells
  192. self._flatten_parameters(device, dtype, self.cells)
  193. def _flatten_parameters(self, device, dtype, cells):
  194. gate_hidden_size = cells[0][0].gate_hidden_size
  195. size_dim1 = 0
  196. for layer in range(self.num_layers):
  197. for direction in range(self.num_directions):
  198. size_dim1 += cells[layer][direction].weight_ih.shape[1]
  199. size_dim1 += cells[layer][direction].weight_hh.shape[1]
  200. # if self.bias:
  201. # size_dim1 += 2 * self.num_directions * self.num_layers
  202. size_dim1 += 2 * self.num_directions * self.num_layers
  203. self._flatten_weights = Parameter(
  204. np.zeros((gate_hidden_size, size_dim1), dtype=np.float32)
  205. )
  206. self.reset_parameters()
  207. # TODO: if no bias, set the bias to zero
  208. def reset_parameters(self) -> None:
  209. stdv = 1.0 / math.sqrt(self.hidden_size)
  210. for weight in self.parameters():
  211. init.uniform_(weight, -stdv, stdv)
  212. @abstractmethod
  213. def create_cell(self, layer, device, dtype):
  214. raise NotImplementedError("Cell not implemented !")
  215. @abstractmethod
  216. def init_hidden(self):
  217. raise NotImplementedError("init_hidden not implemented !")
  218. @abstractmethod
  219. def get_output_from_hidden(self, hx):
  220. raise NotImplementedError("get_output_from_hidden not implemented !")
  221. @abstractmethod
  222. def apply_op(self, input, hx):
  223. raise NotImplementedError("apply_op not implemented !")
  224. def _apply_fn_to_hx(self, hx, fn):
  225. return fn(hx)
  226. def _stack_h_n(self, h_n):
  227. return stack(h_n, axis=0)
  228. def forward(self, input: Tensor, hx=None):
  229. if self.batch_first:
  230. batch_size = input.shape[0]
  231. input = input.transpose((1, 0, 2)) # [seq_len, batch_size, dim]
  232. else:
  233. batch_size = input.shape[1]
  234. if hx is None:
  235. hx = self.init_hidden(batch_size, input.device, input.dtype)
  236. output, h = self.apply_op(input, hx)
  237. if self.batch_first:
  238. output = output.transpose((1, 0, 2))
  239. return output, h
  240. if is_gpu(str(input.device)) or True:
  241. # return output, h_n
  242. output, h = self.apply_op(input, hx)
  243. if self.batch_first:
  244. output = output.transpose((1, 0, 2))
  245. return output, h
  246. order_settings = [(0, input.shape[0]), (input.shape[0] - 1, -1, -1)]
  247. h_n = []
  248. for layer in range(self.num_layers):
  249. layer_outputs = []
  250. for direction in range(self.num_directions):
  251. direction_outputs = [None for _ in range(input.shape[0])]
  252. cell = self.cells[layer][direction]
  253. hidden = self._apply_fn_to_hx(
  254. hx, lambda x: x[layer * self.num_directions + direction]
  255. )
  256. for step in range(*(order_settings[direction])):
  257. hidden = cell(input[step], hidden) # [batch_size, hidden_size]
  258. direction_outputs[step] = self.get_output_from_hidden(hidden)
  259. direction_output = stack(
  260. direction_outputs, axis=0
  261. ) # [seq_len, batch_size, hidden_size]
  262. layer_outputs.append(direction_output)
  263. h_n.append(hidden)
  264. layer_output = concat(
  265. layer_outputs, axis=-1
  266. ) # [seq_len, batch_size, D*hidden_size]
  267. input = layer_output
  268. if self.batch_first:
  269. layer_output = layer_output.transpose((1, 0, 2))
  270. return layer_output, self._stack_h_n(h_n)
  271. class RNN(RNNBase):
  272. def __init__(self, *args, **kwargs) -> None:
  273. self.nonlinearity = kwargs.pop("nonlinearity", "tanh")
  274. super(RNN, self).__init__(*args, **kwargs)
  275. def create_cell(self, layer, device, dtype):
  276. if layer == 0:
  277. input_size = self.input_size
  278. else:
  279. input_size = self.num_directions * self.hidden_size
  280. return RNNCell(
  281. input_size, self.hidden_size, self.bias, self.nonlinearity, device, dtype
  282. )
  283. def init_hidden(self, batch_size, device, dtype):
  284. hidden_shape = (
  285. self.num_directions * self.num_layers,
  286. batch_size,
  287. self.hidden_size,
  288. )
  289. return zeros(shape=hidden_shape, dtype=dtype, device=device)
  290. def get_output_from_hidden(self, hx):
  291. return hx
  292. def apply_op(self, input, hx):
  293. op = builtin.RNN(
  294. num_layers=self.num_layers,
  295. bidirectional=self.bidirectional,
  296. bias=self.bias,
  297. hidden_size=self.hidden_size,
  298. proj_size=self.proj_size,
  299. dropout=self.dropout,
  300. nonlineMode=self.nonlinearity,
  301. )
  302. output, h = apply(op, input, hx, self._flatten_weights)[:2]
  303. output = output + h.sum() * 0
  304. h = h + output.sum() * 0
  305. return output, h
  306. class LSTM(RNNBase):
  307. def __init__(self, *args, **kwargs) -> None:
  308. super(LSTM, self).__init__(*args, **kwargs)
  309. def create_cell(self, layer, device, dtype):
  310. if layer == 0:
  311. input_size = self.input_size
  312. else:
  313. input_size = self.num_directions * self.hidden_size
  314. return LSTMCell(input_size, self.hidden_size, self.bias, device, dtype)
  315. def init_hidden(self, batch_size, device, dtype):
  316. hidden_shape = (
  317. self.num_directions * self.num_layers,
  318. batch_size,
  319. self.hidden_size,
  320. )
  321. h = zeros(shape=hidden_shape, dtype=dtype, device=device)
  322. c = zeros(shape=hidden_shape, dtype=dtype, device=device)
  323. return (h, c)
  324. def get_output_from_hidden(self, hx):
  325. return hx[0]
  326. def apply_op(self, input, hx):
  327. op = builtin.LSTM(
  328. num_layers=self.num_layers,
  329. bidirectional=self.bidirectional,
  330. bias=self.bias,
  331. hidden_size=self.hidden_size,
  332. proj_size=self.proj_size,
  333. dropout=self.dropout,
  334. )
  335. output, h, c = apply(op, input, hx[0], hx[1], self._flatten_weights)[:3]
  336. placeholders = [output.sum() * 0, h.sum() * 0, c.sum() * 0]
  337. output = output + placeholders[1] + placeholders[2]
  338. h = h + placeholders[0] + placeholders[2]
  339. c = c + placeholders[0] + placeholders[1]
  340. return output, (h, c)
  341. def _apply_fn_to_hx(self, hx, fn):
  342. return (fn(hx[0]), fn(hx[1]))
  343. def _stack_h_n(self, h_n):
  344. h = [tup[0] for tup in h_n]
  345. c = [tup[1] for tup in h_n]
  346. return (stack(h, axis=0), stack(c, axis=0))