You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

embedding.py 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional
  10. import numpy as np
  11. from ..functional import embedding as embedding_func
  12. from ..tensor_nn import Parameter
  13. from . import init
  14. from .module import Module
  15. class Embedding(Module):
  16. r"""
  17. A simple lookup table that stores embeddings of a fixed dictionary and size.
  18. This module is often used to store word embeddings and retrieve them using indices.
  19. The input to the module is a list of indices, and the output is the corresponding word embeddings.
  20. The indices should less than num_embeddings.
  21. :param num_embeddings: size of embedding dictionary.
  22. :param embedding_dim: size of each embedding vector.
  23. :param padding_idx: should be set to None, not support now.
  24. :param max_norm: should be set to None, not support now.
  25. :param norm_type: should be set to None, not support now.
  26. :param initial_weight: the learnable weights of the module of shape (num_embeddings, embedding_dim).
  27. Examples:
  28. .. testcode::
  29. import numpy as np
  30. import megengine as mge
  31. import megengine.module as M
  32. weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6),(0.1,1.1,2.1,3.1,4.1)], dtype=np.float32))
  33. data = mge.tensor(np.array([(0,1,1),(1,0,1),(0,0,1)], dtype=np.int32))
  34. embedding = M.Embedding(2, 5, initial_weight=weight)
  35. output = embedding(data)
  36. with np.printoptions(precision=6):
  37. print(output.numpy())
  38. Outputs:
  39. .. testoutput::
  40. [[[1.2 2.3 3.4 4.5 5.6]
  41. [0.1 1.1 2.1 3.1 4.1]
  42. [0.1 1.1 2.1 3.1 4.1]]
  43. [[0.1 1.1 2.1 3.1 4.1]
  44. [1.2 2.3 3.4 4.5 5.6]
  45. [0.1 1.1 2.1 3.1 4.1]]
  46. [[1.2 2.3 3.4 4.5 5.6]
  47. [1.2 2.3 3.4 4.5 5.6]
  48. [0.1 1.1 2.1 3.1 4.1]]]
  49. """
  50. def __init__(
  51. self,
  52. num_embeddings: int,
  53. embedding_dim: int,
  54. padding_idx: Optional[int] = None,
  55. max_norm: Optional[float] = None,
  56. norm_type: Optional[float] = None,
  57. initial_weight: Parameter = None,
  58. ):
  59. super().__init__()
  60. if padding_idx is not None:
  61. raise ValueError("Not support padding index now.")
  62. if max_norm is not None or norm_type is not None:
  63. raise ValueError("Not support weight normalize now.")
  64. self.padding_idx = padding_idx
  65. self.max_norm = max_norm
  66. self.norm_type = norm_type
  67. self.num_embeddings = num_embeddings
  68. self.embedding_dim = embedding_dim
  69. if initial_weight is None:
  70. self.weight = Parameter(
  71. np.random.uniform(
  72. size=(self.num_embeddings, self.embedding_dim)
  73. ).astype(np.float32)
  74. )
  75. self.reset_parameters()
  76. else:
  77. if initial_weight.shape != (num_embeddings, embedding_dim):
  78. raise ValueError(
  79. "The weight shape should match num_embeddings and embedding_dim"
  80. )
  81. self.weight = Parameter(initial_weight.numpy())
  82. def reset_parameters(self) -> None:
  83. init.normal_(self.weight)
  84. def forward(self, inputs):
  85. return embedding_func(inputs, self.weight)
  86. @classmethod
  87. def from_pretrained(
  88. cls,
  89. embeddings: Parameter,
  90. freeze: Optional[bool] = True,
  91. padding_idx: Optional[int] = None,
  92. max_norm: Optional[float] = None,
  93. norm_type: Optional[float] = None,
  94. ):
  95. r"""
  96. Creates Embedding instance from given 2-dimensional FloatTensor.
  97. :param embeddings: Tensor contained weight for the embedding.
  98. :param freeze: If ``True``, the weight does not get updated during the learning process. Default: ``True``.
  99. :param padding_idx: should be set to None, not support Now.
  100. :param max_norm: should be set to None, not support Now.
  101. :param norm_type: should be set to None, not support Now.
  102. Examples:
  103. .. testcode::
  104. import numpy as np
  105. import megengine as mge
  106. import megengine.module as M
  107. weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6),(0.1,1.1,2.1,3.1,4.1)], dtype=np.float32))
  108. data = mge.tensor(np.array([(0,1,1),(1,0,1),(0,0,1)], dtype=np.int32))
  109. embedding = M.Embedding.from_pretrained(weight, freeze=False)
  110. output = embedding(data)
  111. print(output.numpy())
  112. Outputs:
  113. .. testoutput::
  114. [[[1.2 2.3 3.4 4.5 5.6]
  115. [0.1 1.1 2.1 3.1 4.1]
  116. [0.1 1.1 2.1 3.1 4.1]]
  117. [[0.1 1.1 2.1 3.1 4.1]
  118. [1.2 2.3 3.4 4.5 5.6]
  119. [0.1 1.1 2.1 3.1 4.1]]
  120. [[1.2 2.3 3.4 4.5 5.6]
  121. [1.2 2.3 3.4 4.5 5.6]
  122. [0.1 1.1 2.1 3.1 4.1]]]
  123. """
  124. embeddings_shape = embeddings.shape
  125. embeddings_dim = len(embeddings_shape)
  126. if embeddings_dim != 2:
  127. raise ValueError("Embeddings parameter is expected to be 2-dimensional")
  128. rows = embeddings_shape[0]
  129. cols = embeddings_shape[1]
  130. embedding = cls(
  131. num_embeddings=rows,
  132. embedding_dim=cols,
  133. initial_weight=embeddings,
  134. padding_idx=padding_idx,
  135. max_norm=max_norm,
  136. norm_type=norm_type,
  137. )
  138. embedding.weight.requires_grad = not freeze
  139. return embedding

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台