You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

embedding.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional
  10. import numpy as np
  11. from ..functional.nn import embedding as embedding_func
  12. from ..tensor import Parameter
  13. from . import init
  14. from .module import Module
  15. class Embedding(Module):
  16. r"""
  17. A simple lookup table that stores embeddings of a fixed dictionary and size.
  18. This module is often used to store word embeddings and retrieve them using indices.
  19. The input to the module is a list of indices, and the output is the corresponding word embeddings.
  20. The indices should less than num_embeddings.
  21. :param num_embeddings: size of embedding dictionary.
  22. :param embedding_dim: size of each embedding vector.
  23. :param padding_idx: should be set to None, not supportted now.
  24. :param max_norm: should be set to None, not supportted now.
  25. :param norm_type: should be set to None, not supportted now.
  26. :param initial_weight: the learnable weights of the module of shape (num_embeddings, embedding_dim).
  27. Examples:
  28. .. testcode::
  29. import numpy as np
  30. import megengine as mge
  31. import megengine.module as M
  32. weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6)], dtype=np.float32))
  33. data = mge.tensor(np.array([(0,0)], dtype=np.int32))
  34. embedding = M.Embedding(1, 5, initial_weight=weight)
  35. output = embedding(data)
  36. with np.printoptions(precision=6):
  37. print(output.numpy())
  38. Outputs:
  39. .. testoutput::
  40. [[[1.2 2.3 3.4 4.5 5.6]
  41. [1.2 2.3 3.4 4.5 5.6]]]
  42. """
  43. def __init__(
  44. self,
  45. num_embeddings: int,
  46. embedding_dim: int,
  47. padding_idx: Optional[int] = None,
  48. max_norm: Optional[float] = None,
  49. norm_type: Optional[float] = None,
  50. initial_weight: Parameter = None,
  51. freeze: bool = False,
  52. ):
  53. super().__init__()
  54. if padding_idx is not None:
  55. raise ValueError("Not support padding index now.")
  56. if max_norm is not None or norm_type is not None:
  57. raise ValueError("Not support weight normalize now.")
  58. self.padding_idx = padding_idx
  59. self.max_norm = max_norm
  60. self.norm_type = norm_type
  61. self.num_embeddings = num_embeddings
  62. self.embedding_dim = embedding_dim
  63. self.freeze = freeze
  64. if initial_weight is None:
  65. self.weight = Parameter(
  66. np.random.uniform(
  67. size=(self.num_embeddings, self.embedding_dim)
  68. ).astype(np.float32)
  69. )
  70. self.reset_parameters()
  71. else:
  72. if initial_weight.numpy().shape != (num_embeddings, embedding_dim):
  73. raise ValueError(
  74. "The weight shape should match num_embeddings and embedding_dim"
  75. )
  76. self.weight = Parameter(initial_weight.numpy())
  77. def reset_parameters(self) -> None:
  78. init.normal_(self.weight)
  79. def forward(self, inputs):
  80. if self.freeze:
  81. weight = self.weight.detach()
  82. else:
  83. weight = self.weight
  84. return embedding_func(inputs, weight)
  85. @classmethod
  86. def from_pretrained(
  87. cls,
  88. embeddings: Parameter,
  89. freeze: Optional[bool] = True,
  90. padding_idx: Optional[int] = None,
  91. max_norm: Optional[float] = None,
  92. norm_type: Optional[float] = None,
  93. ):
  94. r"""
  95. Creates Embedding instance from given 2-dimensional FloatTensor.
  96. :param embeddings: tensor contained weight for the embedding.
  97. :param freeze: if ``True``, the weight does not get updated during the learning process. Default: True.
  98. :param padding_idx: should be set to None, not support Now.
  99. :param max_norm: should be set to None, not support Now.
  100. :param norm_type: should be set to None, not support Now.
  101. Examples:
  102. .. testcode::
  103. import numpy as np
  104. import megengine as mge
  105. import megengine.module as M
  106. weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6)], dtype=np.float32))
  107. data = mge.tensor(np.array([(0,0)], dtype=np.int32))
  108. embedding = M.Embedding.from_pretrained(weight, freeze=False)
  109. output = embedding(data)
  110. print(output.numpy())
  111. Outputs:
  112. .. testoutput::
  113. [[[1.2 2.3 3.4 4.5 5.6]
  114. [1.2 2.3 3.4 4.5 5.6]]]
  115. """
  116. embeddings_shape = embeddings.shape
  117. embeddings_dim = len(embeddings_shape)
  118. if embeddings_dim != 2:
  119. raise ValueError("Embeddings parameter is expected to be 2-dimensional")
  120. rows = embeddings_shape[0]
  121. cols = embeddings_shape[1]
  122. embedding = cls(
  123. num_embeddings=rows,
  124. embedding_dim=cols,
  125. initial_weight=embeddings,
  126. padding_idx=padding_idx,
  127. max_norm=max_norm,
  128. norm_type=norm_type,
  129. freeze=freeze,
  130. )
  131. return embedding

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台