You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

linear.py 2.1 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import numpy as np
  2. from ..functional.nn import linear
  3. from ..tensor import Parameter
  4. from . import init
  5. from .module import Module
  6. class Linear(Module):
  7. r"""Applies a linear transformation to the input. For instance, if input
  8. is x, then output y is:
  9. .. math::
  10. y = xW^T + b
  11. where :math:`y_i= \sum_j W_{ij} x_j + b_i`
  12. Args:
  13. in_features: size of each input sample.
  14. out_features: size of each output sample.
  15. bias: if it's ``False``, the layer will not learn an additional ``bias``.
  16. Default: ``True``
  17. Examples:
  18. >>> import numpy as np
  19. >>> m = M.Linear(in_features=3, out_features=1)
  20. >>> inp = mge.tensor(np.arange(0, 6).astype("float32").reshape(2, 3))
  21. >>> oup = m(inp)
  22. >>> oup.numpy().shape
  23. (2, 1)
  24. """
  25. def __init__(
  26. self,
  27. in_features: int,
  28. out_features: int,
  29. bias: bool = True,
  30. compute_mode: str = "default",
  31. **kwargs
  32. ):
  33. super().__init__(**kwargs)
  34. self.out_features = out_features
  35. self.in_features = in_features
  36. w_shape = (out_features, in_features)
  37. self.weight = Parameter(np.zeros(w_shape, dtype=np.float32))
  38. self.bias = None
  39. if bias:
  40. b_shape = (out_features,)
  41. self.bias = Parameter(np.zeros(b_shape, dtype=np.float32))
  42. self.compute_mode = compute_mode
  43. self.reset_parameters()
  44. def _get_fanin(self):
  45. return self.in_features
  46. def reset_parameters(self) -> None:
  47. fanin = self._get_fanin()
  48. std = np.sqrt(1 / fanin)
  49. init.normal_(self.weight, 0.0, std)
  50. if self.bias is not None:
  51. init.zeros_(self.bias)
  52. def _calc_linear(self, x, weight, bias):
  53. return linear(x, weight, bias, compute_mode=self.compute_mode)
  54. def forward(self, x):
  55. return self._calc_linear(x, self.weight, self.bias)
  56. def _module_info_string(self) -> str:
  57. return "in_features={}, out_features={}, bias={}".format(
  58. self.in_features, self.out_features, self.bias is not None
  59. )