You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_init.py 2.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. # -*- coding: utf-8 -*-
  2. import numpy as np
  3. import pytest
  4. from megengine import tensor
  5. from megengine.module import Conv1d, Conv2d, Conv3d, Linear
  6. from megengine.module.init import calculate_fan_in_and_fan_out, fill_
  7. def test_fill_():
  8. x = tensor(np.zeros((2, 3, 4)), dtype=np.float32)
  9. fill_(x, 5.0)
  10. np.testing.assert_array_equal(
  11. x.numpy(), np.full(shape=(2, 3, 4), fill_value=5.0, dtype=np.float32)
  12. )
  13. def test_calculate_fan_in_and_fan_out():
  14. l = Linear(in_features=3, out_features=8)
  15. fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
  16. assert fanin == 3
  17. assert fanout == 8
  18. with pytest.raises(ValueError):
  19. calculate_fan_in_and_fan_out(l.bias)
  20. l = Conv1d(in_channels=2, out_channels=3, kernel_size=5)
  21. fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
  22. assert fanin == 2 * 5
  23. assert fanout == 3 * 5
  24. # FIXME: will be wrong for group conv1d
  25. # l = Conv1d(in_channels=2, out_channels=4, kernel_size=5, groups=2)
  26. # fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
  27. # assert fanin == 2 // 2 * 5
  28. # assert fanout == 4 // 2 * 5
  29. l = Conv2d(in_channels=2, out_channels=3, kernel_size=(5, 7))
  30. fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
  31. assert fanin == 2 * 5 * 7
  32. assert fanout == 3 * 5 * 7
  33. l = Conv2d(in_channels=2, out_channels=4, kernel_size=(5, 7), groups=2)
  34. fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
  35. assert fanin == 2 // 2 * 5 * 7
  36. assert fanout == 4 // 2 * 5 * 7
  37. # FIXME: will be wrong for conv3d
  38. # l = Conv3d(in_channels=2, out_channels=3, kernel_size=(5, 7, 9))
  39. # fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
  40. # assert fanin == 2 * 5 * 7 * 9
  41. # assert fanout == 3 * 5 * 7 * 9
  42. l = Conv3d(in_channels=2, out_channels=4, kernel_size=(5, 7, 9), groups=2)
  43. fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
  44. assert fanin == 2 // 2 * 5 * 7 * 9
  45. assert fanout == 4 // 2 * 5 * 7 * 9