You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

embedding.py 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional
  10. import numpy as np
  11. from ..functional import embedding as embedding_func
  12. from ..tensor import Parameter
  13. from . import init
  14. from .module import Module
  15. class Embedding(Module):
  16. r"""
  17. A simple lookup table that stores embeddings of a fixed dictionary and size.
  18. This module is often used to store word embeddings and retrieve them using indices.
  19. The input to the module is a list of indices, and the output is the corresponding word embeddings.
  20. The indices should less than num_embeddings.
  21. :param num_embeddings: size of embedding dictionary.
  22. :param embedding_dim: size of each embedding vector.
  23. :param padding_idx: should be set to None, not supportted now.
  24. :param max_norm: should be set to None, not supportted now.
  25. :param norm_type: should be set to None, not supportted now.
  26. :param initial_weight: the learnable weights of the module of shape (num_embeddings, embedding_dim).
  27. Examples:
  28. .. testcode::
  29. import numpy as np
  30. import megengine as mge
  31. import megengine.module as M
  32. weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6),(0.1,1.1,2.1,3.1,4.1)], dtype=np.float32))
  33. data = mge.tensor(np.array([(0,1,1),(1,0,1),(0,0,1)], dtype=np.int32))
  34. embedding = M.Embedding(2, 5, initial_weight=weight)
  35. output = embedding(data)
  36. with np.printoptions(precision=6):
  37. print(output.numpy())
  38. Outputs:
  39. .. testoutput::
  40. [[[1.2 2.3 3.4 4.5 5.6]
  41. [0.1 1.1 2.1 3.1 4.1]
  42. [0.1 1.1 2.1 3.1 4.1]]
  43. [[0.1 1.1 2.1 3.1 4.1]
  44. [1.2 2.3 3.4 4.5 5.6]
  45. [0.1 1.1 2.1 3.1 4.1]]
  46. [[1.2 2.3 3.4 4.5 5.6]
  47. [1.2 2.3 3.4 4.5 5.6]
  48. [0.1 1.1 2.1 3.1 4.1]]]
  49. """
  50. def __init__(
  51. self,
  52. num_embeddings: int,
  53. embedding_dim: int,
  54. padding_idx: Optional[int] = None,
  55. max_norm: Optional[float] = None,
  56. norm_type: Optional[float] = None,
  57. initial_weight: Parameter = None,
  58. freeze: bool = False,
  59. ):
  60. super().__init__()
  61. if padding_idx is not None:
  62. raise ValueError("Not support padding index now.")
  63. if max_norm is not None or norm_type is not None:
  64. raise ValueError("Not support weight normalize now.")
  65. self.padding_idx = padding_idx
  66. self.max_norm = max_norm
  67. self.norm_type = norm_type
  68. self.num_embeddings = num_embeddings
  69. self.embedding_dim = embedding_dim
  70. self.freeze = freeze
  71. if initial_weight is None:
  72. self.weight = Parameter(
  73. np.random.uniform(
  74. size=(self.num_embeddings, self.embedding_dim)
  75. ).astype(np.float32)
  76. )
  77. self.reset_parameters()
  78. else:
  79. if initial_weight.shape != (num_embeddings, embedding_dim):
  80. raise ValueError(
  81. "The weight shape should match num_embeddings and embedding_dim"
  82. )
  83. self.weight = Parameter(initial_weight.numpy())
  84. def reset_parameters(self) -> None:
  85. init.normal_(self.weight)
  86. def forward(self, inputs):
  87. if self.freeze:
  88. weight = self.weight.detach()
  89. else:
  90. weight = self.weight
  91. return embedding_func(inputs, weight)
  92. @classmethod
  93. def from_pretrained(
  94. cls,
  95. embeddings: Parameter,
  96. freeze: Optional[bool] = True,
  97. padding_idx: Optional[int] = None,
  98. max_norm: Optional[float] = None,
  99. norm_type: Optional[float] = None,
  100. ):
  101. r"""
  102. Creates Embedding instance from given 2-dimensional FloatTensor.
  103. :param embeddings: tensor contained weight for the embedding.
  104. :param freeze: if ``True``, the weight does not get updated during the learning process. Default: True.
  105. :param padding_idx: should be set to None, not support Now.
  106. :param max_norm: should be set to None, not support Now.
  107. :param norm_type: should be set to None, not support Now.
  108. Examples:
  109. .. testcode::
  110. import numpy as np
  111. import megengine as mge
  112. import megengine.module as M
  113. weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6),(0.1,1.1,2.1,3.1,4.1)], dtype=np.float32))
  114. data = mge.tensor(np.array([(0,1,1),(1,0,1),(0,0,1)], dtype=np.int32))
  115. embedding = M.Embedding.from_pretrained(weight, freeze=False)
  116. output = embedding(data)
  117. print(output.numpy())
  118. Outputs:
  119. .. testoutput::
  120. [[[1.2 2.3 3.4 4.5 5.6]
  121. [0.1 1.1 2.1 3.1 4.1]
  122. [0.1 1.1 2.1 3.1 4.1]]
  123. [[0.1 1.1 2.1 3.1 4.1]
  124. [1.2 2.3 3.4 4.5 5.6]
  125. [0.1 1.1 2.1 3.1 4.1]]
  126. [[1.2 2.3 3.4 4.5 5.6]
  127. [1.2 2.3 3.4 4.5 5.6]
  128. [0.1 1.1 2.1 3.1 4.1]]]
  129. """
  130. embeddings_shape = embeddings.shape
  131. embeddings_dim = len(embeddings_shape)
  132. if embeddings_dim != 2:
  133. raise ValueError("Embeddings parameter is expected to be 2-dimensional")
  134. rows = embeddings_shape[0]
  135. cols = embeddings_shape[1]
  136. embedding = cls(
  137. num_embeddings=rows,
  138. embedding_dim=cols,
  139. initial_weight=embeddings,
  140. padding_idx=padding_idx,
  141. max_norm=max_norm,
  142. norm_type=norm_type,
  143. freeze=freeze,
  144. )
  145. return embedding

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台