You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_rnn.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. import pytest
  11. import megengine as mge
  12. import megengine.functional as F
  13. from megengine.device import get_device_count
  14. from megengine.module import LSTM, RNN, LSTMCell, RNNCell
  15. def assert_tuple_equal(src, ref):
  16. assert len(src) == len(ref)
  17. for i, j in zip(src, ref):
  18. assert i == j
  19. @pytest.mark.skipif(get_device_count("gpu") > 0, reason="no algorithm on cuda")
  20. @pytest.mark.parametrize(
  21. "batch_size, input_size, hidden_size, init_hidden",
  22. [(3, 10, 20, True), (3, 10, 20, False), (1, 10, 20, False)],
  23. )
  24. def test_rnn_cell(batch_size, input_size, hidden_size, init_hidden):
  25. rnn_cell = RNNCell(input_size, hidden_size)
  26. x = mge.random.normal(size=(batch_size, input_size))
  27. if init_hidden:
  28. h = F.zeros(shape=(batch_size, hidden_size))
  29. else:
  30. h = None
  31. h_new = rnn_cell(x, h)
  32. assert_tuple_equal(h_new.shape, (batch_size, hidden_size))
  33. @pytest.mark.skipif(get_device_count("gpu") > 0, reason="no algorithm on cuda")
  34. @pytest.mark.parametrize(
  35. "batch_size, input_size, hidden_size, init_hidden",
  36. [(3, 10, 20, True), (3, 10, 20, False), (1, 10, 20, False)],
  37. )
  38. def test_lstm_cell(batch_size, input_size, hidden_size, init_hidden):
  39. rnn_cell = LSTMCell(input_size, hidden_size)
  40. x = mge.random.normal(size=(batch_size, input_size))
  41. if init_hidden:
  42. h = F.zeros(shape=(batch_size, hidden_size))
  43. hx = (h, h)
  44. else:
  45. hx = None
  46. h_new, c_new = rnn_cell(x, hx)
  47. assert_tuple_equal(h_new.shape, (batch_size, hidden_size))
  48. assert_tuple_equal(c_new.shape, (batch_size, hidden_size))
  49. @pytest.mark.skipif(get_device_count("gpu") > 0, reason="no algorithm on cuda")
  50. @pytest.mark.parametrize(
  51. "batch_size, seq_len, input_size, hidden_size, num_layers, bidirectional, init_hidden, batch_first",
  52. [
  53. (3, 6, 10, 20, 2, False, False, True),
  54. pytest.param(
  55. 3,
  56. 3,
  57. 10,
  58. 10,
  59. 1,
  60. True,
  61. True,
  62. False,
  63. marks=pytest.mark.skip(reason="bidirectional will cause cuda oom"),
  64. ),
  65. ],
  66. )
  67. def test_rnn(
  68. batch_size,
  69. seq_len,
  70. input_size,
  71. hidden_size,
  72. num_layers,
  73. bidirectional,
  74. init_hidden,
  75. batch_first,
  76. ):
  77. rnn = RNN(
  78. input_size,
  79. hidden_size,
  80. batch_first=batch_first,
  81. num_layers=num_layers,
  82. bidirectional=bidirectional,
  83. )
  84. if batch_first:
  85. x_shape = (batch_size, seq_len, input_size)
  86. else:
  87. x_shape = (seq_len, batch_size, input_size)
  88. x = mge.random.normal(size=x_shape)
  89. total_hidden_size = num_layers * (2 if bidirectional else 1) * hidden_size
  90. if init_hidden:
  91. h = mge.random.normal(size=(batch_size, total_hidden_size))
  92. else:
  93. h = None
  94. output, h_n = rnn(x, h)
  95. num_directions = 2 if bidirectional else 1
  96. if batch_first:
  97. assert_tuple_equal(
  98. output.shape, (batch_size, seq_len, num_directions * hidden_size)
  99. )
  100. else:
  101. assert_tuple_equal(
  102. output.shape, (seq_len, batch_size, num_directions * hidden_size)
  103. )
  104. assert_tuple_equal(
  105. h_n.shape, (num_directions * num_layers, batch_size, hidden_size)
  106. )
  107. @pytest.mark.skipif(get_device_count("gpu") > 0, reason="no algorithm on cuda")
  108. @pytest.mark.parametrize(
  109. "batch_size, seq_len, input_size, hidden_size, num_layers, bidirectional, init_hidden, batch_first",
  110. [
  111. (3, 10, 20, 20, 1, False, False, True),
  112. pytest.param(
  113. 3,
  114. 3,
  115. 10,
  116. 10,
  117. 1,
  118. True,
  119. True,
  120. False,
  121. marks=pytest.mark.skip(reason="bidirectional will cause cuda oom"),
  122. ),
  123. ],
  124. )
  125. def test_lstm(
  126. batch_size,
  127. seq_len,
  128. input_size,
  129. hidden_size,
  130. num_layers,
  131. bidirectional,
  132. init_hidden,
  133. batch_first,
  134. ):
  135. rnn = LSTM(
  136. input_size,
  137. hidden_size,
  138. batch_first=batch_first,
  139. num_layers=num_layers,
  140. bidirectional=bidirectional,
  141. )
  142. if batch_first:
  143. x_shape = (batch_size, seq_len, input_size)
  144. else:
  145. x_shape = (seq_len, batch_size, input_size)
  146. x = mge.random.normal(size=x_shape)
  147. total_hidden_size = num_layers * (2 if bidirectional else 1) * hidden_size
  148. if init_hidden:
  149. h = mge.random.normal(size=(batch_size, total_hidden_size))
  150. h = (h, h)
  151. else:
  152. h = None
  153. output, h_n = rnn(x, h)
  154. num_directions = 2 if bidirectional else 1
  155. if batch_first:
  156. assert_tuple_equal(
  157. output.shape, (batch_size, seq_len, num_directions * hidden_size)
  158. )
  159. else:
  160. assert_tuple_equal(
  161. output.shape, (seq_len, batch_size, num_directions * hidden_size)
  162. )
  163. assert_tuple_equal(
  164. h_n[0].shape, (num_directions * num_layers, batch_size, hidden_size)
  165. )
  166. assert_tuple_equal(
  167. h_n[1].shape, (num_directions * num_layers, batch_size, hidden_size)
  168. )