You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

embedding.py 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional
  10. import numpy as np
  11. from ..functional.nn import embedding as embedding_func
  12. from ..tensor import Parameter
  13. from . import init
  14. from .module import Module
  15. class Embedding(Module):
  16. r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
  17. This module is often used to store word embeddings and retrieve them using indices.
  18. The input to the module is a list of indices, and the output is the corresponding word embeddings.
  19. The indices should less than num_embeddings.
  20. Args:
  21. num_embeddings: size of embedding dictionary.
  22. embedding_dim: size of each embedding vector.
  23. padding_idx: should be set to None, not supportted now.
  24. max_norm: should be set to None, not supportted now.
  25. norm_type: should be set to None, not supportted now.
  26. initial_weight: the learnable weights of the module of shape (num_embeddings, embedding_dim).
  27. Examples:
  28. >>> import numpy as np
  29. >>> weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6)], dtype=np.float32))
  30. >>> data = mge.tensor(np.array([(0,0)], dtype=np.int32))
  31. >>> embedding = M.Embedding(1, 5, initial_weight=weight)
  32. >>> output = embedding(data)
  33. >>> with np.printoptions(precision=6):
  34. ... print(output.numpy())
  35. [[[1.2 2.3 3.4 4.5 5.6]
  36. [1.2 2.3 3.4 4.5 5.6]]]
  37. """
  38. def __init__(
  39. self,
  40. num_embeddings: int,
  41. embedding_dim: int,
  42. padding_idx: Optional[int] = None,
  43. max_norm: Optional[float] = None,
  44. norm_type: Optional[float] = None,
  45. initial_weight: Parameter = None,
  46. freeze: bool = False,
  47. **kwargs
  48. ):
  49. super().__init__(**kwargs)
  50. if padding_idx is not None:
  51. raise ValueError("Not support padding index now.")
  52. if max_norm is not None or norm_type is not None:
  53. raise ValueError("Not support weight normalize now.")
  54. self.padding_idx = padding_idx
  55. self.max_norm = max_norm
  56. self.norm_type = norm_type
  57. self.num_embeddings = num_embeddings
  58. self.embedding_dim = embedding_dim
  59. self.freeze = freeze
  60. if initial_weight is None:
  61. self.weight = Parameter(
  62. np.random.uniform(
  63. size=(self.num_embeddings, self.embedding_dim)
  64. ).astype(np.float32)
  65. )
  66. self.reset_parameters()
  67. else:
  68. if initial_weight.numpy().shape != (num_embeddings, embedding_dim):
  69. raise ValueError(
  70. "The weight shape should match num_embeddings and embedding_dim"
  71. )
  72. self.weight = Parameter(initial_weight.numpy())
  73. def reset_parameters(self) -> None:
  74. init.normal_(self.weight)
  75. def forward(self, inputs):
  76. if self.freeze:
  77. weight = self.weight.detach()
  78. else:
  79. weight = self.weight
  80. return embedding_func(inputs, weight)
  81. @classmethod
  82. def from_pretrained(
  83. cls,
  84. embeddings: Parameter,
  85. freeze: Optional[bool] = True,
  86. padding_idx: Optional[int] = None,
  87. max_norm: Optional[float] = None,
  88. norm_type: Optional[float] = None,
  89. ):
  90. r"""Creates Embedding instance from given 2-dimensional FloatTensor.
  91. Args:
  92. embeddings: tensor contained weight for the embedding.
  93. freeze: if ``True``, the weight does not get updated during the learning process. Default: True.
  94. padding_idx: should be set to None, not support Now.
  95. max_norm: should be set to None, not support Now.
  96. norm_type: should be set to None, not support Now.
  97. Examples:
  98. >>> import numpy as np
  99. >>> weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6)], dtype=np.float32))
  100. >>> data = mge.tensor(np.array([(0,0)], dtype=np.int32))
  101. >>> embedding = M.Embedding.from_pretrained(weight, freeze=False)
  102. >>> output = embedding(data)
  103. >>> output.numpy()
  104. array([[[1.2, 2.3, 3.4, 4.5, 5.6],
  105. [1.2, 2.3, 3.4, 4.5, 5.6]]], dtype=float32)
  106. """
  107. embeddings_shape = embeddings.shape
  108. embeddings_dim = len(embeddings_shape)
  109. if embeddings_dim != 2:
  110. raise ValueError("Embeddings parameter is expected to be 2-dimensional")
  111. rows = embeddings_shape[0]
  112. cols = embeddings_shape[1]
  113. embedding = cls(
  114. num_embeddings=rows,
  115. embedding_dim=cols,
  116. initial_weight=embeddings,
  117. padding_idx=padding_idx,
  118. max_norm=max_norm,
  119. norm_type=norm_type,
  120. freeze=freeze,
  121. )
  122. return embedding