GitOrigin-RevId: 662618954b
release-1.10
@@ -71,7 +71,7 @@ public: | |||
MGE_WIN_DECLSPEC_FUC Result get(const Key& key); | |||
void clear(); | |||
MGE_WIN_DECLSPEC_FUC void clear(); | |||
private: | |||
struct Hash { | |||
@@ -9,7 +9,7 @@ | |||
import os | |||
from contextlib import contextmanager | |||
from ._imperative_rt.core2 import get_option, set_option | |||
from ._imperative_rt.core2 import _clear_algorithm_cache, get_option, set_option | |||
__compute_mode = "default" | |||
__conv_format = "default" | |||
@@ -44,6 +44,9 @@ def benchmark_kernel(mod): | |||
@benchmark_kernel.setter | |||
def benchmark_kernel(mod, option: bool): | |||
global _benchmark_kernel | |||
# try different strategy, then clear algorithm cache | |||
if option != _benchmark_kernel: | |||
_clear_algorithm_cache() | |||
_benchmark_kernel = option | |||
@@ -9,6 +9,7 @@ | |||
import os | |||
from ..core import _config | |||
from ..core._imperative_rt.core2 import _clear_algorithm_cache | |||
from ..core.ops import builtin | |||
from ..logger import get_logger | |||
from ..utils.deprecation import deprecated | |||
@@ -52,7 +53,6 @@ def set_execution_strategy(option): | |||
* "HEURISTIC": uses heuristic to choose the fastest algorithm. | |||
* "PROFILE": runs possible algorithms on a real device to find the best one. | |||
* "REPRODUCIBLE": uses algorithms that are reproducible. | |||
* "OPTIMIZED": uses algorithms that are optimized. | |||
The default strategy is "HEURISTIC", these options can be combined to | |||
form a combination option, e.g. PROFILE_REPRODUCIBLE is a combination | |||
@@ -70,22 +70,25 @@ def set_execution_strategy(option): | |||
It can also be set through the environment variable ``MEGENGINE_EXECUTION_STRATEGY``. | |||
""" | |||
_benchmark_kernel = False | |||
_deterministic_kernel = False | |||
if isinstance(option, Strategy): | |||
_config._benchmark_kernel = ( | |||
_benchmark_kernel = ( | |||
True if option & _valid_string_option["PROFILE"] != Strategy(0) else False | |||
) | |||
_config._deterministic_kernel = ( | |||
_deterministic_kernel = ( | |||
True | |||
if option & _valid_string_option["REPRODUCIBLE"] != Strategy(0) | |||
else False | |||
) | |||
if _benchmark_kernel != _config._benchmark_kernel: | |||
_clear_algorithm_cache() | |||
_config._benchmark_kernel = _benchmark_kernel | |||
_config._deterministic_kernel = _deterministic_kernel | |||
return | |||
assert isinstance(option, str) | |||
_config._benchmark_kernel = False | |||
_config._deterministic_kernel = False | |||
for opt in option.split("_"): | |||
if not opt in _valid_string_option: | |||
raise ValueError( | |||
@@ -93,10 +96,12 @@ def set_execution_strategy(option): | |||
_valid_string_option.keys() | |||
) | |||
) | |||
_config._benchmark_kernel |= _valid_string_option[opt] == Strategy.PROFILE | |||
_config._deterministic_kernel |= ( | |||
_valid_string_option[opt] == Strategy.REPRODUCIBLE | |||
) | |||
_benchmark_kernel |= _valid_string_option[opt] == Strategy.PROFILE | |||
_deterministic_kernel |= _valid_string_option[opt] == Strategy.REPRODUCIBLE | |||
if _benchmark_kernel != _config._benchmark_kernel: | |||
_clear_algorithm_cache() | |||
_config._benchmark_kernel = _benchmark_kernel | |||
_config._deterministic_kernel = _deterministic_kernel | |||
@deprecated(version="1.3", reason="use get_execution_strategy() instead") | |||
@@ -107,6 +112,3 @@ def get_conv_execution_strategy() -> str: | |||
@deprecated(version="1.3", reason="use set_execution_strategy() instead") | |||
def set_conv_execution_strategy(option: str): | |||
return set_execution_strategy(option) | |||
set_execution_strategy(os.getenv("MEGENGINE_EXECUTION_STRATEGY", "HEURISTIC")) |
@@ -26,6 +26,7 @@ | |||
#include "megbrain/opr/io.h" | |||
#include "megbrain/plugin/profiler.h" | |||
#include "megbrain/utils/stats.h" | |||
#include "megdnn/algorithm_cache.h" | |||
#include "./common.h" | |||
#include "./grad.h" | |||
@@ -1428,6 +1429,8 @@ void init_tensor(py::module m) { | |||
return set_amp_prec_dtype(false, dtype_name); | |||
}); | |||
m.def("_clear_algorithm_cache", [] { megdnn::AlgorithmCache::instance().clear(); }); | |||
py::register_exception<TraceError>(m, "TraceError"); | |||
} | |||
@@ -1,289 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import os | |||
import re | |||
import subprocess | |||
import sys | |||
import numpy as np | |||
import pytest | |||
import megengine as mge | |||
import megengine.autodiff as ad | |||
import megengine.functional as F | |||
from megengine import jit | |||
from megengine.core._trace_option import set_symbolic_shape | |||
from megengine.core.ops import builtin | |||
from megengine.core.tensor.utils import make_shape_tuple | |||
from megengine.functional.debug_param import set_execution_strategy | |||
from megengine.jit import SublinearMemoryConfig | |||
from megengine.module import ( | |||
AdaptiveAvgPool2d, | |||
AvgPool2d, | |||
BatchNorm2d, | |||
Conv2d, | |||
Linear, | |||
Module, | |||
) | |||
from megengine.optimizer import SGD | |||
from megengine.tensor import Tensor | |||
Strategy = builtin.ops.Convolution.Strategy | |||
def get_gpu_name(): | |||
try: | |||
gpu_info = subprocess.check_output( | |||
["nvidia-smi", "--query-gpu=gpu_name", "--format=csv,noheader"] | |||
) | |||
gpu_info = gpu_info.decode("ascii").split("\n")[0] | |||
except: | |||
gpu_info = "None" | |||
return gpu_info | |||
def get_cpu_name(): | |||
cpu_info = "None" | |||
try: | |||
cpu_info = subprocess.check_output(["cat", "/proc/cpuinfo"]).decode("ascii") | |||
for line in cpu_info.split("\n"): | |||
if "model name" in line: | |||
return re.sub(".*model name.*:", "", line, 1).strip() | |||
except: | |||
pass | |||
return cpu_info | |||
def get_xpu_name(): | |||
if mge.is_cuda_available(): | |||
return get_gpu_name() | |||
else: | |||
return get_cpu_name() | |||
class MnistNet(Module): | |||
def __init__(self, has_bn=False, use_adaptive_pooling=False): | |||
super().__init__() | |||
self.conv0 = Conv2d(1, 20, kernel_size=5, bias=True) | |||
if use_adaptive_pooling: | |||
self.pool0 = AdaptiveAvgPool2d(12) | |||
else: | |||
self.pool0 = AvgPool2d(2) | |||
self.conv1 = Conv2d(20, 20, kernel_size=5, bias=True) | |||
self.pool1 = AvgPool2d(2) | |||
self.fc0 = Linear(20 * 4 * 4, 500, bias=True) | |||
self.fc1 = Linear(500, 10, bias=True) | |||
self.bn0 = None | |||
self.bn1 = None | |||
if has_bn: | |||
self.bn0 = BatchNorm2d(20) | |||
self.bn1 = BatchNorm2d(20) | |||
def forward(self, x): | |||
x = self.conv0(x) | |||
if self.bn0: | |||
x = self.bn0(x) | |||
x = F.relu(x) | |||
x = self.pool0(x) | |||
x = self.conv1(x) | |||
if self.bn1: | |||
x = self.bn1(x) | |||
x = F.relu(x) | |||
x = self.pool1(x) | |||
x = F.flatten(x, 1) | |||
x = self.fc0(x) | |||
x = F.relu(x) | |||
x = self.fc1(x) | |||
return x | |||
def train(data, label, net, opt, gm): | |||
with gm: | |||
pred = net(data) | |||
loss = F.nn.cross_entropy(pred, label) | |||
gm.backward(loss) | |||
return loss | |||
def update_model(model_path): | |||
""" | |||
Update the dumped model with test cases for new reference values. | |||
The model with pre-trained weights is trained for one iter with the test data attached. | |||
The loss and updated net state dict is dumped. | |||
.. code-block:: python | |||
from test_correctness import update_model | |||
update_model('mnist_model_with_test.mge') # for gpu | |||
update_model('mnist_model_with_test_cpu.mge') # for cpu | |||
""" | |||
net = MnistNet(has_bn=True) | |||
checkpoint = mge.load(model_path) | |||
net.load_state_dict(checkpoint["net_init"]) | |||
lr = checkpoint["sgd_lr"] | |||
opt = SGD(net.parameters(), lr=lr) | |||
gm = ad.GradManager().attach(net.parameters()) | |||
data = Tensor(checkpoint["data"], dtype=np.float32) | |||
label = Tensor(checkpoint["label"], dtype=np.int32) | |||
opt.clear_grad() | |||
loss = train(data, label, net, opt, gm) | |||
opt.step() | |||
xpu_name = get_xpu_name() | |||
checkpoint.update( | |||
{"net_updated": net.state_dict(), "loss": loss.numpy(), "xpu": xpu_name} | |||
) | |||
mge.save(checkpoint, model_path) | |||
def run_train( | |||
model_path, | |||
use_jit, | |||
use_symbolic, | |||
sublinear_memory_config=None, | |||
max_err=None, | |||
use_adaptive_pooling=False, | |||
): | |||
""" | |||
Load the model with test cases and run the training for one iter. | |||
The loss and updated weights are compared with reference value to verify the correctness. | |||
Dump a new file with updated result by calling update_model | |||
if you think the test fails due to numerical rounding errors instead of bugs. | |||
Please think twice before you do so. | |||
""" | |||
net = MnistNet(has_bn=True, use_adaptive_pooling=use_adaptive_pooling) | |||
checkpoint = mge.load(model_path) | |||
net.load_state_dict(checkpoint["net_init"]) | |||
lr = checkpoint["sgd_lr"] | |||
opt = SGD(net.parameters(), lr=lr) | |||
gm = ad.GradManager().attach(net.parameters()) | |||
data = Tensor(checkpoint["data"], dtype=np.float32) | |||
label = Tensor(checkpoint["label"], dtype=np.int32) | |||
if max_err is None: | |||
max_err = 1e-5 | |||
train_func = train | |||
if use_jit: | |||
train_func = jit.trace( | |||
train_func, | |||
symbolic=use_symbolic, | |||
sublinear_memory_config=sublinear_memory_config, | |||
) | |||
opt.clear_grad() | |||
loss = train_func(data, label, net, opt, gm) | |||
opt.step() | |||
np.testing.assert_allclose(loss.numpy(), checkpoint["loss"], atol=max_err) | |||
for param, param_ref in zip( | |||
net.state_dict().items(), checkpoint["net_updated"].items() | |||
): | |||
assert param[0] == param_ref[0] | |||
if "bn" in param[0]: | |||
ref = param_ref[1].reshape(param[1].shape) | |||
np.testing.assert_allclose(param[1], ref, atol=max_err) | |||
else: | |||
np.testing.assert_allclose(param[1], param_ref[1], atol=max_err) | |||
def run_eval( | |||
model_path, | |||
use_symbolic, | |||
sublinear_memory_config=None, | |||
max_err=None, | |||
use_adaptive_pooling=False, | |||
): | |||
""" | |||
Load the model with test cases and run the training for one iter. | |||
The loss and updated weights are compared with reference value to verify the correctness. | |||
Dump a new file with updated result by calling update_model | |||
if you think the test fails due to numerical rounding errors instead of bugs. | |||
Please think twice before you do so. | |||
""" | |||
net = MnistNet(has_bn=True, use_adaptive_pooling=use_adaptive_pooling) | |||
checkpoint = mge.load(model_path) | |||
net.load_state_dict(checkpoint["net_init"]) | |||
data = Tensor(checkpoint["data"], dtype=np.float32) | |||
def eval_fun(data, *, net=None): | |||
pred = net(data) | |||
return pred | |||
refer_value = eval_fun(data, net=net) | |||
eval_fun = jit.trace(eval_fun, symbolic=use_symbolic) | |||
for _ in range(3): | |||
new_value = eval_fun(data, net=net) | |||
np.testing.assert_allclose(new_value.numpy(), refer_value.numpy(), atol=max_err) | |||
@pytest.mark.skip(reason="close it when cu111 ci") | |||
def test_correctness(): | |||
if mge.is_cuda_available(): | |||
model_name = "mnist_model_with_test.mge" | |||
else: | |||
model_name = "mnist_model_with_test_cpu.mge" | |||
model_path = os.path.join(os.path.dirname(__file__), model_name) | |||
set_execution_strategy(Strategy.HEURISTIC | Strategy.REPRODUCIBLE) | |||
run_train(model_path, False, False, max_err=1e-5) | |||
run_train(model_path, True, False, max_err=1e-5) | |||
run_train(model_path, True, True, max_err=1e-5) | |||
# sublinear | |||
config = SublinearMemoryConfig(genetic_nr_iter=10) | |||
run_train( | |||
model_path, True, True, sublinear_memory_config=config, max_err=1e-5, | |||
) | |||
run_eval(model_path, False, max_err=1e-7) | |||
run_eval(model_path, True, max_err=1e-7) | |||
@pytest.mark.skip(reason="close it when cu111 ci") | |||
def test_correctness_use_adaptive_pooling(): | |||
if mge.is_cuda_available(): | |||
model_name = "mnist_model_with_test.mge" | |||
else: | |||
model_name = "mnist_model_with_test_cpu.mge" | |||
model_path = os.path.join(os.path.dirname(__file__), model_name) | |||
set_execution_strategy("HEURISTIC_REPRODUCIBLE") | |||
run_train(model_path, False, False, max_err=1e-5, use_adaptive_pooling=True) | |||
run_train(model_path, True, False, max_err=1e-5, use_adaptive_pooling=True) | |||
run_train(model_path, True, True, max_err=1e-5, use_adaptive_pooling=True) | |||
# sublinear | |||
config = SublinearMemoryConfig(genetic_nr_iter=10) | |||
run_train( | |||
model_path, | |||
True, | |||
True, | |||
sublinear_memory_config=config, | |||
max_err=1e-5, | |||
use_adaptive_pooling=True, | |||
) | |||
run_eval(model_path, False, max_err=1e-7, use_adaptive_pooling=True) | |||
run_eval(model_path, True, max_err=1e-7, use_adaptive_pooling=True) |
@@ -7,11 +7,8 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import os | |||
import platform | |||
import re | |||
import subprocess | |||
import sys | |||
from math import ceil | |||
import numpy as np | |||
import pytest | |||
@@ -20,8 +17,6 @@ import megengine as mge | |||
import megengine.autodiff as ad | |||
import megengine.distributed as dist | |||
import megengine.functional as F | |||
from megengine.device import get_default_device, set_default_device | |||
from megengine.functional.debug_param import set_execution_strategy | |||
from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module | |||
from megengine.optimizer import SGD | |||
from megengine.tensor import Tensor | |||
@@ -198,5 +193,7 @@ def run_test( | |||
def test_dp_correctness(): | |||
model_name = "mnist_model_with_test.mge" | |||
model_path = os.path.join(os.path.dirname(__file__), model_name) | |||
set_execution_strategy("HEURISTIC_REPRODUCIBLE") | |||
old = mge.config.deterministic_kernel | |||
mge.config.deterministic_kernel = True | |||
run_test(model_path, False, False, max_err=5e-5) | |||
mge.config.deterministic_kernel = old |
@@ -11,21 +11,9 @@ import itertools | |||
import numpy as np | |||
import pytest | |||
import megengine as mge | |||
import megengine.module as M | |||
from megengine import Parameter, tensor | |||
from megengine.functional.debug_param import ( | |||
get_execution_strategy, | |||
set_execution_strategy, | |||
) | |||
from megengine.module import ConvTranspose2d, ConvTranspose3d, LocalConv2d | |||
@pytest.fixture | |||
def reproducible(): | |||
old = get_execution_strategy() | |||
set_execution_strategy("HEURISTIC_REPRODUCIBLE") | |||
yield | |||
set_execution_strategy(old) | |||
from megengine import tensor | |||
# NOTE: test in module for convenience. should really test in functional | |||
@@ -33,7 +21,9 @@ def reproducible(): | |||
"name", | |||
["Conv1d", "Conv2d", "Conv3d", "ConvTranspose2d", "ConvTranspose3d", "LocalConv2d"], | |||
) | |||
def test_conv_dtype_promotion(name, reproducible): | |||
def test_conv_dtype_promotion(name): | |||
old = mge.config.deterministic_kernel | |||
mge.config.deterministic_kernel = True | |||
N, Ci, Co, K = 2, 16, 32, 3 | |||
S = (7,) * int(name[-2]) | |||
if "Local" in name: | |||
@@ -42,3 +32,4 @@ def test_conv_dtype_promotion(name, reproducible): | |||
m = getattr(M, name)(Ci, Co, K) | |||
x = tensor(np.random.random(size=(N, Ci) + S).astype("float16")) | |||
np.testing.assert_equal(m(x).numpy(), m(x.astype("float32")).numpy()) | |||
mge.config.deterministic_kernel = old |
@@ -255,9 +255,8 @@ def test_conv_bias_int4(): | |||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu") | |||
@pytest.mark.require_ngpu(1) | |||
@pytest.mark.skipif( | |||
get_cuda_compute_capability(0) < 61, | |||
get_device_count("gpu") > 0 and get_cuda_compute_capability(0) < 61, | |||
reason="does not support int8 when gpu compute capability less than 6.1", | |||
) | |||
def test_conv_transpose2d(): | |||
@@ -5,6 +5,7 @@ import platform | |||
import numpy as np | |||
import pytest | |||
import megengine as mge | |||
import megengine.core.tensor.dtype as dtype | |||
import megengine.core.tensor.megbrain_graph as G | |||
import megengine.functional as F | |||
@@ -18,10 +19,6 @@ from megengine.device import ( | |||
get_device_count, | |||
is_cuda_available, | |||
) | |||
from megengine.functional.debug_param import ( | |||
get_execution_strategy, | |||
set_execution_strategy, | |||
) | |||
from megengine.functional.external import tensorrt_runtime_opr | |||
from megengine.jit.tracing import trace | |||
from megengine.tensor import Tensor | |||
@@ -110,25 +107,30 @@ def test_matinv(): | |||
@pytest.mark.parametrize( | |||
"execution_strategy", ["HEURISTIC_REPRODUCIBLE", "PROFILE_REPRODUCIBLE"] | |||
"benchmark_kernel, max_err", [(False, None), (True, 1e-5)], | |||
) | |||
def test_matmul(execution_strategy): | |||
def test_matmul(monkeypatch, benchmark_kernel, max_err): | |||
if get_device_count("gpu") == 0 and benchmark_kernel: | |||
return | |||
monkeypatch.setenv("MGE_FASTRUN_CACHE_TYPE", "MEMORY") | |||
old1, old2 = ( | |||
mge.config.benchmark_kernel, | |||
mge.config.deterministic_kernel, | |||
) | |||
mge.config.benchmark_kernel = benchmark_kernel | |||
mge.config.deterministic_kernel = True | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(data1, data2): | |||
return F.matmul(data1, data2) | |||
old = get_execution_strategy() | |||
set_execution_strategy(execution_strategy) | |||
max_err = None | |||
if execution_strategy == "PROFILE_REPRODUCIBLE": | |||
max_err = 1e-5 | |||
data1 = Tensor(np.random.random((32, 64))) | |||
data2 = Tensor(np.random.random((64, 16))) | |||
result = fwd(data1, data2) | |||
check_pygraph_dump(fwd, [data1, data2], [result], max_err=max_err) | |||
set_execution_strategy(old) | |||
mge.config.benchmark_kernel = old1 | |||
mge.config.deterministic_kernel = old2 | |||
monkeypatch.delenv("MGE_FASTRUN_CACHE_TYPE", raising=False) | |||
def test_batchmatmul(): | |||
@@ -290,9 +292,8 @@ def test_deformable_ps_roi_pooling(): | |||
check_pygraph_dump(fwd, [inp, rois, trans], [result]) | |||
@pytest.mark.require_ngpu(1) | |||
@pytest.mark.skipif( | |||
get_cuda_compute_capability(0) < 61, | |||
get_device_count("gpu") > 0 and get_cuda_compute_capability(0) < 61, | |||
reason="does not support int8 when gpu compute capability less than 6.1", | |||
) | |||
def test_convbias(): | |||