You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

linear.py 1.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. from ..core import Parameter
  11. from ..functional import linear
  12. from . import init
  13. from .module import Module
  14. class Linear(Module):
  15. r"""Applies a linear transformation to the input. For instance, if input
  16. is x, then output y is:
  17. .. math::
  18. y = xW^T + b
  19. where :math:`y_i= \sum_j W_{ij} x_j + b_i`
  20. :param in_features: size of each input sample.
  21. :param out_features: size of each output sample.
  22. :param bias: If set to ``False``, the layer will not learn an additive bias.
  23. Default: ``True``
  24. """
  25. def __init__(
  26. self, in_features: int, out_features: int, bias: bool = True, **kwargs
  27. ):
  28. super().__init__(**kwargs)
  29. self.out_features = out_features
  30. self.in_features = in_features
  31. w_shape = (out_features, in_features)
  32. self.weight = Parameter(np.zeros(w_shape, dtype=np.float32))
  33. self.bias = None
  34. if bias:
  35. b_shape = (out_features,)
  36. self.bias = Parameter(np.zeros(b_shape, dtype=np.float32))
  37. self.reset_parameters()
  38. def _get_fanin(self):
  39. return self.in_features
  40. def reset_parameters(self) -> None:
  41. fanin = self._get_fanin()
  42. std = np.sqrt(1 / fanin)
  43. init.normal_(self.weight, 0.0, std)
  44. if self.bias is not None:
  45. init.zeros_(self.bias)
  46. def forward(self, x):
  47. return linear(x, self.weight, self.bias)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)