You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_rnn.py 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # -*- coding: utf-8 -*-
  2. import numpy as np
  3. import pytest
  4. import megengine as mge
  5. import megengine.functional as F
  6. from megengine.device import get_device_count
  7. from megengine.module import LSTM, RNN, LSTMCell, RNNCell
  8. def assert_tuple_equal(src, ref):
  9. assert len(src) == len(ref)
  10. for i, j in zip(src, ref):
  11. assert i == j
  12. @pytest.mark.skipif(get_device_count("gpu") > 0, reason="no algorithm on cuda")
  13. @pytest.mark.parametrize(
  14. "batch_size, input_size, hidden_size, init_hidden",
  15. [(3, 10, 20, True), (3, 10, 20, False), (1, 10, 20, False)],
  16. )
  17. def test_rnn_cell(batch_size, input_size, hidden_size, init_hidden):
  18. rnn_cell = RNNCell(input_size, hidden_size)
  19. x = mge.random.normal(size=(batch_size, input_size))
  20. if init_hidden:
  21. h = F.zeros(shape=(batch_size, hidden_size))
  22. else:
  23. h = None
  24. h_new = rnn_cell(x, h)
  25. assert_tuple_equal(h_new.shape, (batch_size, hidden_size))
  26. @pytest.mark.skipif(get_device_count("gpu") > 0, reason="no algorithm on cuda")
  27. @pytest.mark.parametrize(
  28. "batch_size, input_size, hidden_size, init_hidden",
  29. [(3, 10, 20, True), (3, 10, 20, False), (1, 10, 20, False)],
  30. )
  31. def test_lstm_cell(batch_size, input_size, hidden_size, init_hidden):
  32. rnn_cell = LSTMCell(input_size, hidden_size)
  33. x = mge.random.normal(size=(batch_size, input_size))
  34. if init_hidden:
  35. h = F.zeros(shape=(batch_size, hidden_size))
  36. hx = (h, h)
  37. else:
  38. hx = None
  39. h_new, c_new = rnn_cell(x, hx)
  40. assert_tuple_equal(h_new.shape, (batch_size, hidden_size))
  41. assert_tuple_equal(c_new.shape, (batch_size, hidden_size))
  42. @pytest.mark.skipif(get_device_count("gpu") > 0, reason="no algorithm on cuda")
  43. @pytest.mark.parametrize(
  44. "batch_size, seq_len, input_size, hidden_size, num_layers, bidirectional, init_hidden, batch_first",
  45. [
  46. (3, 6, 10, 20, 2, False, False, True),
  47. pytest.param(
  48. 3,
  49. 3,
  50. 10,
  51. 10,
  52. 1,
  53. True,
  54. True,
  55. False,
  56. marks=pytest.mark.skip(reason="bidirectional will cause cuda oom"),
  57. ),
  58. ],
  59. )
  60. def test_rnn(
  61. batch_size,
  62. seq_len,
  63. input_size,
  64. hidden_size,
  65. num_layers,
  66. bidirectional,
  67. init_hidden,
  68. batch_first,
  69. ):
  70. rnn = RNN(
  71. input_size,
  72. hidden_size,
  73. batch_first=batch_first,
  74. num_layers=num_layers,
  75. bidirectional=bidirectional,
  76. )
  77. if batch_first:
  78. x_shape = (batch_size, seq_len, input_size)
  79. else:
  80. x_shape = (seq_len, batch_size, input_size)
  81. x = mge.random.normal(size=x_shape)
  82. total_hidden_size = num_layers * (2 if bidirectional else 1) * hidden_size
  83. if init_hidden:
  84. h = mge.random.normal(size=(batch_size, total_hidden_size))
  85. else:
  86. h = None
  87. output, h_n = rnn(x, h)
  88. num_directions = 2 if bidirectional else 1
  89. if batch_first:
  90. assert_tuple_equal(
  91. output.shape, (batch_size, seq_len, num_directions * hidden_size)
  92. )
  93. else:
  94. assert_tuple_equal(
  95. output.shape, (seq_len, batch_size, num_directions * hidden_size)
  96. )
  97. assert_tuple_equal(
  98. h_n.shape, (num_directions * num_layers, batch_size, hidden_size)
  99. )
  100. @pytest.mark.skipif(get_device_count("gpu") > 0, reason="no algorithm on cuda")
  101. @pytest.mark.parametrize(
  102. "batch_size, seq_len, input_size, hidden_size, num_layers, bidirectional, init_hidden, batch_first",
  103. [
  104. (3, 10, 20, 20, 1, False, False, True),
  105. pytest.param(
  106. 3,
  107. 3,
  108. 10,
  109. 10,
  110. 1,
  111. True,
  112. True,
  113. False,
  114. marks=pytest.mark.skip(reason="bidirectional will cause cuda oom"),
  115. ),
  116. ],
  117. )
  118. def test_lstm(
  119. batch_size,
  120. seq_len,
  121. input_size,
  122. hidden_size,
  123. num_layers,
  124. bidirectional,
  125. init_hidden,
  126. batch_first,
  127. ):
  128. rnn = LSTM(
  129. input_size,
  130. hidden_size,
  131. batch_first=batch_first,
  132. num_layers=num_layers,
  133. bidirectional=bidirectional,
  134. )
  135. if batch_first:
  136. x_shape = (batch_size, seq_len, input_size)
  137. else:
  138. x_shape = (seq_len, batch_size, input_size)
  139. x = mge.random.normal(size=x_shape)
  140. total_hidden_size = num_layers * (2 if bidirectional else 1) * hidden_size
  141. if init_hidden:
  142. h = mge.random.normal(size=(batch_size, total_hidden_size))
  143. h = (h, h)
  144. else:
  145. h = None
  146. output, h_n = rnn(x, h)
  147. num_directions = 2 if bidirectional else 1
  148. if batch_first:
  149. assert_tuple_equal(
  150. output.shape, (batch_size, seq_len, num_directions * hidden_size)
  151. )
  152. else:
  153. assert_tuple_equal(
  154. output.shape, (seq_len, batch_size, num_directions * hidden_size)
  155. )
  156. assert_tuple_equal(
  157. h_n[0].shape, (num_directions * num_layers, batch_size, hidden_size)
  158. )
  159. assert_tuple_equal(
  160. h_n[1].shape, (num_directions * num_layers, batch_size, hidden_size)
  161. )