|
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import math
- import numbers
- from abc import abstractmethod
- from typing import Optional, Tuple
-
- import numpy as np
-
- from ..core._imperative_rt.core2 import apply
- from ..core.ops import builtin
- from ..device import is_cuda_available
- from ..functional import concat, expand_dims, repeat, stack, zeros
- from ..functional.nn import concat
- from ..tensor import Parameter, Tensor
- from . import init
- from .module import Module
-
-
- class RNNCellBase(Module):
- def __init__(
- self,
- input_size: int,
- hidden_size: int,
- bias: bool,
- num_chunks: int,
- device=None,
- dtype=None,
- ) -> None:
- # num_chunks indicates the number of gates
- super(RNNCellBase, self).__init__()
-
- self.input_size = input_size
- self.hidden_size = hidden_size
- self.bias = bias
-
- # initialize weights
- common_kwargs = {"device": device, "dtype": dtype}
- self.gate_hidden_size = num_chunks * hidden_size
- self.weight_ih = Parameter(
- np.random.uniform(size=(self.gate_hidden_size, input_size)).astype(
- np.float32
- ),
- **common_kwargs,
- )
- self.weight_hh = Parameter(
- np.random.uniform(size=(self.gate_hidden_size, hidden_size)).astype(
- np.float32
- ),
- **common_kwargs,
- )
- if bias:
- self.bias_ih = Parameter(
- np.random.uniform(size=(self.gate_hidden_size)).astype(np.float32),
- **common_kwargs,
- )
- self.bias_hh = Parameter(
- np.random.uniform(size=(self.gate_hidden_size)).astype(np.float32),
- **common_kwargs,
- )
- else:
- self.bias_ih = zeros(shape=(self.gate_hidden_size), **common_kwargs)
- self.bias_hh = zeros(shape=(self.gate_hidden_size), **common_kwargs)
- self.reset_parameters()
- # if bias is False self.bias will remain zero
-
- def get_op(self):
- return builtin.RNNCell()
-
- def reset_parameters(self) -> None:
- stdv = 1.0 / math.sqrt(self.hidden_size)
- for weight in self.parameters():
- init.uniform_(weight, -stdv, stdv)
-
- def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
- if hx is None:
- hx = zeros(
- shape=(input.shape[0], self.gate_hidden_size),
- dtype=input.dtype,
- device=input.device,
- )
- op = self.get_op()
- return apply(
- op, input, self.weight_ih, self.bias_ih, hx, self.weight_hh, self.bias_hh
- )[0]
- # return linear(input, self.weight_ih, self.bias_ih) + linear(hx, self.weight_hh, self.bias_hh)
-
-
- class RNNCell(RNNCellBase):
- def __init__(
- self,
- input_size: int,
- hidden_size: int,
- bias: bool = True,
- nonlinearity: str = "tanh",
- device=None,
- dtype=None,
- ) -> None:
- self.nonlinearity = nonlinearity
- super(RNNCell, self).__init__(
- input_size, hidden_size, bias, num_chunks=1, device=device, dtype=dtype
- )
- # self.activate = tanh if nonlinearity == "tanh" else relu
-
- def get_op(self):
- return builtin.RNNCell(nonlineMode=self.nonlinearity)
-
- def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
- return super().forward(input, hx)
-
-
- class LSTMCell(RNNCellBase):
- def __init__(
- self,
- input_size: int,
- hidden_size: int,
- bias: bool = True,
- device=None,
- dtype=None,
- ) -> None:
- super(LSTMCell, self).__init__(
- input_size, hidden_size, bias, num_chunks=4, device=device, dtype=dtype
- )
-
- def get_op(self):
- return builtin.LSTMCell()
-
- def forward(
- self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
- ) -> Tuple[Tensor, Tensor]:
- # hx: (h, c)
- if hx is None:
- h = zeros(
- shape=(input.shape[0], self.hidden_size),
- dtype=input.dtype,
- device=input.device,
- )
- c = zeros(
- shape=(input.shape[0], self.hidden_size),
- dtype=input.dtype,
- device=input.device,
- )
- else:
- h, c = hx
- op = self.get_op()
- return apply(
- op, input, self.weight_ih, self.bias_ih, h, self.weight_hh, self.bias_hh, c
- )[:2]
-
-
- def is_gpu(device: str) -> bool:
- if "xpux" in device and is_cuda_available():
- return True
- if "gpu" in device:
- return True
- return False
-
-
- class RNNBase(Module):
- def __init__(
- self,
- input_size: int,
- hidden_size: int,
- num_layers: int = 1,
- bias: bool = True,
- batch_first: bool = False,
- dropout: float = 0,
- bidirectional: bool = False,
- proj_size: int = 0,
- device=None,
- dtype=None,
- ) -> None:
- super(RNNBase, self).__init__()
- # self.mode = mode
- self.input_size = input_size
- self.hidden_size = hidden_size
- self.num_layers = num_layers
- self.bias = bias
- self.batch_first = batch_first
- self.dropout = float(dropout)
- self.bidirectional = bidirectional
- self.num_directions = 2 if self.bidirectional else 1
- self.proj_size = proj_size
-
- # check validity of dropout
- if (
- not isinstance(dropout, numbers.Number)
- or not 0 <= dropout <= 1
- or isinstance(dropout, bool)
- ):
- raise ValueError(
- "Dropout should be a float in [0, 1], which indicates the probability "
- "of an element to be zero"
- )
-
- if proj_size < 0:
- raise ValueError(
- "proj_size should be a positive integer or zero to disable projections"
- )
- elif proj_size >= hidden_size:
- raise ValueError("proj_size has to be smaller than hidden_size")
-
- self.cells = []
- for layer in range(self.num_layers):
- self.cells.append([])
- for _ in range(self.num_directions):
- self.cells[layer].append(self.create_cell(layer, device, dtype))
- # parameters have been initialized during the creation of the cells
- # if flatten, then delete cells
- self._flatten_parameters(device, dtype, self.cells)
-
- def _flatten_parameters(self, device, dtype, cells):
- gate_hidden_size = cells[0][0].gate_hidden_size
- size_dim1 = 0
- for layer in range(self.num_layers):
- for direction in range(self.num_directions):
- size_dim1 += cells[layer][direction].weight_ih.shape[1]
- size_dim1 += cells[layer][direction].weight_hh.shape[1]
- # if self.bias:
- # size_dim1 += 2 * self.num_directions * self.num_layers
- size_dim1 += 2 * self.num_directions * self.num_layers
- self._flatten_weights = Parameter(
- np.zeros((gate_hidden_size, size_dim1), dtype=np.float32)
- )
- self.reset_parameters()
- # TODO: if no bias, set the bias to zero
-
- def reset_parameters(self) -> None:
- stdv = 1.0 / math.sqrt(self.hidden_size)
- for weight in self.parameters():
- init.uniform_(weight, -stdv, stdv)
-
- @abstractmethod
- def create_cell(self, layer, device, dtype):
- raise NotImplementedError("Cell not implemented !")
-
- @abstractmethod
- def init_hidden(self):
- raise NotImplementedError("init_hidden not implemented !")
-
- @abstractmethod
- def get_output_from_hidden(self, hx):
- raise NotImplementedError("get_output_from_hidden not implemented !")
-
- @abstractmethod
- def apply_op(self, input, hx):
- raise NotImplementedError("apply_op not implemented !")
-
- def _apply_fn_to_hx(self, hx, fn):
- return fn(hx)
-
- def _stack_h_n(self, h_n):
- return stack(h_n, axis=0)
-
- def forward(self, input: Tensor, hx=None):
- if self.batch_first:
- batch_size = input.shape[0]
- input = input.transpose((1, 0, 2)) # [seq_len, batch_size, dim]
- else:
- batch_size = input.shape[1]
- if hx is None:
- hx = self.init_hidden(batch_size, input.device, input.dtype)
-
- output, h = self.apply_op(input, hx)
- if self.batch_first:
- output = output.transpose((1, 0, 2))
- return output, h
-
- if is_gpu(str(input.device)) or True:
- # return output, h_n
- output, h = self.apply_op(input, hx)
- if self.batch_first:
- output = output.transpose((1, 0, 2))
- return output, h
-
- order_settings = [(0, input.shape[0]), (input.shape[0] - 1, -1, -1)]
- h_n = []
- for layer in range(self.num_layers):
- layer_outputs = []
- for direction in range(self.num_directions):
- direction_outputs = [None for _ in range(input.shape[0])]
- cell = self.cells[layer][direction]
- hidden = self._apply_fn_to_hx(
- hx, lambda x: x[layer * self.num_directions + direction]
- )
- for step in range(*(order_settings[direction])):
- hidden = cell(input[step], hidden) # [batch_size, hidden_size]
- direction_outputs[step] = self.get_output_from_hidden(hidden)
- direction_output = stack(
- direction_outputs, axis=0
- ) # [seq_len, batch_size, hidden_size]
- layer_outputs.append(direction_output)
- h_n.append(hidden)
- layer_output = concat(
- layer_outputs, axis=-1
- ) # [seq_len, batch_size, D*hidden_size]
- input = layer_output
- if self.batch_first:
- layer_output = layer_output.transpose((1, 0, 2))
- return layer_output, self._stack_h_n(h_n)
-
-
- class RNN(RNNBase):
- def __init__(self, *args, **kwargs) -> None:
- self.nonlinearity = kwargs.pop("nonlinearity", "tanh")
- super(RNN, self).__init__(*args, **kwargs)
-
- def create_cell(self, layer, device, dtype):
- if layer == 0:
- input_size = self.input_size
- else:
- input_size = self.num_directions * self.hidden_size
- return RNNCell(
- input_size, self.hidden_size, self.bias, self.nonlinearity, device, dtype
- )
-
- def init_hidden(self, batch_size, device, dtype):
- hidden_shape = (
- self.num_directions * self.num_layers,
- batch_size,
- self.hidden_size,
- )
- return zeros(shape=hidden_shape, dtype=dtype, device=device)
-
- def get_output_from_hidden(self, hx):
- return hx
-
- def apply_op(self, input, hx):
- op = builtin.RNN(
- num_layers=self.num_layers,
- bidirectional=self.bidirectional,
- bias=self.bias,
- hidden_size=self.hidden_size,
- proj_size=self.proj_size,
- dropout=self.dropout,
- nonlineMode=self.nonlinearity,
- )
- output, h = apply(op, input, hx, self._flatten_weights)[:2]
- output = output + h.sum() * 0
- h = h + output.sum() * 0
- return output, h
-
-
- class LSTM(RNNBase):
- def __init__(self, *args, **kwargs) -> None:
- super(LSTM, self).__init__(*args, **kwargs)
-
- def create_cell(self, layer, device, dtype):
- if layer == 0:
- input_size = self.input_size
- else:
- input_size = self.num_directions * self.hidden_size
- return LSTMCell(input_size, self.hidden_size, self.bias, device, dtype)
-
- def init_hidden(self, batch_size, device, dtype):
- hidden_shape = (
- self.num_directions * self.num_layers,
- batch_size,
- self.hidden_size,
- )
- h = zeros(shape=hidden_shape, dtype=dtype, device=device)
- c = zeros(shape=hidden_shape, dtype=dtype, device=device)
- return (h, c)
-
- def get_output_from_hidden(self, hx):
- return hx[0]
-
- def apply_op(self, input, hx):
- op = builtin.LSTM(
- num_layers=self.num_layers,
- bidirectional=self.bidirectional,
- bias=self.bias,
- hidden_size=self.hidden_size,
- proj_size=self.proj_size,
- dropout=self.dropout,
- )
- output, h, c = apply(op, input, hx[0], hx[1], self._flatten_weights)[:3]
- placeholders = [output.sum() * 0, h.sum() * 0, c.sum() * 0]
- output = output + placeholders[1] + placeholders[2]
- h = h + placeholders[0] + placeholders[2]
- c = c + placeholders[0] + placeholders[1]
- return output, (h, c)
-
- def _apply_fn_to_hx(self, hx, fn):
- return (fn(hx[0]), fn(hx[1]))
-
- def _stack_h_n(self, h_n):
- h = [tup[0] for tup in h_n]
- c = [tup[1] for tup in h_n]
- return (stack(h, axis=0), stack(c, axis=0))
|