You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batch_matmul_activation.py 1.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import numpy as np
  2. from ..functional import matmul, relu
  3. from ..tensor import Parameter
  4. from . import init
  5. from .module import Module
  6. class BatchMatMulActivation(Module):
  7. r"""Batched :func:`~.matmul` with activation(only :func:`~.relu` supported), no transpose anywhere."""
  8. def __init__(
  9. self,
  10. batch: int,
  11. in_features: int,
  12. out_features: int,
  13. bias: bool = True,
  14. nonlinear_mode="identity",
  15. **kwargs
  16. ):
  17. super().__init__(**kwargs)
  18. self.batch = batch
  19. self.out_features = out_features
  20. self.in_features = in_features
  21. w_shape = (batch, out_features, in_features)
  22. self.weight = Parameter(np.zeros(w_shape, dtype=np.float32))
  23. self.bias = None
  24. if bias:
  25. b_shape = (out_features,)
  26. self.bias = Parameter(np.zeros(b_shape, dtype=np.float32))
  27. self.nonlinear_mode = nonlinear_mode.lower()
  28. self.reset_parameters()
  29. def _get_fanin(self):
  30. return self.in_features
  31. def reset_parameters(self) -> None:
  32. fanin = self._get_fanin()
  33. std = np.sqrt(1 / fanin)
  34. init.normal_(self.weight, 0.0, std)
  35. if self.bias is not None:
  36. init.zeros_(self.bias)
  37. def _calc_linear(self, x, weight, bias):
  38. res = matmul(weight, x)
  39. if self.bias is not None:
  40. res += bias
  41. if self.nonlinear_mode == "relu":
  42. res = relu(res)
  43. return res
  44. def forward(self, x):
  45. return self._calc_linear(x, self.weight, self.bias)
  46. def _module_info_string(self) -> str:
  47. return "batch={}, in_features={}, out_features={}, bias={}".format(
  48. self.batch, self.in_features, self.out_features, self.bias is not None
  49. )