Browse Source

test(mge/imperative): add module test to imperative

GitOrigin-RevId: e690bc42b0
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
0067dcf068
12 changed files with 1443 additions and 126 deletions
  1. +3
    -3
      imperative/python/megengine/functional/nn.py
  2. +13
    -10
      imperative/python/megengine/module/batchnorm.py
  3. +7
    -1
      imperative/python/megengine/module/conv.py
  4. +24
    -0
      imperative/python/test/unit/module/test_activation.py
  5. +419
    -0
      imperative/python/test/unit/module/test_batchnorm.py
  6. +110
    -0
      imperative/python/test/unit/module/test_conv.py
  7. +46
    -0
      imperative/python/test/unit/module/test_external.py
  8. +27
    -0
      imperative/python/test/unit/module/test_init.py
  9. +614
    -0
      imperative/python/test/unit/module/test_module.py
  10. +91
    -0
      imperative/python/test/unit/module/test_qat.py
  11. +89
    -0
      imperative/python/test/unit/module/test_tensor.py
  12. +0
    -112
      imperative/python/test/unit/test_module.py

+ 3
- 3
imperative/python/megengine/functional/nn.py View File

@@ -209,7 +209,7 @@ def conv_transpose2d(
dilate_w=dilate_w,
strategy=get_conv_execution_strategy(),
)
(output,) = apply(op, inp, weight)
(output,) = apply(op, weight, inp)
if bias is not None:
output += bias
return output
@@ -241,7 +241,7 @@ def local_conv2d(
pad_w=pad_w,
dilate_h=dilate_h,
dilate_w=dilate_w,
strategy=get_conv_execution_strategy(),
# strategy=get_conv_execution_strategy(),
)
(output,) = apply(op, inp, weight)
if bias is not None:
@@ -724,7 +724,7 @@ def sync_batch_norm(
"""
assert eps_mode in {"MAX", "ADDITIVE"}, "unknown eps_mode: {}".format(eps_mode)
_channels = input.shape[1]
_ndim = len(input.shape)
_ndim = input.ndim
_param_shape = (1, _channels) + (1,) * (_ndim - 2)

if training:


+ 13
- 10
imperative/python/megengine/module/batchnorm.py View File

@@ -12,7 +12,7 @@ import numpy as np

from ..distributed.group import WORLD, Group
from ..functional import batch_norm2d, sync_batch_norm
from ..tensor_nn import Buffer, Parameter
from ..tensor_nn import Buffer, Parameter, Tensor
from . import init
from .module import Module

@@ -74,12 +74,12 @@ class _BatchNorm(Module):

_ndims = len(inp.shape)
if _ndims != 4:
origin_shape = inp.shapeof()
origin_shape = inp.shape
if _ndims == 2:
n, c = inp.shapeof(0), inp.shapeof(1)
n, c = inp.shape[0], inp.shape[1]
new_shape = (n, c, 1, 1)
elif _ndims == 3:
n, c, h = inp.shapeof(0), inp.shapeof(1), inp.shapeof(2)
n, c, h = inp.shape[0], inp.shape[1], inp.shape[2]
new_shape = (n, c, h, 1)

inp = inp.reshape(new_shape)
@@ -127,7 +127,7 @@ class SyncBatchNorm(_BatchNorm):
affine=True,
track_running_stats=True,
freeze=False,
group: Optional[Group] = None,
group: Optional[Group] = WORLD,
) -> None:
super().__init__(
num_features, eps, momentum, affine, track_running_stats, freeze
@@ -145,13 +145,16 @@ class SyncBatchNorm(_BatchNorm):

_ndims = len(inp.shape)
if _ndims != 4:
origin_shape = inp.shapeof()
new_shape = Tensor([1, 1, 1, 1], device=inp.device)
origin_shape = inp.shape
if _ndims == 2:
n, c = inp.shape[0], inp.shape[1]
new_shape = (n, c, 1, 1)
new_shape[:2] = origin_shape[:2]
elif _ndims == 3:
n, c, h = inp.shape[0], inp.shape[1], inp.shape[2]
new_shape = (n, c, h, 1)
new_shape[:3] = origin_shape[:3]
else:
raise ValueError(
"expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape))
)

inp = inp.reshape(new_shape)



+ 7
- 1
imperative/python/megengine/module/conv.py View File

@@ -376,7 +376,13 @@ class LocalConv2d(Conv2d):

def forward(self, inp):
return local_conv2d(
inp, self.weight, self.stride, self.padding, self.dilation, self.conv_mode
inp,
self.weight,
None,
self.stride,
self.padding,
self.dilation,
self.conv_mode,
)




+ 24
- 0
imperative/python/test/unit/module/test_activation.py View File

@@ -0,0 +1,24 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np

import megengine as mge
from megengine.module import LeakyReLU
from megengine.test import assertTensorClose


def test_leaky_relu():
data = np.array([-8, -12, 6, 10]).astype(np.float32)
negative_slope = 0.1

leaky_relu = LeakyReLU(negative_slope)
output = leaky_relu(mge.tensor(data))

np_output = np.maximum(0, data) + negative_slope * np.minimum(0, data)
assertTensorClose(output.numpy(), np_output, max_err=0)

+ 419
- 0
imperative/python/test/unit/module/test_batchnorm.py View File

@@ -0,0 +1,419 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import multiprocessing as mp
import platform

import numpy as np
import pytest

import megengine as mge
import megengine.distributed as dist
from megengine import tensor
from megengine.core._trace_option import use_tensor_shape
from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm
from megengine.tensor import Tensor
from megengine.test import assertTensorClose


@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
)
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape")
@pytest.mark.isolated_distributed
def test_syncbn():
nr_chan = 8
data_shape = (3, nr_chan, 4, 16)
momentum = 0.9
eps = 1e-5
running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32)
running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32)
steps = 4
nr_ranks = 2
server = dist.Server(0)
port = server.py_server_port

def worker(rank, data, yv_expect, running_mean, running_var):
if mge.get_device_count("gpu") < nr_ranks:
return
dist.init_process_group("localhost", port, nr_ranks, rank, rank)
bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps)
data_tensor = tensor([])
for i in range(steps):
data_tensor.set_value(data[i])
yv = bn(data_tensor)

assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)
assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6)
assertTensorClose(running_var, bn.running_var.numpy(), max_err=5e-6)

xv = []
for i in range(steps):
xv.append(np.random.normal(loc=2.3, size=data_shape).astype(np.float32))
xv_transposed = np.transpose(xv[i], [0, 2, 3, 1]).reshape(
(data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
)

mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)

var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
sd = np.sqrt(var_biased + eps)

var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1))
running_mean = running_mean * momentum + mean * (1 - momentum)
running_var = running_var * momentum + var_unbiased * (1 - momentum)

yv_expect = (xv[i] - mean) / sd

data = []
for i in range(nr_ranks):
data.append([])
for j in range(steps):
data[i].append(xv[j][:, :, :, i * 8 : i * 8 + 8])

procs = []
for rank in range(nr_ranks):
p = mp.Process(
target=worker,
args=(
rank,
data[rank],
yv_expect[:, :, :, rank * 8 : rank * 8 + 8],
running_mean,
running_var,
),
)
p.start()
procs.append(p)

for p in procs:
p.join(10)
assert p.exitcode == 0


def test_batchnorm():
nr_chan = 8
data_shape = (3, nr_chan, 4)
momentum = 0.9
bn = BatchNorm1d(nr_chan, momentum=momentum)
running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32)
running_var = np.ones((1, nr_chan, 1), dtype=np.float32)
data = tensor([])
for i in range(3):
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
xv_transposed = np.transpose(xv, [0, 2, 1]).reshape(
(data_shape[0] * data_shape[2], nr_chan)
)

var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1))
sd = np.sqrt(var_biased + bn.eps)

var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1))
running_mean = running_mean * momentum + mean * (1 - momentum)
running_var = running_var * momentum + var_unbiased * (1 - momentum)

data.set_value(xv)
yv = bn(data)
yv_expect = (xv - mean) / sd

assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)
assertTensorClose(
running_mean.reshape(-1), bn.running_mean.numpy().reshape(-1), max_err=5e-6
)
assertTensorClose(
running_var.reshape(-1), bn.running_var.numpy().reshape(-1), max_err=5e-6
)

# test set 'training' flag to False
mean_backup = bn.running_mean.numpy()
var_backup = bn.running_var.numpy()
bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
data.set_value(xv)
yv1 = bn(data)
yv2 = bn(data)
assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0)
assertTensorClose(mean_backup, bn.running_mean.numpy(), max_err=0)
assertTensorClose(var_backup, bn.running_var.numpy(), max_err=0)
yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps)
assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6)


@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
)
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape")
@pytest.mark.isolated_distributed
def test_syncbn1d():
nr_chan = 8
data_shape = (3, nr_chan, 4)
momentum = 0.9
bn = SyncBatchNorm(nr_chan, momentum=momentum)
running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32)
running_var = np.ones((1, nr_chan, 1), dtype=np.float32)
data = tensor([])
for i in range(3):
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
xv_transposed = np.transpose(xv, [0, 2, 1]).reshape(
(data_shape[0] * data_shape[2], nr_chan)
)

var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1))
sd = np.sqrt(var_biased + bn.eps)

var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1))
running_mean = running_mean * momentum + mean * (1 - momentum)
running_var = running_var * momentum + var_unbiased * (1 - momentum)

data.set_value(xv)
yv = bn(data)
yv_expect = (xv - mean) / sd

assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)
assertTensorClose(
running_mean.reshape(-1), bn.running_mean.numpy().reshape(-1), max_err=5e-6
)
assertTensorClose(
running_var.reshape(-1), bn.running_var.numpy().reshape(-1), max_err=5e-6
)

# test set 'training' flag to False
mean_backup = bn.running_mean.numpy()
var_backup = bn.running_var.numpy()
bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
data.set_value(xv)
yv1 = bn(data)
yv2 = bn(data)
assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0)
assertTensorClose(mean_backup, bn.running_mean.numpy(), max_err=0)
assertTensorClose(var_backup, bn.running_var.numpy(), max_err=0)
yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps)
assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6)


def test_batchnorm2d():
nr_chan = 8
data_shape = (3, nr_chan, 16, 16)
momentum = 0.9
bn = BatchNorm2d(nr_chan, momentum=momentum)
running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32)
running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32)
data = tensor([])
for i in range(3):
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
(data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
)

mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)

var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
sd = np.sqrt(var_biased + bn.eps)

var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1))
running_mean = running_mean * momentum + mean * (1 - momentum)
running_var = running_var * momentum + var_unbiased * (1 - momentum)

data.set_value(xv)
yv = bn(data)
yv_expect = (xv - mean) / sd

assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)
assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6)
assertTensorClose(running_var, bn.running_var.numpy(), max_err=5e-6)

# test set 'training' flag to False
mean_backup = bn.running_mean.numpy()
var_backup = bn.running_var.numpy()
bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
data.set_value(xv)
yv1 = bn(data)
yv2 = bn(data)
assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0)
assertTensorClose(mean_backup, bn.running_mean.numpy(), max_err=0)
assertTensorClose(var_backup, bn.running_var.numpy(), max_err=0)
yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps)
assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6)


@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
)
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape")
@pytest.mark.isolated_distributed
def test_syncbn2d():
nr_chan = 8
data_shape = (3, nr_chan, 16, 16)
momentum = 0.9
bn = SyncBatchNorm(nr_chan, momentum=momentum)
running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32)
running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32)
data = tensor([])
for i in range(3):
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
(data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
)

mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)

var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
sd = np.sqrt(var_biased + bn.eps)

var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1))
running_mean = running_mean * momentum + mean * (1 - momentum)
running_var = running_var * momentum + var_unbiased * (1 - momentum)

data.set_value(xv)
yv = bn(data)
yv_expect = (xv - mean) / sd

assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)
assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6)
assertTensorClose(running_var, bn.running_var.numpy(), max_err=5e-6)

# test set 'training' flag to False
mean_backup = bn.running_mean.numpy()
var_backup = bn.running_var.numpy()
bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
data.set_value(xv)
yv1 = bn(data)
yv2 = bn(data)
assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0)
assertTensorClose(mean_backup, bn.running_mean.numpy(), max_err=0)
assertTensorClose(var_backup, bn.running_var.numpy(), max_err=0)
yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps)
assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6)


def test_batchnorm_no_stats():
nr_chan = 8
data_shape = (3, nr_chan, 4)
bn = BatchNorm1d(8, track_running_stats=False)
data = tensor([])
for i in range(4):
if i == 2:
bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
var = np.var(
np.transpose(xv, [0, 2, 1]).reshape(
(data_shape[0] * data_shape[2], nr_chan)
),
axis=0,
).reshape((1, nr_chan, 1))
sd = np.sqrt(var + bn.eps)

data.set_value(xv)
yv = bn(data)
yv_expect = (xv - mean) / sd

assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)


@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
)
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape")
@pytest.mark.isolated_distributed
def test_syncbn_no_stats():
nr_chan = 8
data_shape = (3, nr_chan, 4)
bn = SyncBatchNorm(8, track_running_stats=False)
data = tensor([])
for i in range(4):
if i == 2:
bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
var = np.var(
np.transpose(xv, [0, 2, 1]).reshape(
(data_shape[0] * data_shape[2], nr_chan)
),
axis=0,
).reshape((1, nr_chan, 1))
sd = np.sqrt(var + bn.eps)

data.set_value(xv)
yv = bn(data)
yv_expect = (xv - mean) / sd

assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)


def test_batchnorm2d_no_stats():
nr_chan = 8
data_shape = (3, nr_chan, 16, 16)
bn = BatchNorm2d(8, track_running_stats=False)
data = tensor([])
for i in range(4):
if i == 2:
bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
(data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
)

mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
sd = np.sqrt(var + bn.eps)

data.set_value(xv)
yv = bn(data)
yv_expect = (xv - mean) / sd

assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)


@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
)
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape")
@pytest.mark.isolated_distributed
def test_syncbn2d_no_stats():
nr_chan = 8
data_shape = (3, nr_chan, 16, 16)
bn = SyncBatchNorm(8, track_running_stats=False)
data = tensor([])
for i in range(4):
if i == 2:
bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
(data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
)

mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)
var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
sd = np.sqrt(var + bn.eps)

data.set_value(xv)
yv = bn(data)
yv_expect = (xv - mean) / sd

assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)

+ 110
- 0
imperative/python/test/unit/module/test_conv.py View File

@@ -0,0 +1,110 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import itertools

import numpy as np

from megengine import Parameter, tensor
from megengine.module import ConvTranspose2d, LocalConv2d
from megengine.test import assertTensorClose


def test_conv_transpose2d():
SH, SW = 3, 1
PH, PW = 2, 0
N, IC, IH, IW = 4, 5, 8, 6
KH, KW = 3, 4
OC = 3
BIAS = False

def getsize(inp, kern, stride):
return (inp - 1) * stride + kern

OH = getsize(IH, KH, SH)
OW = getsize(IW, KW, SW)

inp = np.random.normal(size=(N, IC, IH, IW)).astype(np.float32)
out = np.zeros((N, OC, OH, OW), dtype=np.float32)
weight = np.random.normal(size=(IC, OC, KH, KW)).astype(np.float32)
bias = np.random.normal(size=(1, OC, 1, 1)).astype(np.float32)

# naive calculation use numpy
for n, ic, ih, iw in itertools.product(*map(range, [N, IC, IH, IW])):
oh, ow = ih * SH, iw * SW
out[n, :, oh : oh + KH, ow : ow + KW] += inp[n, ic, ih, iw] * weight[ic]
out = out[:, :, PH : OH - PH, PW : OW - PW]
if BIAS:
out += bias

# megengine conv_transpose2d calculation
conv_transpose2d = ConvTranspose2d(IC, OC, (KH, KW), (SH, SW), (PH, PW), bias=BIAS)
conv_transpose2d.weight = Parameter(weight, dtype=np.float32)
if BIAS:
conv_transpose2d.bias = Parameter(bias, dtype=np.float32)
y = conv_transpose2d(tensor(inp))

assertTensorClose(out, y.numpy(), max_err=2e-6)


def test_local_conv2d():
batch_size = 10
in_channels = 4
out_channels = 8
input_height = 8
input_width = 8
kernel_size = 3
stride = 1
padding = 1
dilation = 1
groups = 1
local_conv2d = LocalConv2d(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height,
input_width=input_width,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
inputs = np.random.normal(
size=(batch_size, in_channels, input_height, input_width)
).astype(np.float32)
output_height = (input_height + padding * 2 - kernel_size) // stride + 1
output_width = (input_width + padding * 2 - kernel_size) // stride + 1
weights = np.random.normal(
size=(
groups,
output_height,
output_width,
in_channels // groups,
kernel_size,
kernel_size,
out_channels // groups,
)
).astype(np.float32)
local_conv2d.weight = Parameter(weights)
outputs = local_conv2d(tensor(inputs))
# naive calculation use numpy
# only test output_height == input_height, output_width == input_width, group == 1
inputs = np.pad(inputs, ((0, 0), (0, 0), (1, 1), (1, 1)))
expected = np.zeros(
(batch_size, out_channels, output_height, output_width), dtype=np.float32,
)
for n, oc, oh, ow in itertools.product(
*map(range, [batch_size, out_channels, output_height, output_width])
):
ih, iw = oh * stride, ow * stride
expected[n, oc, ih, iw] = np.sum(
inputs[n, :, ih : ih + kernel_size, iw : iw + kernel_size]
* weights[0, oh, ow, :, :, :, oc]
)

assertTensorClose(outputs.numpy(), expected, max_err=1e-5)

+ 46
- 0
imperative/python/test/unit/module/test_external.py View File

@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os

import numpy as np
import pytest

import megengine as mge
from megengine import tensor
from megengine.module import Module


class MyModule(Module):
def __init__(self, data):
from megengine.module.external import CambriconSubgraph

super().__init__()
self.cambricon = CambriconSubgraph(data, "subnet0", True)

def forward(self, inputs):
out = self.cambricon(inputs)
return out


@pytest.mark.skip(reason="cambricon unimplemented")
def test_cambricon_module():
model = "CambriconRuntimeOprTest.MutableBatchSize.mlu"
model = os.path.join(os.path.dirname(__file__), model)
with open(model, "rb") as f:
data = f.read()
m = MyModule(data)
inputs = []
inputs.append(tensor(data=[], dtype=np.float16, device="cambricon0"))
inputs[0].set_value(np.random.normal(size=(1, 64, 32, 32)).astype(np.float16))

def inference(inps):
pred = m(inps)
return pred

pred = inference(inputs)

+ 27
- 0
imperative/python/test/unit/module/test_init.py View File

@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import pytest

from megengine.module import Conv2d, Linear
from megengine.module.init import calculate_fan_in_and_fan_out


def test_calculate_fan_in_and_fan_out():
l = Linear(in_features=3, out_features=8)
fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
assert fanin == 3
assert fanout == 8

with pytest.raises(ValueError):
calculate_fan_in_and_fan_out(l.bias)

l = Conv2d(in_channels=2, out_channels=3, kernel_size=(5, 7))
fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
assert fanin == 2 * 5 * 7
assert fanout == 3 * 5 * 7

+ 614
- 0
imperative/python/test/unit/module/test_module.py View File

@@ -0,0 +1,614 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os
import tempfile
from collections import OrderedDict
from io import BytesIO

import numpy as np
import pytest

import megengine as mge
import megengine.functional as F
from megengine import Buffer, Parameter, Tensor, tensor
from megengine.module import (
BatchNorm1d,
BatchNorm2d,
Conv2d,
Linear,
Module,
Sequential,
)
from megengine.quantization.quantize import quantize, quantize_qat
from megengine.test import assertTensorClose


class MLP(Module):
def __init__(self):
super().__init__()
self.dense0 = Linear(28, 50)
self.dense1 = Linear(50, 20)

def forward(self, x):
x = self.dense0(x)
x = F.relu(x)
x = self.dense1(x)
return x


def has_gpu(num=1):
try:
mgb.comp_node("gpu{}".format(num - 1))
except mgb.MegBrainError:
return False

return True


def randomNp(*args):
for arg in args:
assert isinstance(arg, int)
return np.random.random(args)


def randomTorch(*args):
import torch # pylint: disable=import-outside-toplevel

for arg in args:
assert isinstance(arg, int)
return torch.tensor(randomNp(*args), dtype=torch.float32)


def graph_mode(*modes):
if not set(modes).issubset({"eager", "static"}):
raise ValueError("graph mode must be in (eager, static)")

def decorator(func):
def wrapper(*args, **kwargs):
if "eager" in set(modes):
func(*args, **kwargs)
if "static" in set(modes):
with Graph() as cg:
cg.set_option("eager_evaluation", False)
func(*args, **kwargs)

return wrapper

return decorator


def _default_compare_fn(x, y):
assertTensorClose(x.numpy(), y)


def opr_test(
cases,
func,
mode=("eager", "static", "dynamic_shape"),
compare_fn=_default_compare_fn,
ref_fn=None,
**kwargs
):
"""
mode: the list of test mode which are eager, static and dynamic_shape
will test all the cases if None.
func: the function to run opr.
compare_fn: the function to compare the result and expected, use assertTensorClose if None.
ref_fn: the function to generate expected data, should assign output if None.
cases: the list which have dict element, the list length should be 2 for dynamic shape test.
and the dict should have input,
and should have output if ref_fn is None.
should use list for multiple inputs and outputs for each case.
kwargs: The additional kwargs for opr func.

simple examples:

dtype = np.float32
cases = [{"input": [10, 20]}, {"input": [20, 30]}]
opr_test(cases,
F.eye,
ref_fn=lambda n, m: np.eye(n, m).astype(dtype),
dtype=dtype)

"""

def check_results(results, expected):
if not isinstance(results, Tuple):
results = (results,)
for r, e in zip(results, expected):
compare_fn(r, e)

def get_trace_fn(func, enabled, symbolic):
jit.trace.enabled = enabled
return jit.trace(func, symbolic=symbolic)

def get_param(cases, idx):
case = cases[idx]
inp = case.get("input", None)
outp = case.get("output", None)
if inp is None:
raise ValueError("the test case should have input")
if not isinstance(inp, List):
inp = (inp,)
else:
inp = tuple(inp)
if ref_fn is not None and callable(ref_fn):
outp = ref_fn(*inp)
if outp is None:
raise ValueError("the test case should have output or reference function")
if not isinstance(outp, List):
outp = (outp,)
else:
outp = tuple(outp)

return inp, outp

if not set(mode).issubset({"eager", "static", "dynamic_shape"}):
raise ValueError("opr test mode must be in (eager, static, dynamic_shape)")

if len(cases) == 0:
raise ValueError("should give one case at least")

if "dynamic_shape" in set(mode):
if len(cases) != 2:
raise ValueError("should give 2 cases for dynamic shape test")

if not callable(func):
raise ValueError("the input func should be callable")

inp, outp = get_param(cases, 0)

def run(*args, **kwargs):
return func(*args, **kwargs)

if "eager" in set(mode):
f = get_trace_fn(run, False, False)
results = f(*inp, **kwargs)
check_results(results, outp)

if "static" in set(mode) or "dynamic_shape" in set(mode):
f = get_trace_fn(run, True, True)
results = f(*inp, **kwargs)
check_results(results, outp)
if "dynamic_shape" in set(mode):
inp, outp = get_param(cases, 1)
results = f(*inp, **kwargs)
check_results(results, outp)


class MyModule(Module):
class InnerModule(Module):
def __init__(self):
super().__init__()
self.bn = BatchNorm2d(4)

def forward(self, x):
return self.bn(x)

def __init__(self):
super().__init__()
self.i = self.InnerModule()
self.bn = BatchNorm2d(4)
self.param = Parameter(np.ones(1, dtype=np.float32))
self.buff = Buffer(np.ones(1, dtype=np.float32))

def forward(self, x):
x = self.i(x)
x = self.bn(x)
return x


def test_module_api():
m = MyModule()
assert list(m.children()) == [m.bn, m.i]
assert list(m.named_children()) == [("bn", m.bn), ("i", m.i)]
assert list(m.modules()) == [m, m.bn, m.i, m.i.bn]
assert list(m.named_modules()) == [
("", m),
("bn", m.bn),
("i", m.i),
("i.bn", m.i.bn),
]
assert list(m.named_modules(prefix="x")) == [
("x", m),
("x.bn", m.bn),
("x.i", m.i),
("x.i.bn", m.i.bn),
]
assert list(m.buffers()) == [
m.bn.running_mean,
m.bn.running_var,
m.buff,
m.i.bn.running_mean,
m.i.bn.running_var,
]
assert list(m.buffers(recursive=False)) == [m.buff]
assert list(m.named_buffers()) == [
("bn.running_mean", m.bn.running_mean),
("bn.running_var", m.bn.running_var),
("buff", m.buff),
("i.bn.running_mean", m.i.bn.running_mean),
("i.bn.running_var", m.i.bn.running_var),
]
assert list(m.parameters()) == [
m.bn.bias,
m.bn.weight,
m.i.bn.bias,
m.i.bn.weight,
m.param,
]
assert list(m.named_parameters()) == [
("bn.bias", m.bn.bias),
("bn.weight", m.bn.weight),
("i.bn.bias", m.i.bn.bias),
("i.bn.weight", m.i.bn.weight),
("param", m.param),
]
m.eval()
assert (
m.training == False
and m.bn.training == False
and m.i.training == False
and m.i.bn.training == False
)
m.bn.train()
assert m.training == False and m.bn.training == True and m.i.bn.training == False
m.eval()
m.i.train()
assert (
m.training == False
and m.bn.training == False
and m.i.training == True
and m.i.bn.training == True
)
m.eval()
m.train()
assert m.training == True and m.bn.training == True and m.i.bn.training == True

def fn(m):
m.training = False

m.apply(fn)
assert m.bn.training == False and m.i.bn.training == False


def test_module_api_reuse_submodule():
m = MyModule()
m.h = m.i # pylint: disable=attribute-defined-outside-init
assert list(m.modules()) == [m, m.bn, m.i, m.i.bn]
assert list(m.named_modules()) == [
("", m),
("bn", m.bn),
("h", m.i),
("h.bn", m.i.bn),
]


def test_module_api_iterable_stability():
m = MyModule()
l = list(m.modules())
for _ in range(100):
assert list(m.modules()) == l


def test_module_api_hooks():
net = MyModule()
pre_hook_num = 0
post_hook_num = 0
hooks = []

def pre_hook(module, inputs):
nonlocal pre_hook_num
pre_hook_num += 1
modified_inputs = tuple(inp + 1 for inp in inputs)
return modified_inputs

def post_hook(module, inputs, outputs):
nonlocal post_hook_num
post_hook_num += 1
outputs += 1
return outputs

net.apply(lambda module: hooks.append(module.register_forward_pre_hook(pre_hook)))
net.apply(lambda module: hooks.append(module.register_forward_hook(post_hook)))

shape = (1, 4, 1, 1)
x = tensor(np.zeros(shape, dtype=np.float32))
y = net(x)

assert pre_hook_num == 4
assert post_hook_num == 4
mean1 = Parameter(np.zeros(shape), dtype=np.float32)
bn1 = F.batch_norm2d(
x + 3, mean1, Parameter(np.ones(shape), dtype=np.float32), training=True
)
assertTensorClose(
net.i.bn.running_mean.numpy(), mean1.numpy(),
)
mean2 = Parameter(np.zeros(shape), dtype=np.float32)
bn2 = F.batch_norm2d(
bn1 + 3, mean2, Parameter(np.ones(shape), dtype=np.float32), training=True
)
assertTensorClose(
net.bn.running_mean.numpy(), mean2.numpy(),
)
assertTensorClose((bn2 + 2).numpy(), y.numpy())

assert len(hooks) == 8
for handler in hooks:
handler.remove()
y = net(x)
assert pre_hook_num == 4
assert post_hook_num == 4


class MyModule2(Module):
class InnerModule(Module):
def __init__(self):
super().__init__()
self.bn = BatchNorm2d(4)
self.test_bool_key = {True: 1, False: 0}

def forward(self, x):
x = self.bn(x)

def __init__(self):
super().__init__()
self.bn = BatchNorm2d(4)
self.a = [
BatchNorm2d(4),
{"x": BatchNorm2d(4), "y": [BatchNorm2d(4), self.InnerModule()], "z": 0},
(self.InnerModule(),),
]

def forward(self, x):
return x


def test_expand_structure():
m = MyModule2()
assert list(m.named_modules()) == [
("", m),
("a.0", m.a[0]),
("a.1.x", m.a[1]["x"]),
("a.1.y.0", m.a[1]["y"][0]),
("a.1.y.1", m.a[1]["y"][1]),
("a.1.y.1.bn", m.a[1]["y"][1].bn),
("a.2.0", m.a[2][0]),
("a.2.0.bn", m.a[2][0].bn),
("bn", m.bn),
]


def test_flatten_others():
def be_others(obj):
return not isinstance(obj, (Tensor, Module))

m = MyModule2()
assert len(list(m._flatten(with_key=True, predicate=be_others))) == 0


def test_flatten_with_parent():
m = MyModule2()
assert list(m.named_modules(with_parent=True)) == [
("", m, None),
("a.0", m.a[0], m),
("a.1.x", m.a[1]["x"], m),
("a.1.y.0", m.a[1]["y"][0], m),
("a.1.y.1", m.a[1]["y"][1], m),
("a.1.y.1.bn", m.a[1]["y"][1].bn, m.a[1]["y"][1]),
("a.2.0", m.a[2][0], m),
("a.2.0.bn", m.a[2][0].bn, m.a[2][0]),
("bn", m.bn, m),
]
assert list(m.modules(with_parent=True)) == [
(m, None),
(m.a[0], m),
(m.a[1]["x"], m),
(m.a[1]["y"][0], m),
(m.a[1]["y"][1], m),
(m.a[1]["y"][1].bn, m.a[1]["y"][1]),
(m.a[2][0], m),
(m.a[2][0].bn, m.a[2][0]),
(m.bn, m),
]


class MyModule3(Module):
class InnerModule(Module):
def __init__(self):
super().__init__()
self.bn = BatchNorm2d(4)

def forward(self, x):
x = self.bn(x)

def __init__(self):
super().__init__()
self.bn = BatchNorm2d(4)
self.seq = Sequential(BatchNorm2d(4), self.InnerModule(),)

def forward(self, x):
return x


def test_module_api_with_sequential():
m = MyModule3()
assert list(m.named_modules()) == [
("", m),
("bn", m.bn),
("seq", m.seq),
("seq.0", m.seq[0]),
("seq.1", m.seq[1]),
("seq.1.bn", m.seq[1].bn),
]


def test_sequential_named_children():
modules = OrderedDict()
modules["name0"] = Linear(20, 10)
modules["name1"] = Linear(10, 5)
modules["name2"] = Linear(5, 1)
m = Sequential(modules)
l = list(m.named_children())
assert l[0][0] == "layer_values.0"
assert l[1][0] == "layer_values.1"
assert l[2][0] == "layer_values.2"


def test_state_dict():
data_shape = (2, 28)
data = tensor([])
data.set_value(np.random.random(data_shape))
mlp = MLP()
pred0 = mlp(data)

with BytesIO() as fout:
mge.save(mlp.state_dict(), fout)
fout.seek(0)
state_dict = mge.load(fout)
state_dict["extra"] = None
mlp1 = MLP()
mlp1.load_state_dict(state_dict, strict=False)
pred1 = mlp1(data)
assertTensorClose(pred0.numpy(), pred1.numpy(), max_err=5e-6)
with pytest.raises(KeyError):
mlp1.load_state_dict(state_dict)
del state_dict["extra"]
del state_dict["dense0.bias"]
with pytest.raises(KeyError):
mlp1.load_state_dict(state_dict)


class AssertModule(Module):
def __init__(self):
super().__init__()
self.error_tensor_key = {True: tensor([]), False: 0}

def forward(self, x):
return x


def test_assert_message():
m = AssertModule()
with pytest.raises(
AssertionError, match="keys for Tensor and Module must be str, error key: True"
):
list(m._flatten())


class Simple(Module):
def __init__(self):
super().__init__()
self.conv0 = Conv2d(1, 1, kernel_size=3, bias=False)
self.conv1 = Conv2d(1, 1, kernel_size=3, bias=False)
self.conv1.weight = self.conv0.weight

def forward(self, inputs):
pass


def test_shared_param():
net = Simple()
assert net.conv0.weight is net.conv1.weight
data = tensor(np.random.random((1, 1, 8, 8)).astype(np.float32))
assertTensorClose(net.conv0(data).numpy(), net.conv1(data).numpy())
with BytesIO() as f:
mge.save(net, f)
f.seek(0)
net1 = mge.load(f)
assert net1.conv0.weight is net1.conv1.weight
assertTensorClose(net1.conv0(data).numpy(), net1.conv1(data).numpy())

with BytesIO() as f:
mge.save(net.conv0, f)
f.seek(0)
conv0 = mge.load(f)

with BytesIO() as f:
mge.save(net.conv1, f)
f.seek(0)
conv1 = mge.load(f)

assert conv0.weight is not conv1.weight
assertTensorClose(conv0(data).numpy(), conv1(data).numpy())


def test_pickle_module():
data_shape = (2, 28)
data = tensor([])
data.set_value(np.random.random(data_shape))
mlp = MLP()
# pickle before forward
with BytesIO() as fout:
mge.save(mlp, fout)
fout.seek(0)
mlp1 = mge.load(fout)
pred0 = mlp1(data)

pred1 = mlp(data)

# pickle after forward
with BytesIO() as fout:
mge.save(mlp, fout)
fout.seek(0)
mlp1 = mge.load(fout)
pred2 = mlp1(data)

assertTensorClose(pred0.numpy(), pred1.numpy(), max_err=5e-6)
assertTensorClose(pred0.numpy(), pred2.numpy(), max_err=5e-6)


@pytest.mark.skip(reason="under development")
def test_dump_model():
data_shape = (2, 28)
data = tensor([])
data.set_value(np.random.random(data_shape))
mlp = MLP()
pred = mlp(data)
f = tempfile.NamedTemporaryFile(delete=False)
f_name = f.name
try:
mge.dump(pred, f_name)
finally:
f.close()
os.unlink(f_name)


def test_load_quantized():
from megengine.core.tensor import dtype

data_shape = (2, 28)
data = tensor(np.random.random(data_shape), dtype="float32")
data = data.astype(dtype.qint8(0.1))
mlp = MLP()
quantize_qat(mlp)
quantize(mlp)
mlp.dense0.weight = Parameter(mlp.dense0.weight.astype(dtype.qint8(0.001)).numpy())
mlp.dense1.weight = Parameter(mlp.dense1.weight.astype(dtype.qint8(0.0002)).numpy())
mlp.eval()
pred0 = mlp(data)

with BytesIO() as fout:
mge.save(mlp.state_dict(), fout)
fout.seek(0)
checkpoint = mge.load(fout)
# change mlp weight.
mlp.dense0.weight = Parameter(
mlp.dense0.weight.astype(dtype.qint8(0.00001)).numpy()
)
mlp.dense1.weight = Parameter(
mlp.dense1.weight.astype(dtype.qint8(0.2)).numpy()
)
mlp.load_state_dict(checkpoint)
pred1 = mlp(data)

assertTensorClose(
pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), max_err=5e-6
)

+ 91
- 0
imperative/python/test/unit/module/test_qat.py View File

@@ -0,0 +1,91 @@
from itertools import product

import numpy as np

from megengine import tensor
from megengine.module import (
Conv2d,
ConvBn2d,
ConvRelu2d,
DequantStub,
Module,
QuantStub,
)
from megengine.quantization.quantize import disable_fake_quant, quantize_qat
from megengine.test import assertTensorClose


def test_qat_convbn2d():
in_channels = 32
out_channels = 64
kernel_size = 3
for groups, bias in product([1, 4], [True, False]):
module = ConvBn2d(
in_channels, out_channels, kernel_size, groups=groups, bias=bias
)
module.train()
qat_module = quantize_qat(module, inplace=False)
disable_fake_quant(qat_module)
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
normal_outputs = module(inputs)
# import pdb
# pdb.set_trace()
qat_outputs = qat_module(inputs)
assertTensorClose(normal_outputs.numpy(), qat_outputs.numpy(), max_err=5e-6)
assertTensorClose(
module.bn.running_mean.numpy(),
qat_module.bn.running_mean.numpy(),
max_err=5e-8,
)
assertTensorClose(
module.bn.running_var.numpy(),
qat_module.bn.running_var.numpy(),
max_err=5e-7,
)
module.eval()
normal_outputs = module(inputs)
qat_module.eval()
qat_outputs = qat_module(inputs)
assertTensorClose(normal_outputs.numpy(), qat_outputs.numpy(), max_err=5e-6)


def test_qat_conv():

in_channels = 32
out_channels = 64
kernel_size = 3

class TestNet(Module):
def __init__(self, groups, bias):
super().__init__()
self.quant = QuantStub()
self.dequant = DequantStub()
self.conv = Conv2d(
in_channels, out_channels, kernel_size, groups=groups, bias=bias
)
self.conv_relu = ConvRelu2d(
out_channels, in_channels, kernel_size, groups=groups, bias=bias
)

def forward(self, inp):
out = self.quant(inp)
out = self.conv(out)
out = self.conv_relu(out)
out = self.dequant(out)
return out

inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
for groups, bias in product([1, 4], [True, False]):
net = TestNet(groups, bias)
net.train()
qat_net = quantize_qat(net, inplace=False)
disable_fake_quant(qat_net)
normal_outputs = net(inputs)
qat_outputs = qat_net(inputs)
assertTensorClose(normal_outputs.numpy(), qat_outputs.numpy())

net.eval()
normal_outputs = net(inputs)
qat_net.eval()
qat_outputs = qat_net(inputs)
assertTensorClose(normal_outputs.numpy(), qat_outputs.numpy())

+ 89
- 0
imperative/python/test/unit/module/test_tensor.py View File

@@ -0,0 +1,89 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy

import numpy as np
import pytest

import megengine as mge
import megengine.functional as F
from megengine import Buffer, Parameter
from megengine.module import Conv2d
from megengine.test import assertTensorClose


def test_set_value():
v0 = np.random.random((2, 3)).astype(np.float32)
param = Parameter(v0)
v1 = np.random.random((2, 3)).astype(np.float32)
param.set_value(v1)
assertTensorClose(param.numpy(), v1, max_err=5e-6)
v2 = np.random.random((3, 3)).astype(np.float32)
# TODO: add this
# with pytest.raises(ValueError):
# param.set_value(v2)
assertTensorClose(param.numpy(), v1, max_err=5e-6)


@pytest.mark.skip(reason="fill unsupported")
def test_fill():
a = Buffer(np.zeros((2, 3), dtype=np.float32))
a.fill(3)
assertTensorClose(a.numpy(), np.full((2, 3), 3, dtype=np.float32))
a.fill(124.568)
assertTensorClose(a.numpy(), np.full((2, 3), 124.568, dtype=np.float32))


# TODO: remove or rewrite following test
# def test_attach():
# p_ = np.random.random((2, 3)).astype(np.float32)

# with Graph() as g:
# g.set_option('eager_evaluation', False)
# p = Parameter(p_)
# v = p * 2
# f = compile(v, None)

# out, = f()
# assertTensorClose(out, p_ * 2)

# F.add_update(p, p)
# out, = f()
# assertTensorClose(out, p_ * 4)


# TODO: remove or rewrite following test
# def test_module_attach():
# v = np.random.random((1, 3, 64, 64)).astype(np.float32)
# net = Conv2d(3, 16, 3)

# with Graph() as g:
# g.set_option('eager_evaluation', False)

# data0 = Input("data")
# f = compile(net(data0), None)

# out0, = f(data=v)

# data1 = Input("data", value=v)
# out1 = net(data1)

# assertTensorClose(out0, out1.numpy())


# def test_shape_warning():
# with Graph() as cg:
# cg.set_option("eager_evaluation", False)
# b = Buffer(np.ones((2, 3)).astype(np.float32))
# with pytest.warns(None) as record:
# print(b.shape)
# if len(record) != 0:
# raise ValueError(
# "Getting the shape of a constant Tensor should throw no Warning"
# )

+ 0
- 112
imperative/python/test/unit/test_module.py View File

@@ -1,112 +0,0 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import platform

import pytest

import megengine as mge
import megengine.distributed as dist
from megengine import tensor
from megengine.distributed.group import Group
from megengine.distributed.helper import get_device_count_by_fork
from megengine.module import SyncBatchNorm
from megengine.test import assertTensorClose


@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 4, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_syncbn():
import numpy as np
import multiprocessing as mp
from megengine.distributed.group import Server
from megengine.core._trace_option import use_tensor_shape

if use_tensor_shape(): # XXX: fix sync bn if use_tensor_shape
return

nr_chan = 8
nr_ranks = 4
data_shape = (3, nr_chan, 4, nr_ranks * 8)
momentum = 0.9
eps = 1e-5
running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32)
running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32)
steps = 4
server = Server(0)
port = server.py_server_port

def worker(rank, data, yv_expect, running_mean, running_var):
dist.init_process_group("localhost", port, nr_ranks, rank, rank)
group = Group([i for i in range(nr_ranks)])
bn = SyncBatchNorm(nr_chan, eps=eps, momentum=momentum, group=group)
data_tensor = None
for i in range(steps):
if data_tensor is None:
data_tensor = tensor(data[i], device=f"gpu{rank}:0")
else:
data_tensor.set_value(data[i])
yv = bn(data_tensor)

assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)
assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6)
assertTensorClose(running_var, bn.running_var.numpy(), max_err=5e-6)

xv = []
for i in range(steps):
xv.append(np.random.normal(loc=2.3, size=data_shape).astype(np.float32))
xv_transposed = np.transpose(xv[i], [0, 2, 3, 1]).reshape(
(data_shape[0] * data_shape[2] * data_shape[3], nr_chan)
)

mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1)

var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
sd = np.sqrt(var_biased + eps)

var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1))
running_mean = running_mean * momentum + mean * (1 - momentum)
running_var = running_var * momentum + var_unbiased * (1 - momentum)

yv_expect = (xv[i] - mean) / sd

data = []
for i in range(nr_ranks):
data.append([])
for j in range(steps):
data[i].append(xv[j][:, :, :, i * 8 : i * 8 + 8])

procs = []
for rank in range(nr_ranks):
p = mp.Process(
target=worker,
args=(
rank,
data[rank],
yv_expect[:, :, :, rank * 8 : rank * 8 + 8],
running_mean,
running_var,
),
)
p.start()
procs.append(p)
for p in procs:
p.join(10)
assert p.exitcode == 0


def test_module_conv2d():
from megengine.module.conv import Conv2d

conv = Conv2d(2, 3, 1)

Loading…
Cancel
Save