GitOrigin-RevId: e690bc42b0
tags/v1.0.0-rc1
@@ -209,7 +209,7 @@ def conv_transpose2d( | |||||
dilate_w=dilate_w, | dilate_w=dilate_w, | ||||
strategy=get_conv_execution_strategy(), | strategy=get_conv_execution_strategy(), | ||||
) | ) | ||||
(output,) = apply(op, inp, weight) | |||||
(output,) = apply(op, weight, inp) | |||||
if bias is not None: | if bias is not None: | ||||
output += bias | output += bias | ||||
return output | return output | ||||
@@ -241,7 +241,7 @@ def local_conv2d( | |||||
pad_w=pad_w, | pad_w=pad_w, | ||||
dilate_h=dilate_h, | dilate_h=dilate_h, | ||||
dilate_w=dilate_w, | dilate_w=dilate_w, | ||||
strategy=get_conv_execution_strategy(), | |||||
# strategy=get_conv_execution_strategy(), | |||||
) | ) | ||||
(output,) = apply(op, inp, weight) | (output,) = apply(op, inp, weight) | ||||
if bias is not None: | if bias is not None: | ||||
@@ -724,7 +724,7 @@ def sync_batch_norm( | |||||
""" | """ | ||||
assert eps_mode in {"MAX", "ADDITIVE"}, "unknown eps_mode: {}".format(eps_mode) | assert eps_mode in {"MAX", "ADDITIVE"}, "unknown eps_mode: {}".format(eps_mode) | ||||
_channels = input.shape[1] | _channels = input.shape[1] | ||||
_ndim = len(input.shape) | |||||
_ndim = input.ndim | |||||
_param_shape = (1, _channels) + (1,) * (_ndim - 2) | _param_shape = (1, _channels) + (1,) * (_ndim - 2) | ||||
if training: | if training: | ||||
@@ -12,7 +12,7 @@ import numpy as np | |||||
from ..distributed.group import WORLD, Group | from ..distributed.group import WORLD, Group | ||||
from ..functional import batch_norm2d, sync_batch_norm | from ..functional import batch_norm2d, sync_batch_norm | ||||
from ..tensor_nn import Buffer, Parameter | |||||
from ..tensor_nn import Buffer, Parameter, Tensor | |||||
from . import init | from . import init | ||||
from .module import Module | from .module import Module | ||||
@@ -74,12 +74,12 @@ class _BatchNorm(Module): | |||||
_ndims = len(inp.shape) | _ndims = len(inp.shape) | ||||
if _ndims != 4: | if _ndims != 4: | ||||
origin_shape = inp.shapeof() | |||||
origin_shape = inp.shape | |||||
if _ndims == 2: | if _ndims == 2: | ||||
n, c = inp.shapeof(0), inp.shapeof(1) | |||||
n, c = inp.shape[0], inp.shape[1] | |||||
new_shape = (n, c, 1, 1) | new_shape = (n, c, 1, 1) | ||||
elif _ndims == 3: | elif _ndims == 3: | ||||
n, c, h = inp.shapeof(0), inp.shapeof(1), inp.shapeof(2) | |||||
n, c, h = inp.shape[0], inp.shape[1], inp.shape[2] | |||||
new_shape = (n, c, h, 1) | new_shape = (n, c, h, 1) | ||||
inp = inp.reshape(new_shape) | inp = inp.reshape(new_shape) | ||||
@@ -127,7 +127,7 @@ class SyncBatchNorm(_BatchNorm): | |||||
affine=True, | affine=True, | ||||
track_running_stats=True, | track_running_stats=True, | ||||
freeze=False, | freeze=False, | ||||
group: Optional[Group] = None, | |||||
group: Optional[Group] = WORLD, | |||||
) -> None: | ) -> None: | ||||
super().__init__( | super().__init__( | ||||
num_features, eps, momentum, affine, track_running_stats, freeze | num_features, eps, momentum, affine, track_running_stats, freeze | ||||
@@ -145,13 +145,16 @@ class SyncBatchNorm(_BatchNorm): | |||||
_ndims = len(inp.shape) | _ndims = len(inp.shape) | ||||
if _ndims != 4: | if _ndims != 4: | ||||
origin_shape = inp.shapeof() | |||||
new_shape = Tensor([1, 1, 1, 1], device=inp.device) | |||||
origin_shape = inp.shape | |||||
if _ndims == 2: | if _ndims == 2: | ||||
n, c = inp.shape[0], inp.shape[1] | |||||
new_shape = (n, c, 1, 1) | |||||
new_shape[:2] = origin_shape[:2] | |||||
elif _ndims == 3: | elif _ndims == 3: | ||||
n, c, h = inp.shape[0], inp.shape[1], inp.shape[2] | |||||
new_shape = (n, c, h, 1) | |||||
new_shape[:3] = origin_shape[:3] | |||||
else: | |||||
raise ValueError( | |||||
"expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape)) | |||||
) | |||||
inp = inp.reshape(new_shape) | inp = inp.reshape(new_shape) | ||||
@@ -376,7 +376,13 @@ class LocalConv2d(Conv2d): | |||||
def forward(self, inp): | def forward(self, inp): | ||||
return local_conv2d( | return local_conv2d( | ||||
inp, self.weight, self.stride, self.padding, self.dilation, self.conv_mode | |||||
inp, | |||||
self.weight, | |||||
None, | |||||
self.stride, | |||||
self.padding, | |||||
self.dilation, | |||||
self.conv_mode, | |||||
) | ) | ||||
@@ -0,0 +1,24 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import numpy as np | |||||
import megengine as mge | |||||
from megengine.module import LeakyReLU | |||||
from megengine.test import assertTensorClose | |||||
def test_leaky_relu(): | |||||
data = np.array([-8, -12, 6, 10]).astype(np.float32) | |||||
negative_slope = 0.1 | |||||
leaky_relu = LeakyReLU(negative_slope) | |||||
output = leaky_relu(mge.tensor(data)) | |||||
np_output = np.maximum(0, data) + negative_slope * np.minimum(0, data) | |||||
assertTensorClose(output.numpy(), np_output, max_err=0) |
@@ -0,0 +1,419 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import multiprocessing as mp | |||||
import platform | |||||
import numpy as np | |||||
import pytest | |||||
import megengine as mge | |||||
import megengine.distributed as dist | |||||
from megengine import tensor | |||||
from megengine.core._trace_option import use_tensor_shape | |||||
from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm | |||||
from megengine.tensor import Tensor | |||||
from megengine.test import assertTensorClose | |||||
@pytest.mark.skipif( | |||||
platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||||
) | |||||
@pytest.mark.skipif( | |||||
platform.system() == "Windows", reason="do not imp GPU mode at Windows now" | |||||
) | |||||
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape") | |||||
@pytest.mark.isolated_distributed | |||||
def test_syncbn(): | |||||
nr_chan = 8 | |||||
data_shape = (3, nr_chan, 4, 16) | |||||
momentum = 0.9 | |||||
eps = 1e-5 | |||||
running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32) | |||||
running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) | |||||
steps = 4 | |||||
nr_ranks = 2 | |||||
server = dist.Server(0) | |||||
port = server.py_server_port | |||||
def worker(rank, data, yv_expect, running_mean, running_var): | |||||
if mge.get_device_count("gpu") < nr_ranks: | |||||
return | |||||
dist.init_process_group("localhost", port, nr_ranks, rank, rank) | |||||
bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps) | |||||
data_tensor = tensor([]) | |||||
for i in range(steps): | |||||
data_tensor.set_value(data[i]) | |||||
yv = bn(data_tensor) | |||||
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||||
assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6) | |||||
assertTensorClose(running_var, bn.running_var.numpy(), max_err=5e-6) | |||||
xv = [] | |||||
for i in range(steps): | |||||
xv.append(np.random.normal(loc=2.3, size=data_shape).astype(np.float32)) | |||||
xv_transposed = np.transpose(xv[i], [0, 2, 3, 1]).reshape( | |||||
(data_shape[0] * data_shape[2] * data_shape[3], nr_chan) | |||||
) | |||||
mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) | |||||
var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) | |||||
sd = np.sqrt(var_biased + eps) | |||||
var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1)) | |||||
running_mean = running_mean * momentum + mean * (1 - momentum) | |||||
running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||||
yv_expect = (xv[i] - mean) / sd | |||||
data = [] | |||||
for i in range(nr_ranks): | |||||
data.append([]) | |||||
for j in range(steps): | |||||
data[i].append(xv[j][:, :, :, i * 8 : i * 8 + 8]) | |||||
procs = [] | |||||
for rank in range(nr_ranks): | |||||
p = mp.Process( | |||||
target=worker, | |||||
args=( | |||||
rank, | |||||
data[rank], | |||||
yv_expect[:, :, :, rank * 8 : rank * 8 + 8], | |||||
running_mean, | |||||
running_var, | |||||
), | |||||
) | |||||
p.start() | |||||
procs.append(p) | |||||
for p in procs: | |||||
p.join(10) | |||||
assert p.exitcode == 0 | |||||
def test_batchnorm(): | |||||
nr_chan = 8 | |||||
data_shape = (3, nr_chan, 4) | |||||
momentum = 0.9 | |||||
bn = BatchNorm1d(nr_chan, momentum=momentum) | |||||
running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32) | |||||
running_var = np.ones((1, nr_chan, 1), dtype=np.float32) | |||||
data = tensor([]) | |||||
for i in range(3): | |||||
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True) | |||||
xv_transposed = np.transpose(xv, [0, 2, 1]).reshape( | |||||
(data_shape[0] * data_shape[2], nr_chan) | |||||
) | |||||
var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1)) | |||||
sd = np.sqrt(var_biased + bn.eps) | |||||
var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1)) | |||||
running_mean = running_mean * momentum + mean * (1 - momentum) | |||||
running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||||
data.set_value(xv) | |||||
yv = bn(data) | |||||
yv_expect = (xv - mean) / sd | |||||
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||||
assertTensorClose( | |||||
running_mean.reshape(-1), bn.running_mean.numpy().reshape(-1), max_err=5e-6 | |||||
) | |||||
assertTensorClose( | |||||
running_var.reshape(-1), bn.running_var.numpy().reshape(-1), max_err=5e-6 | |||||
) | |||||
# test set 'training' flag to False | |||||
mean_backup = bn.running_mean.numpy() | |||||
var_backup = bn.running_var.numpy() | |||||
bn.training = False | |||||
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
data.set_value(xv) | |||||
yv1 = bn(data) | |||||
yv2 = bn(data) | |||||
assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) | |||||
assertTensorClose(mean_backup, bn.running_mean.numpy(), max_err=0) | |||||
assertTensorClose(var_backup, bn.running_var.numpy(), max_err=0) | |||||
yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps) | |||||
assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6) | |||||
@pytest.mark.skipif( | |||||
platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||||
) | |||||
@pytest.mark.skipif( | |||||
platform.system() == "Windows", reason="do not imp GPU mode at Windows now" | |||||
) | |||||
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape") | |||||
@pytest.mark.isolated_distributed | |||||
def test_syncbn1d(): | |||||
nr_chan = 8 | |||||
data_shape = (3, nr_chan, 4) | |||||
momentum = 0.9 | |||||
bn = SyncBatchNorm(nr_chan, momentum=momentum) | |||||
running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32) | |||||
running_var = np.ones((1, nr_chan, 1), dtype=np.float32) | |||||
data = tensor([]) | |||||
for i in range(3): | |||||
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True) | |||||
xv_transposed = np.transpose(xv, [0, 2, 1]).reshape( | |||||
(data_shape[0] * data_shape[2], nr_chan) | |||||
) | |||||
var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1)) | |||||
sd = np.sqrt(var_biased + bn.eps) | |||||
var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1)) | |||||
running_mean = running_mean * momentum + mean * (1 - momentum) | |||||
running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||||
data.set_value(xv) | |||||
yv = bn(data) | |||||
yv_expect = (xv - mean) / sd | |||||
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||||
assertTensorClose( | |||||
running_mean.reshape(-1), bn.running_mean.numpy().reshape(-1), max_err=5e-6 | |||||
) | |||||
assertTensorClose( | |||||
running_var.reshape(-1), bn.running_var.numpy().reshape(-1), max_err=5e-6 | |||||
) | |||||
# test set 'training' flag to False | |||||
mean_backup = bn.running_mean.numpy() | |||||
var_backup = bn.running_var.numpy() | |||||
bn.training = False | |||||
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
data.set_value(xv) | |||||
yv1 = bn(data) | |||||
yv2 = bn(data) | |||||
assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) | |||||
assertTensorClose(mean_backup, bn.running_mean.numpy(), max_err=0) | |||||
assertTensorClose(var_backup, bn.running_var.numpy(), max_err=0) | |||||
yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps) | |||||
assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6) | |||||
def test_batchnorm2d(): | |||||
nr_chan = 8 | |||||
data_shape = (3, nr_chan, 16, 16) | |||||
momentum = 0.9 | |||||
bn = BatchNorm2d(nr_chan, momentum=momentum) | |||||
running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32) | |||||
running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) | |||||
data = tensor([]) | |||||
for i in range(3): | |||||
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape( | |||||
(data_shape[0] * data_shape[2] * data_shape[3], nr_chan) | |||||
) | |||||
mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) | |||||
var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) | |||||
sd = np.sqrt(var_biased + bn.eps) | |||||
var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1)) | |||||
running_mean = running_mean * momentum + mean * (1 - momentum) | |||||
running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||||
data.set_value(xv) | |||||
yv = bn(data) | |||||
yv_expect = (xv - mean) / sd | |||||
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||||
assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6) | |||||
assertTensorClose(running_var, bn.running_var.numpy(), max_err=5e-6) | |||||
# test set 'training' flag to False | |||||
mean_backup = bn.running_mean.numpy() | |||||
var_backup = bn.running_var.numpy() | |||||
bn.training = False | |||||
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
data.set_value(xv) | |||||
yv1 = bn(data) | |||||
yv2 = bn(data) | |||||
assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) | |||||
assertTensorClose(mean_backup, bn.running_mean.numpy(), max_err=0) | |||||
assertTensorClose(var_backup, bn.running_var.numpy(), max_err=0) | |||||
yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps) | |||||
assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6) | |||||
@pytest.mark.skipif( | |||||
platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||||
) | |||||
@pytest.mark.skipif( | |||||
platform.system() == "Windows", reason="do not imp GPU mode at Windows now" | |||||
) | |||||
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape") | |||||
@pytest.mark.isolated_distributed | |||||
def test_syncbn2d(): | |||||
nr_chan = 8 | |||||
data_shape = (3, nr_chan, 16, 16) | |||||
momentum = 0.9 | |||||
bn = SyncBatchNorm(nr_chan, momentum=momentum) | |||||
running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32) | |||||
running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) | |||||
data = tensor([]) | |||||
for i in range(3): | |||||
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape( | |||||
(data_shape[0] * data_shape[2] * data_shape[3], nr_chan) | |||||
) | |||||
mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) | |||||
var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) | |||||
sd = np.sqrt(var_biased + bn.eps) | |||||
var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1)) | |||||
running_mean = running_mean * momentum + mean * (1 - momentum) | |||||
running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||||
data.set_value(xv) | |||||
yv = bn(data) | |||||
yv_expect = (xv - mean) / sd | |||||
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||||
assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6) | |||||
assertTensorClose(running_var, bn.running_var.numpy(), max_err=5e-6) | |||||
# test set 'training' flag to False | |||||
mean_backup = bn.running_mean.numpy() | |||||
var_backup = bn.running_var.numpy() | |||||
bn.training = False | |||||
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
data.set_value(xv) | |||||
yv1 = bn(data) | |||||
yv2 = bn(data) | |||||
assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) | |||||
assertTensorClose(mean_backup, bn.running_mean.numpy(), max_err=0) | |||||
assertTensorClose(var_backup, bn.running_var.numpy(), max_err=0) | |||||
yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps) | |||||
assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6) | |||||
def test_batchnorm_no_stats(): | |||||
nr_chan = 8 | |||||
data_shape = (3, nr_chan, 4) | |||||
bn = BatchNorm1d(8, track_running_stats=False) | |||||
data = tensor([]) | |||||
for i in range(4): | |||||
if i == 2: | |||||
bn.training = False | |||||
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True) | |||||
var = np.var( | |||||
np.transpose(xv, [0, 2, 1]).reshape( | |||||
(data_shape[0] * data_shape[2], nr_chan) | |||||
), | |||||
axis=0, | |||||
).reshape((1, nr_chan, 1)) | |||||
sd = np.sqrt(var + bn.eps) | |||||
data.set_value(xv) | |||||
yv = bn(data) | |||||
yv_expect = (xv - mean) / sd | |||||
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||||
@pytest.mark.skipif( | |||||
platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||||
) | |||||
@pytest.mark.skipif( | |||||
platform.system() == "Windows", reason="do not imp GPU mode at Windows now" | |||||
) | |||||
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape") | |||||
@pytest.mark.isolated_distributed | |||||
def test_syncbn_no_stats(): | |||||
nr_chan = 8 | |||||
data_shape = (3, nr_chan, 4) | |||||
bn = SyncBatchNorm(8, track_running_stats=False) | |||||
data = tensor([]) | |||||
for i in range(4): | |||||
if i == 2: | |||||
bn.training = False | |||||
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True) | |||||
var = np.var( | |||||
np.transpose(xv, [0, 2, 1]).reshape( | |||||
(data_shape[0] * data_shape[2], nr_chan) | |||||
), | |||||
axis=0, | |||||
).reshape((1, nr_chan, 1)) | |||||
sd = np.sqrt(var + bn.eps) | |||||
data.set_value(xv) | |||||
yv = bn(data) | |||||
yv_expect = (xv - mean) / sd | |||||
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||||
def test_batchnorm2d_no_stats(): | |||||
nr_chan = 8 | |||||
data_shape = (3, nr_chan, 16, 16) | |||||
bn = BatchNorm2d(8, track_running_stats=False) | |||||
data = tensor([]) | |||||
for i in range(4): | |||||
if i == 2: | |||||
bn.training = False | |||||
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape( | |||||
(data_shape[0] * data_shape[2] * data_shape[3], nr_chan) | |||||
) | |||||
mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) | |||||
var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) | |||||
sd = np.sqrt(var + bn.eps) | |||||
data.set_value(xv) | |||||
yv = bn(data) | |||||
yv_expect = (xv - mean) / sd | |||||
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||||
@pytest.mark.skipif( | |||||
platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||||
) | |||||
@pytest.mark.skipif( | |||||
platform.system() == "Windows", reason="do not imp GPU mode at Windows now" | |||||
) | |||||
@pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape") | |||||
@pytest.mark.isolated_distributed | |||||
def test_syncbn2d_no_stats(): | |||||
nr_chan = 8 | |||||
data_shape = (3, nr_chan, 16, 16) | |||||
bn = SyncBatchNorm(8, track_running_stats=False) | |||||
data = tensor([]) | |||||
for i in range(4): | |||||
if i == 2: | |||||
bn.training = False | |||||
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape( | |||||
(data_shape[0] * data_shape[2] * data_shape[3], nr_chan) | |||||
) | |||||
mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) | |||||
var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) | |||||
sd = np.sqrt(var + bn.eps) | |||||
data.set_value(xv) | |||||
yv = bn(data) | |||||
yv_expect = (xv - mean) / sd | |||||
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) |
@@ -0,0 +1,110 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import itertools | |||||
import numpy as np | |||||
from megengine import Parameter, tensor | |||||
from megengine.module import ConvTranspose2d, LocalConv2d | |||||
from megengine.test import assertTensorClose | |||||
def test_conv_transpose2d(): | |||||
SH, SW = 3, 1 | |||||
PH, PW = 2, 0 | |||||
N, IC, IH, IW = 4, 5, 8, 6 | |||||
KH, KW = 3, 4 | |||||
OC = 3 | |||||
BIAS = False | |||||
def getsize(inp, kern, stride): | |||||
return (inp - 1) * stride + kern | |||||
OH = getsize(IH, KH, SH) | |||||
OW = getsize(IW, KW, SW) | |||||
inp = np.random.normal(size=(N, IC, IH, IW)).astype(np.float32) | |||||
out = np.zeros((N, OC, OH, OW), dtype=np.float32) | |||||
weight = np.random.normal(size=(IC, OC, KH, KW)).astype(np.float32) | |||||
bias = np.random.normal(size=(1, OC, 1, 1)).astype(np.float32) | |||||
# naive calculation use numpy | |||||
for n, ic, ih, iw in itertools.product(*map(range, [N, IC, IH, IW])): | |||||
oh, ow = ih * SH, iw * SW | |||||
out[n, :, oh : oh + KH, ow : ow + KW] += inp[n, ic, ih, iw] * weight[ic] | |||||
out = out[:, :, PH : OH - PH, PW : OW - PW] | |||||
if BIAS: | |||||
out += bias | |||||
# megengine conv_transpose2d calculation | |||||
conv_transpose2d = ConvTranspose2d(IC, OC, (KH, KW), (SH, SW), (PH, PW), bias=BIAS) | |||||
conv_transpose2d.weight = Parameter(weight, dtype=np.float32) | |||||
if BIAS: | |||||
conv_transpose2d.bias = Parameter(bias, dtype=np.float32) | |||||
y = conv_transpose2d(tensor(inp)) | |||||
assertTensorClose(out, y.numpy(), max_err=2e-6) | |||||
def test_local_conv2d(): | |||||
batch_size = 10 | |||||
in_channels = 4 | |||||
out_channels = 8 | |||||
input_height = 8 | |||||
input_width = 8 | |||||
kernel_size = 3 | |||||
stride = 1 | |||||
padding = 1 | |||||
dilation = 1 | |||||
groups = 1 | |||||
local_conv2d = LocalConv2d( | |||||
in_channels=in_channels, | |||||
out_channels=out_channels, | |||||
input_height=input_height, | |||||
input_width=input_width, | |||||
kernel_size=kernel_size, | |||||
stride=stride, | |||||
padding=padding, | |||||
dilation=dilation, | |||||
groups=groups, | |||||
) | |||||
inputs = np.random.normal( | |||||
size=(batch_size, in_channels, input_height, input_width) | |||||
).astype(np.float32) | |||||
output_height = (input_height + padding * 2 - kernel_size) // stride + 1 | |||||
output_width = (input_width + padding * 2 - kernel_size) // stride + 1 | |||||
weights = np.random.normal( | |||||
size=( | |||||
groups, | |||||
output_height, | |||||
output_width, | |||||
in_channels // groups, | |||||
kernel_size, | |||||
kernel_size, | |||||
out_channels // groups, | |||||
) | |||||
).astype(np.float32) | |||||
local_conv2d.weight = Parameter(weights) | |||||
outputs = local_conv2d(tensor(inputs)) | |||||
# naive calculation use numpy | |||||
# only test output_height == input_height, output_width == input_width, group == 1 | |||||
inputs = np.pad(inputs, ((0, 0), (0, 0), (1, 1), (1, 1))) | |||||
expected = np.zeros( | |||||
(batch_size, out_channels, output_height, output_width), dtype=np.float32, | |||||
) | |||||
for n, oc, oh, ow in itertools.product( | |||||
*map(range, [batch_size, out_channels, output_height, output_width]) | |||||
): | |||||
ih, iw = oh * stride, ow * stride | |||||
expected[n, oc, ih, iw] = np.sum( | |||||
inputs[n, :, ih : ih + kernel_size, iw : iw + kernel_size] | |||||
* weights[0, oh, ow, :, :, :, oc] | |||||
) | |||||
assertTensorClose(outputs.numpy(), expected, max_err=1e-5) |
@@ -0,0 +1,46 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import os | |||||
import numpy as np | |||||
import pytest | |||||
import megengine as mge | |||||
from megengine import tensor | |||||
from megengine.module import Module | |||||
class MyModule(Module): | |||||
def __init__(self, data): | |||||
from megengine.module.external import CambriconSubgraph | |||||
super().__init__() | |||||
self.cambricon = CambriconSubgraph(data, "subnet0", True) | |||||
def forward(self, inputs): | |||||
out = self.cambricon(inputs) | |||||
return out | |||||
@pytest.mark.skip(reason="cambricon unimplemented") | |||||
def test_cambricon_module(): | |||||
model = "CambriconRuntimeOprTest.MutableBatchSize.mlu" | |||||
model = os.path.join(os.path.dirname(__file__), model) | |||||
with open(model, "rb") as f: | |||||
data = f.read() | |||||
m = MyModule(data) | |||||
inputs = [] | |||||
inputs.append(tensor(data=[], dtype=np.float16, device="cambricon0")) | |||||
inputs[0].set_value(np.random.normal(size=(1, 64, 32, 32)).astype(np.float16)) | |||||
def inference(inps): | |||||
pred = m(inps) | |||||
return pred | |||||
pred = inference(inputs) |
@@ -0,0 +1,27 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import pytest | |||||
from megengine.module import Conv2d, Linear | |||||
from megengine.module.init import calculate_fan_in_and_fan_out | |||||
def test_calculate_fan_in_and_fan_out(): | |||||
l = Linear(in_features=3, out_features=8) | |||||
fanin, fanout = calculate_fan_in_and_fan_out(l.weight) | |||||
assert fanin == 3 | |||||
assert fanout == 8 | |||||
with pytest.raises(ValueError): | |||||
calculate_fan_in_and_fan_out(l.bias) | |||||
l = Conv2d(in_channels=2, out_channels=3, kernel_size=(5, 7)) | |||||
fanin, fanout = calculate_fan_in_and_fan_out(l.weight) | |||||
assert fanin == 2 * 5 * 7 | |||||
assert fanout == 3 * 5 * 7 |
@@ -0,0 +1,614 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import os | |||||
import tempfile | |||||
from collections import OrderedDict | |||||
from io import BytesIO | |||||
import numpy as np | |||||
import pytest | |||||
import megengine as mge | |||||
import megengine.functional as F | |||||
from megengine import Buffer, Parameter, Tensor, tensor | |||||
from megengine.module import ( | |||||
BatchNorm1d, | |||||
BatchNorm2d, | |||||
Conv2d, | |||||
Linear, | |||||
Module, | |||||
Sequential, | |||||
) | |||||
from megengine.quantization.quantize import quantize, quantize_qat | |||||
from megengine.test import assertTensorClose | |||||
class MLP(Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.dense0 = Linear(28, 50) | |||||
self.dense1 = Linear(50, 20) | |||||
def forward(self, x): | |||||
x = self.dense0(x) | |||||
x = F.relu(x) | |||||
x = self.dense1(x) | |||||
return x | |||||
def has_gpu(num=1): | |||||
try: | |||||
mgb.comp_node("gpu{}".format(num - 1)) | |||||
except mgb.MegBrainError: | |||||
return False | |||||
return True | |||||
def randomNp(*args): | |||||
for arg in args: | |||||
assert isinstance(arg, int) | |||||
return np.random.random(args) | |||||
def randomTorch(*args): | |||||
import torch # pylint: disable=import-outside-toplevel | |||||
for arg in args: | |||||
assert isinstance(arg, int) | |||||
return torch.tensor(randomNp(*args), dtype=torch.float32) | |||||
def graph_mode(*modes): | |||||
if not set(modes).issubset({"eager", "static"}): | |||||
raise ValueError("graph mode must be in (eager, static)") | |||||
def decorator(func): | |||||
def wrapper(*args, **kwargs): | |||||
if "eager" in set(modes): | |||||
func(*args, **kwargs) | |||||
if "static" in set(modes): | |||||
with Graph() as cg: | |||||
cg.set_option("eager_evaluation", False) | |||||
func(*args, **kwargs) | |||||
return wrapper | |||||
return decorator | |||||
def _default_compare_fn(x, y): | |||||
assertTensorClose(x.numpy(), y) | |||||
def opr_test( | |||||
cases, | |||||
func, | |||||
mode=("eager", "static", "dynamic_shape"), | |||||
compare_fn=_default_compare_fn, | |||||
ref_fn=None, | |||||
**kwargs | |||||
): | |||||
""" | |||||
mode: the list of test mode which are eager, static and dynamic_shape | |||||
will test all the cases if None. | |||||
func: the function to run opr. | |||||
compare_fn: the function to compare the result and expected, use assertTensorClose if None. | |||||
ref_fn: the function to generate expected data, should assign output if None. | |||||
cases: the list which have dict element, the list length should be 2 for dynamic shape test. | |||||
and the dict should have input, | |||||
and should have output if ref_fn is None. | |||||
should use list for multiple inputs and outputs for each case. | |||||
kwargs: The additional kwargs for opr func. | |||||
simple examples: | |||||
dtype = np.float32 | |||||
cases = [{"input": [10, 20]}, {"input": [20, 30]}] | |||||
opr_test(cases, | |||||
F.eye, | |||||
ref_fn=lambda n, m: np.eye(n, m).astype(dtype), | |||||
dtype=dtype) | |||||
""" | |||||
def check_results(results, expected): | |||||
if not isinstance(results, Tuple): | |||||
results = (results,) | |||||
for r, e in zip(results, expected): | |||||
compare_fn(r, e) | |||||
def get_trace_fn(func, enabled, symbolic): | |||||
jit.trace.enabled = enabled | |||||
return jit.trace(func, symbolic=symbolic) | |||||
def get_param(cases, idx): | |||||
case = cases[idx] | |||||
inp = case.get("input", None) | |||||
outp = case.get("output", None) | |||||
if inp is None: | |||||
raise ValueError("the test case should have input") | |||||
if not isinstance(inp, List): | |||||
inp = (inp,) | |||||
else: | |||||
inp = tuple(inp) | |||||
if ref_fn is not None and callable(ref_fn): | |||||
outp = ref_fn(*inp) | |||||
if outp is None: | |||||
raise ValueError("the test case should have output or reference function") | |||||
if not isinstance(outp, List): | |||||
outp = (outp,) | |||||
else: | |||||
outp = tuple(outp) | |||||
return inp, outp | |||||
if not set(mode).issubset({"eager", "static", "dynamic_shape"}): | |||||
raise ValueError("opr test mode must be in (eager, static, dynamic_shape)") | |||||
if len(cases) == 0: | |||||
raise ValueError("should give one case at least") | |||||
if "dynamic_shape" in set(mode): | |||||
if len(cases) != 2: | |||||
raise ValueError("should give 2 cases for dynamic shape test") | |||||
if not callable(func): | |||||
raise ValueError("the input func should be callable") | |||||
inp, outp = get_param(cases, 0) | |||||
def run(*args, **kwargs): | |||||
return func(*args, **kwargs) | |||||
if "eager" in set(mode): | |||||
f = get_trace_fn(run, False, False) | |||||
results = f(*inp, **kwargs) | |||||
check_results(results, outp) | |||||
if "static" in set(mode) or "dynamic_shape" in set(mode): | |||||
f = get_trace_fn(run, True, True) | |||||
results = f(*inp, **kwargs) | |||||
check_results(results, outp) | |||||
if "dynamic_shape" in set(mode): | |||||
inp, outp = get_param(cases, 1) | |||||
results = f(*inp, **kwargs) | |||||
check_results(results, outp) | |||||
class MyModule(Module): | |||||
class InnerModule(Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.bn = BatchNorm2d(4) | |||||
def forward(self, x): | |||||
return self.bn(x) | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.i = self.InnerModule() | |||||
self.bn = BatchNorm2d(4) | |||||
self.param = Parameter(np.ones(1, dtype=np.float32)) | |||||
self.buff = Buffer(np.ones(1, dtype=np.float32)) | |||||
def forward(self, x): | |||||
x = self.i(x) | |||||
x = self.bn(x) | |||||
return x | |||||
def test_module_api(): | |||||
m = MyModule() | |||||
assert list(m.children()) == [m.bn, m.i] | |||||
assert list(m.named_children()) == [("bn", m.bn), ("i", m.i)] | |||||
assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] | |||||
assert list(m.named_modules()) == [ | |||||
("", m), | |||||
("bn", m.bn), | |||||
("i", m.i), | |||||
("i.bn", m.i.bn), | |||||
] | |||||
assert list(m.named_modules(prefix="x")) == [ | |||||
("x", m), | |||||
("x.bn", m.bn), | |||||
("x.i", m.i), | |||||
("x.i.bn", m.i.bn), | |||||
] | |||||
assert list(m.buffers()) == [ | |||||
m.bn.running_mean, | |||||
m.bn.running_var, | |||||
m.buff, | |||||
m.i.bn.running_mean, | |||||
m.i.bn.running_var, | |||||
] | |||||
assert list(m.buffers(recursive=False)) == [m.buff] | |||||
assert list(m.named_buffers()) == [ | |||||
("bn.running_mean", m.bn.running_mean), | |||||
("bn.running_var", m.bn.running_var), | |||||
("buff", m.buff), | |||||
("i.bn.running_mean", m.i.bn.running_mean), | |||||
("i.bn.running_var", m.i.bn.running_var), | |||||
] | |||||
assert list(m.parameters()) == [ | |||||
m.bn.bias, | |||||
m.bn.weight, | |||||
m.i.bn.bias, | |||||
m.i.bn.weight, | |||||
m.param, | |||||
] | |||||
assert list(m.named_parameters()) == [ | |||||
("bn.bias", m.bn.bias), | |||||
("bn.weight", m.bn.weight), | |||||
("i.bn.bias", m.i.bn.bias), | |||||
("i.bn.weight", m.i.bn.weight), | |||||
("param", m.param), | |||||
] | |||||
m.eval() | |||||
assert ( | |||||
m.training == False | |||||
and m.bn.training == False | |||||
and m.i.training == False | |||||
and m.i.bn.training == False | |||||
) | |||||
m.bn.train() | |||||
assert m.training == False and m.bn.training == True and m.i.bn.training == False | |||||
m.eval() | |||||
m.i.train() | |||||
assert ( | |||||
m.training == False | |||||
and m.bn.training == False | |||||
and m.i.training == True | |||||
and m.i.bn.training == True | |||||
) | |||||
m.eval() | |||||
m.train() | |||||
assert m.training == True and m.bn.training == True and m.i.bn.training == True | |||||
def fn(m): | |||||
m.training = False | |||||
m.apply(fn) | |||||
assert m.bn.training == False and m.i.bn.training == False | |||||
def test_module_api_reuse_submodule(): | |||||
m = MyModule() | |||||
m.h = m.i # pylint: disable=attribute-defined-outside-init | |||||
assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] | |||||
assert list(m.named_modules()) == [ | |||||
("", m), | |||||
("bn", m.bn), | |||||
("h", m.i), | |||||
("h.bn", m.i.bn), | |||||
] | |||||
def test_module_api_iterable_stability(): | |||||
m = MyModule() | |||||
l = list(m.modules()) | |||||
for _ in range(100): | |||||
assert list(m.modules()) == l | |||||
def test_module_api_hooks(): | |||||
net = MyModule() | |||||
pre_hook_num = 0 | |||||
post_hook_num = 0 | |||||
hooks = [] | |||||
def pre_hook(module, inputs): | |||||
nonlocal pre_hook_num | |||||
pre_hook_num += 1 | |||||
modified_inputs = tuple(inp + 1 for inp in inputs) | |||||
return modified_inputs | |||||
def post_hook(module, inputs, outputs): | |||||
nonlocal post_hook_num | |||||
post_hook_num += 1 | |||||
outputs += 1 | |||||
return outputs | |||||
net.apply(lambda module: hooks.append(module.register_forward_pre_hook(pre_hook))) | |||||
net.apply(lambda module: hooks.append(module.register_forward_hook(post_hook))) | |||||
shape = (1, 4, 1, 1) | |||||
x = tensor(np.zeros(shape, dtype=np.float32)) | |||||
y = net(x) | |||||
assert pre_hook_num == 4 | |||||
assert post_hook_num == 4 | |||||
mean1 = Parameter(np.zeros(shape), dtype=np.float32) | |||||
bn1 = F.batch_norm2d( | |||||
x + 3, mean1, Parameter(np.ones(shape), dtype=np.float32), training=True | |||||
) | |||||
assertTensorClose( | |||||
net.i.bn.running_mean.numpy(), mean1.numpy(), | |||||
) | |||||
mean2 = Parameter(np.zeros(shape), dtype=np.float32) | |||||
bn2 = F.batch_norm2d( | |||||
bn1 + 3, mean2, Parameter(np.ones(shape), dtype=np.float32), training=True | |||||
) | |||||
assertTensorClose( | |||||
net.bn.running_mean.numpy(), mean2.numpy(), | |||||
) | |||||
assertTensorClose((bn2 + 2).numpy(), y.numpy()) | |||||
assert len(hooks) == 8 | |||||
for handler in hooks: | |||||
handler.remove() | |||||
y = net(x) | |||||
assert pre_hook_num == 4 | |||||
assert post_hook_num == 4 | |||||
class MyModule2(Module): | |||||
class InnerModule(Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.bn = BatchNorm2d(4) | |||||
self.test_bool_key = {True: 1, False: 0} | |||||
def forward(self, x): | |||||
x = self.bn(x) | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.bn = BatchNorm2d(4) | |||||
self.a = [ | |||||
BatchNorm2d(4), | |||||
{"x": BatchNorm2d(4), "y": [BatchNorm2d(4), self.InnerModule()], "z": 0}, | |||||
(self.InnerModule(),), | |||||
] | |||||
def forward(self, x): | |||||
return x | |||||
def test_expand_structure(): | |||||
m = MyModule2() | |||||
assert list(m.named_modules()) == [ | |||||
("", m), | |||||
("a.0", m.a[0]), | |||||
("a.1.x", m.a[1]["x"]), | |||||
("a.1.y.0", m.a[1]["y"][0]), | |||||
("a.1.y.1", m.a[1]["y"][1]), | |||||
("a.1.y.1.bn", m.a[1]["y"][1].bn), | |||||
("a.2.0", m.a[2][0]), | |||||
("a.2.0.bn", m.a[2][0].bn), | |||||
("bn", m.bn), | |||||
] | |||||
def test_flatten_others(): | |||||
def be_others(obj): | |||||
return not isinstance(obj, (Tensor, Module)) | |||||
m = MyModule2() | |||||
assert len(list(m._flatten(with_key=True, predicate=be_others))) == 0 | |||||
def test_flatten_with_parent(): | |||||
m = MyModule2() | |||||
assert list(m.named_modules(with_parent=True)) == [ | |||||
("", m, None), | |||||
("a.0", m.a[0], m), | |||||
("a.1.x", m.a[1]["x"], m), | |||||
("a.1.y.0", m.a[1]["y"][0], m), | |||||
("a.1.y.1", m.a[1]["y"][1], m), | |||||
("a.1.y.1.bn", m.a[1]["y"][1].bn, m.a[1]["y"][1]), | |||||
("a.2.0", m.a[2][0], m), | |||||
("a.2.0.bn", m.a[2][0].bn, m.a[2][0]), | |||||
("bn", m.bn, m), | |||||
] | |||||
assert list(m.modules(with_parent=True)) == [ | |||||
(m, None), | |||||
(m.a[0], m), | |||||
(m.a[1]["x"], m), | |||||
(m.a[1]["y"][0], m), | |||||
(m.a[1]["y"][1], m), | |||||
(m.a[1]["y"][1].bn, m.a[1]["y"][1]), | |||||
(m.a[2][0], m), | |||||
(m.a[2][0].bn, m.a[2][0]), | |||||
(m.bn, m), | |||||
] | |||||
class MyModule3(Module): | |||||
class InnerModule(Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.bn = BatchNorm2d(4) | |||||
def forward(self, x): | |||||
x = self.bn(x) | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.bn = BatchNorm2d(4) | |||||
self.seq = Sequential(BatchNorm2d(4), self.InnerModule(),) | |||||
def forward(self, x): | |||||
return x | |||||
def test_module_api_with_sequential(): | |||||
m = MyModule3() | |||||
assert list(m.named_modules()) == [ | |||||
("", m), | |||||
("bn", m.bn), | |||||
("seq", m.seq), | |||||
("seq.0", m.seq[0]), | |||||
("seq.1", m.seq[1]), | |||||
("seq.1.bn", m.seq[1].bn), | |||||
] | |||||
def test_sequential_named_children(): | |||||
modules = OrderedDict() | |||||
modules["name0"] = Linear(20, 10) | |||||
modules["name1"] = Linear(10, 5) | |||||
modules["name2"] = Linear(5, 1) | |||||
m = Sequential(modules) | |||||
l = list(m.named_children()) | |||||
assert l[0][0] == "layer_values.0" | |||||
assert l[1][0] == "layer_values.1" | |||||
assert l[2][0] == "layer_values.2" | |||||
def test_state_dict(): | |||||
data_shape = (2, 28) | |||||
data = tensor([]) | |||||
data.set_value(np.random.random(data_shape)) | |||||
mlp = MLP() | |||||
pred0 = mlp(data) | |||||
with BytesIO() as fout: | |||||
mge.save(mlp.state_dict(), fout) | |||||
fout.seek(0) | |||||
state_dict = mge.load(fout) | |||||
state_dict["extra"] = None | |||||
mlp1 = MLP() | |||||
mlp1.load_state_dict(state_dict, strict=False) | |||||
pred1 = mlp1(data) | |||||
assertTensorClose(pred0.numpy(), pred1.numpy(), max_err=5e-6) | |||||
with pytest.raises(KeyError): | |||||
mlp1.load_state_dict(state_dict) | |||||
del state_dict["extra"] | |||||
del state_dict["dense0.bias"] | |||||
with pytest.raises(KeyError): | |||||
mlp1.load_state_dict(state_dict) | |||||
class AssertModule(Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.error_tensor_key = {True: tensor([]), False: 0} | |||||
def forward(self, x): | |||||
return x | |||||
def test_assert_message(): | |||||
m = AssertModule() | |||||
with pytest.raises( | |||||
AssertionError, match="keys for Tensor and Module must be str, error key: True" | |||||
): | |||||
list(m._flatten()) | |||||
class Simple(Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.conv0 = Conv2d(1, 1, kernel_size=3, bias=False) | |||||
self.conv1 = Conv2d(1, 1, kernel_size=3, bias=False) | |||||
self.conv1.weight = self.conv0.weight | |||||
def forward(self, inputs): | |||||
pass | |||||
def test_shared_param(): | |||||
net = Simple() | |||||
assert net.conv0.weight is net.conv1.weight | |||||
data = tensor(np.random.random((1, 1, 8, 8)).astype(np.float32)) | |||||
assertTensorClose(net.conv0(data).numpy(), net.conv1(data).numpy()) | |||||
with BytesIO() as f: | |||||
mge.save(net, f) | |||||
f.seek(0) | |||||
net1 = mge.load(f) | |||||
assert net1.conv0.weight is net1.conv1.weight | |||||
assertTensorClose(net1.conv0(data).numpy(), net1.conv1(data).numpy()) | |||||
with BytesIO() as f: | |||||
mge.save(net.conv0, f) | |||||
f.seek(0) | |||||
conv0 = mge.load(f) | |||||
with BytesIO() as f: | |||||
mge.save(net.conv1, f) | |||||
f.seek(0) | |||||
conv1 = mge.load(f) | |||||
assert conv0.weight is not conv1.weight | |||||
assertTensorClose(conv0(data).numpy(), conv1(data).numpy()) | |||||
def test_pickle_module(): | |||||
data_shape = (2, 28) | |||||
data = tensor([]) | |||||
data.set_value(np.random.random(data_shape)) | |||||
mlp = MLP() | |||||
# pickle before forward | |||||
with BytesIO() as fout: | |||||
mge.save(mlp, fout) | |||||
fout.seek(0) | |||||
mlp1 = mge.load(fout) | |||||
pred0 = mlp1(data) | |||||
pred1 = mlp(data) | |||||
# pickle after forward | |||||
with BytesIO() as fout: | |||||
mge.save(mlp, fout) | |||||
fout.seek(0) | |||||
mlp1 = mge.load(fout) | |||||
pred2 = mlp1(data) | |||||
assertTensorClose(pred0.numpy(), pred1.numpy(), max_err=5e-6) | |||||
assertTensorClose(pred0.numpy(), pred2.numpy(), max_err=5e-6) | |||||
@pytest.mark.skip(reason="under development") | |||||
def test_dump_model(): | |||||
data_shape = (2, 28) | |||||
data = tensor([]) | |||||
data.set_value(np.random.random(data_shape)) | |||||
mlp = MLP() | |||||
pred = mlp(data) | |||||
f = tempfile.NamedTemporaryFile(delete=False) | |||||
f_name = f.name | |||||
try: | |||||
mge.dump(pred, f_name) | |||||
finally: | |||||
f.close() | |||||
os.unlink(f_name) | |||||
def test_load_quantized(): | |||||
from megengine.core.tensor import dtype | |||||
data_shape = (2, 28) | |||||
data = tensor(np.random.random(data_shape), dtype="float32") | |||||
data = data.astype(dtype.qint8(0.1)) | |||||
mlp = MLP() | |||||
quantize_qat(mlp) | |||||
quantize(mlp) | |||||
mlp.dense0.weight = Parameter(mlp.dense0.weight.astype(dtype.qint8(0.001)).numpy()) | |||||
mlp.dense1.weight = Parameter(mlp.dense1.weight.astype(dtype.qint8(0.0002)).numpy()) | |||||
mlp.eval() | |||||
pred0 = mlp(data) | |||||
with BytesIO() as fout: | |||||
mge.save(mlp.state_dict(), fout) | |||||
fout.seek(0) | |||||
checkpoint = mge.load(fout) | |||||
# change mlp weight. | |||||
mlp.dense0.weight = Parameter( | |||||
mlp.dense0.weight.astype(dtype.qint8(0.00001)).numpy() | |||||
) | |||||
mlp.dense1.weight = Parameter( | |||||
mlp.dense1.weight.astype(dtype.qint8(0.2)).numpy() | |||||
) | |||||
mlp.load_state_dict(checkpoint) | |||||
pred1 = mlp(data) | |||||
assertTensorClose( | |||||
pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), max_err=5e-6 | |||||
) |
@@ -0,0 +1,91 @@ | |||||
from itertools import product | |||||
import numpy as np | |||||
from megengine import tensor | |||||
from megengine.module import ( | |||||
Conv2d, | |||||
ConvBn2d, | |||||
ConvRelu2d, | |||||
DequantStub, | |||||
Module, | |||||
QuantStub, | |||||
) | |||||
from megengine.quantization.quantize import disable_fake_quant, quantize_qat | |||||
from megengine.test import assertTensorClose | |||||
def test_qat_convbn2d(): | |||||
in_channels = 32 | |||||
out_channels = 64 | |||||
kernel_size = 3 | |||||
for groups, bias in product([1, 4], [True, False]): | |||||
module = ConvBn2d( | |||||
in_channels, out_channels, kernel_size, groups=groups, bias=bias | |||||
) | |||||
module.train() | |||||
qat_module = quantize_qat(module, inplace=False) | |||||
disable_fake_quant(qat_module) | |||||
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||||
normal_outputs = module(inputs) | |||||
# import pdb | |||||
# pdb.set_trace() | |||||
qat_outputs = qat_module(inputs) | |||||
assertTensorClose(normal_outputs.numpy(), qat_outputs.numpy(), max_err=5e-6) | |||||
assertTensorClose( | |||||
module.bn.running_mean.numpy(), | |||||
qat_module.bn.running_mean.numpy(), | |||||
max_err=5e-8, | |||||
) | |||||
assertTensorClose( | |||||
module.bn.running_var.numpy(), | |||||
qat_module.bn.running_var.numpy(), | |||||
max_err=5e-7, | |||||
) | |||||
module.eval() | |||||
normal_outputs = module(inputs) | |||||
qat_module.eval() | |||||
qat_outputs = qat_module(inputs) | |||||
assertTensorClose(normal_outputs.numpy(), qat_outputs.numpy(), max_err=5e-6) | |||||
def test_qat_conv(): | |||||
in_channels = 32 | |||||
out_channels = 64 | |||||
kernel_size = 3 | |||||
class TestNet(Module): | |||||
def __init__(self, groups, bias): | |||||
super().__init__() | |||||
self.quant = QuantStub() | |||||
self.dequant = DequantStub() | |||||
self.conv = Conv2d( | |||||
in_channels, out_channels, kernel_size, groups=groups, bias=bias | |||||
) | |||||
self.conv_relu = ConvRelu2d( | |||||
out_channels, in_channels, kernel_size, groups=groups, bias=bias | |||||
) | |||||
def forward(self, inp): | |||||
out = self.quant(inp) | |||||
out = self.conv(out) | |||||
out = self.conv_relu(out) | |||||
out = self.dequant(out) | |||||
return out | |||||
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||||
for groups, bias in product([1, 4], [True, False]): | |||||
net = TestNet(groups, bias) | |||||
net.train() | |||||
qat_net = quantize_qat(net, inplace=False) | |||||
disable_fake_quant(qat_net) | |||||
normal_outputs = net(inputs) | |||||
qat_outputs = qat_net(inputs) | |||||
assertTensorClose(normal_outputs.numpy(), qat_outputs.numpy()) | |||||
net.eval() | |||||
normal_outputs = net(inputs) | |||||
qat_net.eval() | |||||
qat_outputs = qat_net(inputs) | |||||
assertTensorClose(normal_outputs.numpy(), qat_outputs.numpy()) |
@@ -0,0 +1,89 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import copy | |||||
import numpy as np | |||||
import pytest | |||||
import megengine as mge | |||||
import megengine.functional as F | |||||
from megengine import Buffer, Parameter | |||||
from megengine.module import Conv2d | |||||
from megengine.test import assertTensorClose | |||||
def test_set_value(): | |||||
v0 = np.random.random((2, 3)).astype(np.float32) | |||||
param = Parameter(v0) | |||||
v1 = np.random.random((2, 3)).astype(np.float32) | |||||
param.set_value(v1) | |||||
assertTensorClose(param.numpy(), v1, max_err=5e-6) | |||||
v2 = np.random.random((3, 3)).astype(np.float32) | |||||
# TODO: add this | |||||
# with pytest.raises(ValueError): | |||||
# param.set_value(v2) | |||||
assertTensorClose(param.numpy(), v1, max_err=5e-6) | |||||
@pytest.mark.skip(reason="fill unsupported") | |||||
def test_fill(): | |||||
a = Buffer(np.zeros((2, 3), dtype=np.float32)) | |||||
a.fill(3) | |||||
assertTensorClose(a.numpy(), np.full((2, 3), 3, dtype=np.float32)) | |||||
a.fill(124.568) | |||||
assertTensorClose(a.numpy(), np.full((2, 3), 124.568, dtype=np.float32)) | |||||
# TODO: remove or rewrite following test | |||||
# def test_attach(): | |||||
# p_ = np.random.random((2, 3)).astype(np.float32) | |||||
# with Graph() as g: | |||||
# g.set_option('eager_evaluation', False) | |||||
# p = Parameter(p_) | |||||
# v = p * 2 | |||||
# f = compile(v, None) | |||||
# out, = f() | |||||
# assertTensorClose(out, p_ * 2) | |||||
# F.add_update(p, p) | |||||
# out, = f() | |||||
# assertTensorClose(out, p_ * 4) | |||||
# TODO: remove or rewrite following test | |||||
# def test_module_attach(): | |||||
# v = np.random.random((1, 3, 64, 64)).astype(np.float32) | |||||
# net = Conv2d(3, 16, 3) | |||||
# with Graph() as g: | |||||
# g.set_option('eager_evaluation', False) | |||||
# data0 = Input("data") | |||||
# f = compile(net(data0), None) | |||||
# out0, = f(data=v) | |||||
# data1 = Input("data", value=v) | |||||
# out1 = net(data1) | |||||
# assertTensorClose(out0, out1.numpy()) | |||||
# def test_shape_warning(): | |||||
# with Graph() as cg: | |||||
# cg.set_option("eager_evaluation", False) | |||||
# b = Buffer(np.ones((2, 3)).astype(np.float32)) | |||||
# with pytest.warns(None) as record: | |||||
# print(b.shape) | |||||
# if len(record) != 0: | |||||
# raise ValueError( | |||||
# "Getting the shape of a constant Tensor should throw no Warning" | |||||
# ) |
@@ -1,112 +0,0 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import platform | |||||
import pytest | |||||
import megengine as mge | |||||
import megengine.distributed as dist | |||||
from megengine import tensor | |||||
from megengine.distributed.group import Group | |||||
from megengine.distributed.helper import get_device_count_by_fork | |||||
from megengine.module import SyncBatchNorm | |||||
from megengine.test import assertTensorClose | |||||
@pytest.mark.skipif( | |||||
platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||||
) | |||||
@pytest.mark.skipif( | |||||
platform.system() == "Windows", reason="do not imp GPU mode at Windows now" | |||||
) | |||||
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 4, reason="need more gpu device") | |||||
@pytest.mark.isolated_distributed | |||||
def test_syncbn(): | |||||
import numpy as np | |||||
import multiprocessing as mp | |||||
from megengine.distributed.group import Server | |||||
from megengine.core._trace_option import use_tensor_shape | |||||
if use_tensor_shape(): # XXX: fix sync bn if use_tensor_shape | |||||
return | |||||
nr_chan = 8 | |||||
nr_ranks = 4 | |||||
data_shape = (3, nr_chan, 4, nr_ranks * 8) | |||||
momentum = 0.9 | |||||
eps = 1e-5 | |||||
running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32) | |||||
running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) | |||||
steps = 4 | |||||
server = Server(0) | |||||
port = server.py_server_port | |||||
def worker(rank, data, yv_expect, running_mean, running_var): | |||||
dist.init_process_group("localhost", port, nr_ranks, rank, rank) | |||||
group = Group([i for i in range(nr_ranks)]) | |||||
bn = SyncBatchNorm(nr_chan, eps=eps, momentum=momentum, group=group) | |||||
data_tensor = None | |||||
for i in range(steps): | |||||
if data_tensor is None: | |||||
data_tensor = tensor(data[i], device=f"gpu{rank}:0") | |||||
else: | |||||
data_tensor.set_value(data[i]) | |||||
yv = bn(data_tensor) | |||||
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||||
assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6) | |||||
assertTensorClose(running_var, bn.running_var.numpy(), max_err=5e-6) | |||||
xv = [] | |||||
for i in range(steps): | |||||
xv.append(np.random.normal(loc=2.3, size=data_shape).astype(np.float32)) | |||||
xv_transposed = np.transpose(xv[i], [0, 2, 3, 1]).reshape( | |||||
(data_shape[0] * data_shape[2] * data_shape[3], nr_chan) | |||||
) | |||||
mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) | |||||
var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) | |||||
sd = np.sqrt(var_biased + eps) | |||||
var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1)) | |||||
running_mean = running_mean * momentum + mean * (1 - momentum) | |||||
running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||||
yv_expect = (xv[i] - mean) / sd | |||||
data = [] | |||||
for i in range(nr_ranks): | |||||
data.append([]) | |||||
for j in range(steps): | |||||
data[i].append(xv[j][:, :, :, i * 8 : i * 8 + 8]) | |||||
procs = [] | |||||
for rank in range(nr_ranks): | |||||
p = mp.Process( | |||||
target=worker, | |||||
args=( | |||||
rank, | |||||
data[rank], | |||||
yv_expect[:, :, :, rank * 8 : rank * 8 + 8], | |||||
running_mean, | |||||
running_var, | |||||
), | |||||
) | |||||
p.start() | |||||
procs.append(p) | |||||
for p in procs: | |||||
p.join(10) | |||||
assert p.exitcode == 0 | |||||
def test_module_conv2d(): | |||||
from megengine.module.conv import Conv2d | |||||
conv = Conv2d(2, 3, 1) |