You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batch_matmul_activation.py 2.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import numpy as np
  9. from ..functional import matmul, relu
  10. from ..tensor import Parameter
  11. from . import init
  12. from .module import Module
  13. class BatchMatMulActivation(Module):
  14. r"""
  15. Batched :func:`~.matmul` with activation(only :func:`~.relu` supported), no transpose anywhere.
  16. """
  17. def __init__(
  18. self,
  19. batch: int,
  20. in_features: int,
  21. out_features: int,
  22. bias: bool = True,
  23. nonlinear_mode="IDENTITY",
  24. **kwargs
  25. ):
  26. super().__init__(**kwargs)
  27. self.batch = batch
  28. self.out_features = out_features
  29. self.in_features = in_features
  30. w_shape = (batch, out_features, in_features)
  31. self.weight = Parameter(np.zeros(w_shape, dtype=np.float32))
  32. self.bias = None
  33. if bias:
  34. b_shape = (out_features,)
  35. self.bias = Parameter(np.zeros(b_shape, dtype=np.float32))
  36. self.nonlinear_mode = nonlinear_mode
  37. self.reset_parameters()
  38. def _get_fanin(self):
  39. return self.in_features
  40. def reset_parameters(self) -> None:
  41. fanin = self._get_fanin()
  42. std = np.sqrt(1 / fanin)
  43. init.normal_(self.weight, 0.0, std)
  44. if self.bias is not None:
  45. init.zeros_(self.bias)
  46. def _calc_linear(self, x, weight, bias):
  47. res = matmul(weight, x)
  48. if self.bias is not None:
  49. res += bias
  50. if self.nonlinear_mode == "RELU":
  51. res = relu(res)
  52. return res
  53. def forward(self, x):
  54. return self._calc_linear(x, self.weight, self.bias)
  55. def _module_info_string(self) -> str:
  56. return "batch={}, in_features={}, out_features={}, bias={}".format(
  57. self.batch, self.in_features, self.out_features, self.bias is not None
  58. )

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台