You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

embedding.py 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. # -*- coding: utf-8 -*-
  2. from typing import Optional
  3. import numpy as np
  4. from ..functional.nn import embedding as embedding_func
  5. from ..tensor import Parameter
  6. from . import init
  7. from .module import Module
  8. class Embedding(Module):
  9. r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
  10. This module is often used to store word embeddings and retrieve them using indices.
  11. The input to the module is a list of indices, and the output is the corresponding word embeddings.
  12. The indices should less than num_embeddings.
  13. Args:
  14. num_embeddings: size of embedding dictionary.
  15. embedding_dim: size of each embedding vector.
  16. padding_idx: should be set to None, not supportted now.
  17. max_norm: should be set to None, not supportted now.
  18. norm_type: should be set to None, not supportted now.
  19. initial_weight: the learnable weights of the module of shape (num_embeddings, embedding_dim).
  20. Examples:
  21. >>> import numpy as np
  22. >>> weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6)], dtype=np.float32))
  23. >>> data = mge.tensor(np.array([(0,0)], dtype=np.int32))
  24. >>> embedding = M.Embedding(1, 5, initial_weight=weight)
  25. >>> output = embedding(data)
  26. >>> with np.printoptions(precision=6):
  27. ... print(output.numpy())
  28. [[[1.2 2.3 3.4 4.5 5.6]
  29. [1.2 2.3 3.4 4.5 5.6]]]
  30. """
  31. def __init__(
  32. self,
  33. num_embeddings: int,
  34. embedding_dim: int,
  35. padding_idx: Optional[int] = None,
  36. max_norm: Optional[float] = None,
  37. norm_type: Optional[float] = None,
  38. initial_weight: Parameter = None,
  39. freeze: bool = False,
  40. **kwargs
  41. ):
  42. super().__init__(**kwargs)
  43. if padding_idx is not None:
  44. raise ValueError("Not support padding index now.")
  45. if max_norm is not None or norm_type is not None:
  46. raise ValueError("Not support weight normalize now.")
  47. self.padding_idx = padding_idx
  48. self.max_norm = max_norm
  49. self.norm_type = norm_type
  50. self.num_embeddings = num_embeddings
  51. self.embedding_dim = embedding_dim
  52. self.freeze = freeze
  53. if initial_weight is None:
  54. self.weight = Parameter(
  55. np.random.uniform(
  56. size=(self.num_embeddings, self.embedding_dim)
  57. ).astype(np.float32)
  58. )
  59. self.reset_parameters()
  60. else:
  61. if initial_weight.numpy().shape != (num_embeddings, embedding_dim):
  62. raise ValueError(
  63. "The weight shape should match num_embeddings and embedding_dim"
  64. )
  65. self.weight = Parameter(initial_weight.numpy())
  66. def reset_parameters(self) -> None:
  67. init.normal_(self.weight)
  68. def forward(self, inputs):
  69. if self.freeze:
  70. weight = self.weight.detach()
  71. else:
  72. weight = self.weight
  73. return embedding_func(inputs, weight)
  74. @classmethod
  75. def from_pretrained(
  76. cls,
  77. embeddings: Parameter,
  78. freeze: Optional[bool] = True,
  79. padding_idx: Optional[int] = None,
  80. max_norm: Optional[float] = None,
  81. norm_type: Optional[float] = None,
  82. ):
  83. r"""Creates Embedding instance from given 2-dimensional FloatTensor.
  84. Args:
  85. embeddings: tensor contained weight for the embedding.
  86. freeze: if ``True``, the weight does not get updated during the learning process. Default: True.
  87. padding_idx: should be set to None, not support Now.
  88. max_norm: should be set to None, not support Now.
  89. norm_type: should be set to None, not support Now.
  90. Examples:
  91. >>> import numpy as np
  92. >>> weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6)], dtype=np.float32))
  93. >>> data = mge.tensor(np.array([(0,0)], dtype=np.int32))
  94. >>> embedding = M.Embedding.from_pretrained(weight, freeze=False)
  95. >>> output = embedding(data)
  96. >>> output.numpy()
  97. array([[[1.2, 2.3, 3.4, 4.5, 5.6],
  98. [1.2, 2.3, 3.4, 4.5, 5.6]]], dtype=float32)
  99. """
  100. embeddings_shape = embeddings.shape
  101. embeddings_dim = len(embeddings_shape)
  102. if embeddings_dim != 2:
  103. raise ValueError("Embeddings parameter is expected to be 2-dimensional")
  104. rows = embeddings_shape[0]
  105. cols = embeddings_shape[1]
  106. embedding = cls(
  107. num_embeddings=rows,
  108. embedding_dim=cols,
  109. initial_weight=embeddings,
  110. padding_idx=padding_idx,
  111. max_norm=max_norm,
  112. norm_type=norm_type,
  113. freeze=freeze,
  114. )
  115. return embedding