You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_pool2d.py 1.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. # -*- coding: utf-8 -*-
  2. import itertools
  3. import numpy as np
  4. from megengine import Parameter, tensor
  5. from megengine.module import AvgPool2d, MaxPool2d
  6. def test_avg_pool2d():
  7. def test_func(
  8. batch_size,
  9. in_channels,
  10. out_channels,
  11. in_height,
  12. in_width,
  13. kernel_size,
  14. stride,
  15. padding,
  16. ):
  17. pool = AvgPool2d(kernel_size, stride=stride, padding=padding, mode="average")
  18. inp = np.random.normal(
  19. size=(batch_size, in_channels, in_height, in_width)
  20. ).astype(np.float32)
  21. out_height = (in_height + padding * 2 - kernel_size) // stride + 1
  22. out_width = (in_width + padding * 2 - kernel_size) // stride + 1
  23. out = pool(tensor(inp))
  24. inp = np.pad(inp, ((0, 0), (0, 0), (padding, padding), (padding, padding)))
  25. expected = np.zeros(
  26. (batch_size, out_channels, out_height, out_width), dtype=np.float32,
  27. )
  28. for n, c, oh, ow in itertools.product(
  29. *map(range, [batch_size, out_channels, out_height, out_width])
  30. ):
  31. ih, iw = oh * stride, ow * stride
  32. expected[n, c, oh, ow] = np.sum(
  33. inp[n, c, ih : ih + kernel_size, iw : iw + kernel_size,]
  34. ) / (kernel_size * kernel_size)
  35. np.testing.assert_almost_equal(out.numpy(), expected, 1e-5)
  36. test_func(10, 4, 4, 5, 5, 2, 2, 1)
  37. test_func(10, 4, 4, 6, 6, 2, 2, 0)
  38. test_func(10, 16, 16, 14, 14, 2, 2, 0)