You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_rnn.py 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. import pytest
  11. import megengine as mge
  12. import megengine.functional as F
  13. from megengine.module import LSTM, RNN, LSTMCell, RNNCell
  14. def assert_tuple_equal(src, ref):
  15. assert len(src) == len(ref)
  16. for i, j in zip(src, ref):
  17. assert i == j
  18. @pytest.mark.parametrize(
  19. "batch_size, input_size, hidden_size, init_hidden",
  20. [(3, 10, 20, True), (3, 10, 20, False), (1, 10, 20, False), (0, 10, 20, True)],
  21. )
  22. def test_rnn_cell(batch_size, input_size, hidden_size, init_hidden):
  23. rnn_cell = RNNCell(input_size, hidden_size)
  24. x = mge.random.normal(size=(batch_size, input_size))
  25. if init_hidden:
  26. h = F.zeros(shape=(batch_size, hidden_size))
  27. else:
  28. h = None
  29. h_new = rnn_cell(x, h)
  30. assert_tuple_equal(h_new.shape, (batch_size, hidden_size))
  31. # is batch_size == 0 tolerated ? it will cause error in slice operation xx[:, ...]
  32. @pytest.mark.parametrize(
  33. "batch_size, input_size, hidden_size, init_hidden",
  34. [(3, 10, 20, True), (3, 10, 20, False), (1, 10, 20, False)],
  35. )
  36. def test_lstm_cell(batch_size, input_size, hidden_size, init_hidden):
  37. rnn_cell = LSTMCell(input_size, hidden_size)
  38. x = mge.random.normal(size=(batch_size, input_size))
  39. if init_hidden:
  40. h = F.zeros(shape=(batch_size, hidden_size))
  41. hx = (h, h)
  42. else:
  43. hx = None
  44. h_new, c_new = rnn_cell(x, hx)
  45. assert_tuple_equal(h_new.shape, (batch_size, hidden_size))
  46. assert_tuple_equal(c_new.shape, (batch_size, hidden_size))
  47. @pytest.mark.parametrize(
  48. "batch_size, seq_len, input_size, hidden_size, num_layers, bidirectional, init_hidden, batch_first",
  49. [
  50. (3, 6, 10, 20, 2, False, False, True),
  51. pytest.param(
  52. 3,
  53. 3,
  54. 10,
  55. 10,
  56. 1,
  57. True,
  58. True,
  59. False,
  60. marks=pytest.mark.skip(reason="bidirectional will cause cuda oom"),
  61. ),
  62. ],
  63. )
  64. # (0, 1, 1, 1, 1, False, True, False)])
  65. def test_rnn(
  66. batch_size,
  67. seq_len,
  68. input_size,
  69. hidden_size,
  70. num_layers,
  71. bidirectional,
  72. init_hidden,
  73. batch_first,
  74. ):
  75. rnn = RNN(
  76. input_size,
  77. hidden_size,
  78. batch_first=batch_first,
  79. num_layers=num_layers,
  80. bidirectional=bidirectional,
  81. )
  82. if batch_first:
  83. x_shape = (batch_size, seq_len, input_size)
  84. else:
  85. x_shape = (seq_len, batch_size, input_size)
  86. x = mge.random.normal(size=x_shape)
  87. total_hidden_size = num_layers * (2 if bidirectional else 1) * hidden_size
  88. if init_hidden:
  89. h = mge.random.normal(size=(batch_size, total_hidden_size))
  90. else:
  91. h = None
  92. output, h_n = rnn(x, h)
  93. num_directions = 2 if bidirectional else 1
  94. if batch_first:
  95. assert_tuple_equal(
  96. output.shape, (batch_size, seq_len, num_directions * hidden_size)
  97. )
  98. else:
  99. assert_tuple_equal(
  100. output.shape, (seq_len, batch_size, num_directions * hidden_size)
  101. )
  102. assert_tuple_equal(
  103. h_n.shape, (num_directions * num_layers, batch_size, hidden_size)
  104. )
  105. @pytest.mark.parametrize(
  106. "batch_size, seq_len, input_size, hidden_size, num_layers, bidirectional, init_hidden, batch_first",
  107. [
  108. (3, 10, 20, 20, 1, False, False, True),
  109. pytest.param(
  110. 3,
  111. 3,
  112. 10,
  113. 10,
  114. 1,
  115. True,
  116. True,
  117. False,
  118. marks=pytest.mark.skip(reason="bidirectional will cause cuda oom"),
  119. ),
  120. ],
  121. )
  122. # (0, 1, 1, 1, 1, False, True, False)])
  123. def test_lstm(
  124. batch_size,
  125. seq_len,
  126. input_size,
  127. hidden_size,
  128. num_layers,
  129. bidirectional,
  130. init_hidden,
  131. batch_first,
  132. ):
  133. rnn = LSTM(
  134. input_size,
  135. hidden_size,
  136. batch_first=batch_first,
  137. num_layers=num_layers,
  138. bidirectional=bidirectional,
  139. )
  140. if batch_first:
  141. x_shape = (batch_size, seq_len, input_size)
  142. else:
  143. x_shape = (seq_len, batch_size, input_size)
  144. x = mge.random.normal(size=x_shape)
  145. total_hidden_size = num_layers * (2 if bidirectional else 1) * hidden_size
  146. if init_hidden:
  147. h = mge.random.normal(size=(batch_size, total_hidden_size))
  148. h = (h, h)
  149. else:
  150. h = None
  151. output, h_n = rnn(x, h)
  152. num_directions = 2 if bidirectional else 1
  153. if batch_first:
  154. assert_tuple_equal(
  155. output.shape, (batch_size, seq_len, num_directions * hidden_size)
  156. )
  157. else:
  158. assert_tuple_equal(
  159. output.shape, (seq_len, batch_size, num_directions * hidden_size)
  160. )
  161. assert_tuple_equal(
  162. h_n[0].shape, (num_directions * num_layers, batch_size, hidden_size)
  163. )
  164. assert_tuple_equal(
  165. h_n[1].shape, (num_directions * num_layers, batch_size, hidden_size)
  166. )
  167. if __name__ == "__main__":
  168. test_lstm(5, 10, 10, 20, 1, False, False, True)