GitOrigin-RevId: 662618954b
release-1.10
@@ -71,7 +71,7 @@ public: | |||||
MGE_WIN_DECLSPEC_FUC Result get(const Key& key); | MGE_WIN_DECLSPEC_FUC Result get(const Key& key); | ||||
void clear(); | |||||
MGE_WIN_DECLSPEC_FUC void clear(); | |||||
private: | private: | ||||
struct Hash { | struct Hash { | ||||
@@ -9,7 +9,7 @@ | |||||
import os | import os | ||||
from contextlib import contextmanager | from contextlib import contextmanager | ||||
from ._imperative_rt.core2 import get_option, set_option | |||||
from ._imperative_rt.core2 import _clear_algorithm_cache, get_option, set_option | |||||
__compute_mode = "default" | __compute_mode = "default" | ||||
__conv_format = "default" | __conv_format = "default" | ||||
@@ -44,6 +44,9 @@ def benchmark_kernel(mod): | |||||
@benchmark_kernel.setter | @benchmark_kernel.setter | ||||
def benchmark_kernel(mod, option: bool): | def benchmark_kernel(mod, option: bool): | ||||
global _benchmark_kernel | global _benchmark_kernel | ||||
# try different strategy, then clear algorithm cache | |||||
if option != _benchmark_kernel: | |||||
_clear_algorithm_cache() | |||||
_benchmark_kernel = option | _benchmark_kernel = option | ||||
@@ -9,6 +9,7 @@ | |||||
import os | import os | ||||
from ..core import _config | from ..core import _config | ||||
from ..core._imperative_rt.core2 import _clear_algorithm_cache | |||||
from ..core.ops import builtin | from ..core.ops import builtin | ||||
from ..logger import get_logger | from ..logger import get_logger | ||||
from ..utils.deprecation import deprecated | from ..utils.deprecation import deprecated | ||||
@@ -52,7 +53,6 @@ def set_execution_strategy(option): | |||||
* "HEURISTIC": uses heuristic to choose the fastest algorithm. | * "HEURISTIC": uses heuristic to choose the fastest algorithm. | ||||
* "PROFILE": runs possible algorithms on a real device to find the best one. | * "PROFILE": runs possible algorithms on a real device to find the best one. | ||||
* "REPRODUCIBLE": uses algorithms that are reproducible. | * "REPRODUCIBLE": uses algorithms that are reproducible. | ||||
* "OPTIMIZED": uses algorithms that are optimized. | |||||
The default strategy is "HEURISTIC", these options can be combined to | The default strategy is "HEURISTIC", these options can be combined to | ||||
form a combination option, e.g. PROFILE_REPRODUCIBLE is a combination | form a combination option, e.g. PROFILE_REPRODUCIBLE is a combination | ||||
@@ -70,22 +70,25 @@ def set_execution_strategy(option): | |||||
It can also be set through the environment variable ``MEGENGINE_EXECUTION_STRATEGY``. | It can also be set through the environment variable ``MEGENGINE_EXECUTION_STRATEGY``. | ||||
""" | """ | ||||
_benchmark_kernel = False | |||||
_deterministic_kernel = False | |||||
if isinstance(option, Strategy): | if isinstance(option, Strategy): | ||||
_config._benchmark_kernel = ( | |||||
_benchmark_kernel = ( | |||||
True if option & _valid_string_option["PROFILE"] != Strategy(0) else False | True if option & _valid_string_option["PROFILE"] != Strategy(0) else False | ||||
) | ) | ||||
_config._deterministic_kernel = ( | |||||
_deterministic_kernel = ( | |||||
True | True | ||||
if option & _valid_string_option["REPRODUCIBLE"] != Strategy(0) | if option & _valid_string_option["REPRODUCIBLE"] != Strategy(0) | ||||
else False | else False | ||||
) | ) | ||||
if _benchmark_kernel != _config._benchmark_kernel: | |||||
_clear_algorithm_cache() | |||||
_config._benchmark_kernel = _benchmark_kernel | |||||
_config._deterministic_kernel = _deterministic_kernel | |||||
return | return | ||||
assert isinstance(option, str) | assert isinstance(option, str) | ||||
_config._benchmark_kernel = False | |||||
_config._deterministic_kernel = False | |||||
for opt in option.split("_"): | for opt in option.split("_"): | ||||
if not opt in _valid_string_option: | if not opt in _valid_string_option: | ||||
raise ValueError( | raise ValueError( | ||||
@@ -93,10 +96,12 @@ def set_execution_strategy(option): | |||||
_valid_string_option.keys() | _valid_string_option.keys() | ||||
) | ) | ||||
) | ) | ||||
_config._benchmark_kernel |= _valid_string_option[opt] == Strategy.PROFILE | |||||
_config._deterministic_kernel |= ( | |||||
_valid_string_option[opt] == Strategy.REPRODUCIBLE | |||||
) | |||||
_benchmark_kernel |= _valid_string_option[opt] == Strategy.PROFILE | |||||
_deterministic_kernel |= _valid_string_option[opt] == Strategy.REPRODUCIBLE | |||||
if _benchmark_kernel != _config._benchmark_kernel: | |||||
_clear_algorithm_cache() | |||||
_config._benchmark_kernel = _benchmark_kernel | |||||
_config._deterministic_kernel = _deterministic_kernel | |||||
@deprecated(version="1.3", reason="use get_execution_strategy() instead") | @deprecated(version="1.3", reason="use get_execution_strategy() instead") | ||||
@@ -107,6 +112,3 @@ def get_conv_execution_strategy() -> str: | |||||
@deprecated(version="1.3", reason="use set_execution_strategy() instead") | @deprecated(version="1.3", reason="use set_execution_strategy() instead") | ||||
def set_conv_execution_strategy(option: str): | def set_conv_execution_strategy(option: str): | ||||
return set_execution_strategy(option) | return set_execution_strategy(option) | ||||
set_execution_strategy(os.getenv("MEGENGINE_EXECUTION_STRATEGY", "HEURISTIC")) |
@@ -26,6 +26,7 @@ | |||||
#include "megbrain/opr/io.h" | #include "megbrain/opr/io.h" | ||||
#include "megbrain/plugin/profiler.h" | #include "megbrain/plugin/profiler.h" | ||||
#include "megbrain/utils/stats.h" | #include "megbrain/utils/stats.h" | ||||
#include "megdnn/algorithm_cache.h" | |||||
#include "./common.h" | #include "./common.h" | ||||
#include "./grad.h" | #include "./grad.h" | ||||
@@ -1428,6 +1429,8 @@ void init_tensor(py::module m) { | |||||
return set_amp_prec_dtype(false, dtype_name); | return set_amp_prec_dtype(false, dtype_name); | ||||
}); | }); | ||||
m.def("_clear_algorithm_cache", [] { megdnn::AlgorithmCache::instance().clear(); }); | |||||
py::register_exception<TraceError>(m, "TraceError"); | py::register_exception<TraceError>(m, "TraceError"); | ||||
} | } | ||||
@@ -1,289 +0,0 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import os | |||||
import re | |||||
import subprocess | |||||
import sys | |||||
import numpy as np | |||||
import pytest | |||||
import megengine as mge | |||||
import megengine.autodiff as ad | |||||
import megengine.functional as F | |||||
from megengine import jit | |||||
from megengine.core._trace_option import set_symbolic_shape | |||||
from megengine.core.ops import builtin | |||||
from megengine.core.tensor.utils import make_shape_tuple | |||||
from megengine.functional.debug_param import set_execution_strategy | |||||
from megengine.jit import SublinearMemoryConfig | |||||
from megengine.module import ( | |||||
AdaptiveAvgPool2d, | |||||
AvgPool2d, | |||||
BatchNorm2d, | |||||
Conv2d, | |||||
Linear, | |||||
Module, | |||||
) | |||||
from megengine.optimizer import SGD | |||||
from megengine.tensor import Tensor | |||||
Strategy = builtin.ops.Convolution.Strategy | |||||
def get_gpu_name(): | |||||
try: | |||||
gpu_info = subprocess.check_output( | |||||
["nvidia-smi", "--query-gpu=gpu_name", "--format=csv,noheader"] | |||||
) | |||||
gpu_info = gpu_info.decode("ascii").split("\n")[0] | |||||
except: | |||||
gpu_info = "None" | |||||
return gpu_info | |||||
def get_cpu_name(): | |||||
cpu_info = "None" | |||||
try: | |||||
cpu_info = subprocess.check_output(["cat", "/proc/cpuinfo"]).decode("ascii") | |||||
for line in cpu_info.split("\n"): | |||||
if "model name" in line: | |||||
return re.sub(".*model name.*:", "", line, 1).strip() | |||||
except: | |||||
pass | |||||
return cpu_info | |||||
def get_xpu_name(): | |||||
if mge.is_cuda_available(): | |||||
return get_gpu_name() | |||||
else: | |||||
return get_cpu_name() | |||||
class MnistNet(Module): | |||||
def __init__(self, has_bn=False, use_adaptive_pooling=False): | |||||
super().__init__() | |||||
self.conv0 = Conv2d(1, 20, kernel_size=5, bias=True) | |||||
if use_adaptive_pooling: | |||||
self.pool0 = AdaptiveAvgPool2d(12) | |||||
else: | |||||
self.pool0 = AvgPool2d(2) | |||||
self.conv1 = Conv2d(20, 20, kernel_size=5, bias=True) | |||||
self.pool1 = AvgPool2d(2) | |||||
self.fc0 = Linear(20 * 4 * 4, 500, bias=True) | |||||
self.fc1 = Linear(500, 10, bias=True) | |||||
self.bn0 = None | |||||
self.bn1 = None | |||||
if has_bn: | |||||
self.bn0 = BatchNorm2d(20) | |||||
self.bn1 = BatchNorm2d(20) | |||||
def forward(self, x): | |||||
x = self.conv0(x) | |||||
if self.bn0: | |||||
x = self.bn0(x) | |||||
x = F.relu(x) | |||||
x = self.pool0(x) | |||||
x = self.conv1(x) | |||||
if self.bn1: | |||||
x = self.bn1(x) | |||||
x = F.relu(x) | |||||
x = self.pool1(x) | |||||
x = F.flatten(x, 1) | |||||
x = self.fc0(x) | |||||
x = F.relu(x) | |||||
x = self.fc1(x) | |||||
return x | |||||
def train(data, label, net, opt, gm): | |||||
with gm: | |||||
pred = net(data) | |||||
loss = F.nn.cross_entropy(pred, label) | |||||
gm.backward(loss) | |||||
return loss | |||||
def update_model(model_path): | |||||
""" | |||||
Update the dumped model with test cases for new reference values. | |||||
The model with pre-trained weights is trained for one iter with the test data attached. | |||||
The loss and updated net state dict is dumped. | |||||
.. code-block:: python | |||||
from test_correctness import update_model | |||||
update_model('mnist_model_with_test.mge') # for gpu | |||||
update_model('mnist_model_with_test_cpu.mge') # for cpu | |||||
""" | |||||
net = MnistNet(has_bn=True) | |||||
checkpoint = mge.load(model_path) | |||||
net.load_state_dict(checkpoint["net_init"]) | |||||
lr = checkpoint["sgd_lr"] | |||||
opt = SGD(net.parameters(), lr=lr) | |||||
gm = ad.GradManager().attach(net.parameters()) | |||||
data = Tensor(checkpoint["data"], dtype=np.float32) | |||||
label = Tensor(checkpoint["label"], dtype=np.int32) | |||||
opt.clear_grad() | |||||
loss = train(data, label, net, opt, gm) | |||||
opt.step() | |||||
xpu_name = get_xpu_name() | |||||
checkpoint.update( | |||||
{"net_updated": net.state_dict(), "loss": loss.numpy(), "xpu": xpu_name} | |||||
) | |||||
mge.save(checkpoint, model_path) | |||||
def run_train( | |||||
model_path, | |||||
use_jit, | |||||
use_symbolic, | |||||
sublinear_memory_config=None, | |||||
max_err=None, | |||||
use_adaptive_pooling=False, | |||||
): | |||||
""" | |||||
Load the model with test cases and run the training for one iter. | |||||
The loss and updated weights are compared with reference value to verify the correctness. | |||||
Dump a new file with updated result by calling update_model | |||||
if you think the test fails due to numerical rounding errors instead of bugs. | |||||
Please think twice before you do so. | |||||
""" | |||||
net = MnistNet(has_bn=True, use_adaptive_pooling=use_adaptive_pooling) | |||||
checkpoint = mge.load(model_path) | |||||
net.load_state_dict(checkpoint["net_init"]) | |||||
lr = checkpoint["sgd_lr"] | |||||
opt = SGD(net.parameters(), lr=lr) | |||||
gm = ad.GradManager().attach(net.parameters()) | |||||
data = Tensor(checkpoint["data"], dtype=np.float32) | |||||
label = Tensor(checkpoint["label"], dtype=np.int32) | |||||
if max_err is None: | |||||
max_err = 1e-5 | |||||
train_func = train | |||||
if use_jit: | |||||
train_func = jit.trace( | |||||
train_func, | |||||
symbolic=use_symbolic, | |||||
sublinear_memory_config=sublinear_memory_config, | |||||
) | |||||
opt.clear_grad() | |||||
loss = train_func(data, label, net, opt, gm) | |||||
opt.step() | |||||
np.testing.assert_allclose(loss.numpy(), checkpoint["loss"], atol=max_err) | |||||
for param, param_ref in zip( | |||||
net.state_dict().items(), checkpoint["net_updated"].items() | |||||
): | |||||
assert param[0] == param_ref[0] | |||||
if "bn" in param[0]: | |||||
ref = param_ref[1].reshape(param[1].shape) | |||||
np.testing.assert_allclose(param[1], ref, atol=max_err) | |||||
else: | |||||
np.testing.assert_allclose(param[1], param_ref[1], atol=max_err) | |||||
def run_eval( | |||||
model_path, | |||||
use_symbolic, | |||||
sublinear_memory_config=None, | |||||
max_err=None, | |||||
use_adaptive_pooling=False, | |||||
): | |||||
""" | |||||
Load the model with test cases and run the training for one iter. | |||||
The loss and updated weights are compared with reference value to verify the correctness. | |||||
Dump a new file with updated result by calling update_model | |||||
if you think the test fails due to numerical rounding errors instead of bugs. | |||||
Please think twice before you do so. | |||||
""" | |||||
net = MnistNet(has_bn=True, use_adaptive_pooling=use_adaptive_pooling) | |||||
checkpoint = mge.load(model_path) | |||||
net.load_state_dict(checkpoint["net_init"]) | |||||
data = Tensor(checkpoint["data"], dtype=np.float32) | |||||
def eval_fun(data, *, net=None): | |||||
pred = net(data) | |||||
return pred | |||||
refer_value = eval_fun(data, net=net) | |||||
eval_fun = jit.trace(eval_fun, symbolic=use_symbolic) | |||||
for _ in range(3): | |||||
new_value = eval_fun(data, net=net) | |||||
np.testing.assert_allclose(new_value.numpy(), refer_value.numpy(), atol=max_err) | |||||
@pytest.mark.skip(reason="close it when cu111 ci") | |||||
def test_correctness(): | |||||
if mge.is_cuda_available(): | |||||
model_name = "mnist_model_with_test.mge" | |||||
else: | |||||
model_name = "mnist_model_with_test_cpu.mge" | |||||
model_path = os.path.join(os.path.dirname(__file__), model_name) | |||||
set_execution_strategy(Strategy.HEURISTIC | Strategy.REPRODUCIBLE) | |||||
run_train(model_path, False, False, max_err=1e-5) | |||||
run_train(model_path, True, False, max_err=1e-5) | |||||
run_train(model_path, True, True, max_err=1e-5) | |||||
# sublinear | |||||
config = SublinearMemoryConfig(genetic_nr_iter=10) | |||||
run_train( | |||||
model_path, True, True, sublinear_memory_config=config, max_err=1e-5, | |||||
) | |||||
run_eval(model_path, False, max_err=1e-7) | |||||
run_eval(model_path, True, max_err=1e-7) | |||||
@pytest.mark.skip(reason="close it when cu111 ci") | |||||
def test_correctness_use_adaptive_pooling(): | |||||
if mge.is_cuda_available(): | |||||
model_name = "mnist_model_with_test.mge" | |||||
else: | |||||
model_name = "mnist_model_with_test_cpu.mge" | |||||
model_path = os.path.join(os.path.dirname(__file__), model_name) | |||||
set_execution_strategy("HEURISTIC_REPRODUCIBLE") | |||||
run_train(model_path, False, False, max_err=1e-5, use_adaptive_pooling=True) | |||||
run_train(model_path, True, False, max_err=1e-5, use_adaptive_pooling=True) | |||||
run_train(model_path, True, True, max_err=1e-5, use_adaptive_pooling=True) | |||||
# sublinear | |||||
config = SublinearMemoryConfig(genetic_nr_iter=10) | |||||
run_train( | |||||
model_path, | |||||
True, | |||||
True, | |||||
sublinear_memory_config=config, | |||||
max_err=1e-5, | |||||
use_adaptive_pooling=True, | |||||
) | |||||
run_eval(model_path, False, max_err=1e-7, use_adaptive_pooling=True) | |||||
run_eval(model_path, True, max_err=1e-7, use_adaptive_pooling=True) |
@@ -7,11 +7,8 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import os | import os | ||||
import platform | |||||
import re | import re | ||||
import subprocess | import subprocess | ||||
import sys | |||||
from math import ceil | |||||
import numpy as np | import numpy as np | ||||
import pytest | import pytest | ||||
@@ -20,8 +17,6 @@ import megengine as mge | |||||
import megengine.autodiff as ad | import megengine.autodiff as ad | ||||
import megengine.distributed as dist | import megengine.distributed as dist | ||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine.device import get_default_device, set_default_device | |||||
from megengine.functional.debug_param import set_execution_strategy | |||||
from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module | from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module | ||||
from megengine.optimizer import SGD | from megengine.optimizer import SGD | ||||
from megengine.tensor import Tensor | from megengine.tensor import Tensor | ||||
@@ -198,5 +193,7 @@ def run_test( | |||||
def test_dp_correctness(): | def test_dp_correctness(): | ||||
model_name = "mnist_model_with_test.mge" | model_name = "mnist_model_with_test.mge" | ||||
model_path = os.path.join(os.path.dirname(__file__), model_name) | model_path = os.path.join(os.path.dirname(__file__), model_name) | ||||
set_execution_strategy("HEURISTIC_REPRODUCIBLE") | |||||
old = mge.config.deterministic_kernel | |||||
mge.config.deterministic_kernel = True | |||||
run_test(model_path, False, False, max_err=5e-5) | run_test(model_path, False, False, max_err=5e-5) | ||||
mge.config.deterministic_kernel = old |
@@ -11,21 +11,9 @@ import itertools | |||||
import numpy as np | import numpy as np | ||||
import pytest | import pytest | ||||
import megengine as mge | |||||
import megengine.module as M | import megengine.module as M | ||||
from megengine import Parameter, tensor | |||||
from megengine.functional.debug_param import ( | |||||
get_execution_strategy, | |||||
set_execution_strategy, | |||||
) | |||||
from megengine.module import ConvTranspose2d, ConvTranspose3d, LocalConv2d | |||||
@pytest.fixture | |||||
def reproducible(): | |||||
old = get_execution_strategy() | |||||
set_execution_strategy("HEURISTIC_REPRODUCIBLE") | |||||
yield | |||||
set_execution_strategy(old) | |||||
from megengine import tensor | |||||
# NOTE: test in module for convenience. should really test in functional | # NOTE: test in module for convenience. should really test in functional | ||||
@@ -33,7 +21,9 @@ def reproducible(): | |||||
"name", | "name", | ||||
["Conv1d", "Conv2d", "Conv3d", "ConvTranspose2d", "ConvTranspose3d", "LocalConv2d"], | ["Conv1d", "Conv2d", "Conv3d", "ConvTranspose2d", "ConvTranspose3d", "LocalConv2d"], | ||||
) | ) | ||||
def test_conv_dtype_promotion(name, reproducible): | |||||
def test_conv_dtype_promotion(name): | |||||
old = mge.config.deterministic_kernel | |||||
mge.config.deterministic_kernel = True | |||||
N, Ci, Co, K = 2, 16, 32, 3 | N, Ci, Co, K = 2, 16, 32, 3 | ||||
S = (7,) * int(name[-2]) | S = (7,) * int(name[-2]) | ||||
if "Local" in name: | if "Local" in name: | ||||
@@ -42,3 +32,4 @@ def test_conv_dtype_promotion(name, reproducible): | |||||
m = getattr(M, name)(Ci, Co, K) | m = getattr(M, name)(Ci, Co, K) | ||||
x = tensor(np.random.random(size=(N, Ci) + S).astype("float16")) | x = tensor(np.random.random(size=(N, Ci) + S).astype("float16")) | ||||
np.testing.assert_equal(m(x).numpy(), m(x.astype("float32")).numpy()) | np.testing.assert_equal(m(x).numpy(), m(x.astype("float32")).numpy()) | ||||
mge.config.deterministic_kernel = old |
@@ -255,9 +255,8 @@ def test_conv_bias_int4(): | |||||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu") | run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu") | ||||
@pytest.mark.require_ngpu(1) | |||||
@pytest.mark.skipif( | @pytest.mark.skipif( | ||||
get_cuda_compute_capability(0) < 61, | |||||
get_device_count("gpu") > 0 and get_cuda_compute_capability(0) < 61, | |||||
reason="does not support int8 when gpu compute capability less than 6.1", | reason="does not support int8 when gpu compute capability less than 6.1", | ||||
) | ) | ||||
def test_conv_transpose2d(): | def test_conv_transpose2d(): | ||||
@@ -5,6 +5,7 @@ import platform | |||||
import numpy as np | import numpy as np | ||||
import pytest | import pytest | ||||
import megengine as mge | |||||
import megengine.core.tensor.dtype as dtype | import megengine.core.tensor.dtype as dtype | ||||
import megengine.core.tensor.megbrain_graph as G | import megengine.core.tensor.megbrain_graph as G | ||||
import megengine.functional as F | import megengine.functional as F | ||||
@@ -18,10 +19,6 @@ from megengine.device import ( | |||||
get_device_count, | get_device_count, | ||||
is_cuda_available, | is_cuda_available, | ||||
) | ) | ||||
from megengine.functional.debug_param import ( | |||||
get_execution_strategy, | |||||
set_execution_strategy, | |||||
) | |||||
from megengine.functional.external import tensorrt_runtime_opr | from megengine.functional.external import tensorrt_runtime_opr | ||||
from megengine.jit.tracing import trace | from megengine.jit.tracing import trace | ||||
from megengine.tensor import Tensor | from megengine.tensor import Tensor | ||||
@@ -110,25 +107,30 @@ def test_matinv(): | |||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"execution_strategy", ["HEURISTIC_REPRODUCIBLE", "PROFILE_REPRODUCIBLE"] | |||||
"benchmark_kernel, max_err", [(False, None), (True, 1e-5)], | |||||
) | ) | ||||
def test_matmul(execution_strategy): | |||||
def test_matmul(monkeypatch, benchmark_kernel, max_err): | |||||
if get_device_count("gpu") == 0 and benchmark_kernel: | |||||
return | |||||
monkeypatch.setenv("MGE_FASTRUN_CACHE_TYPE", "MEMORY") | |||||
old1, old2 = ( | |||||
mge.config.benchmark_kernel, | |||||
mge.config.deterministic_kernel, | |||||
) | |||||
mge.config.benchmark_kernel = benchmark_kernel | |||||
mge.config.deterministic_kernel = True | |||||
@trace(symbolic=True, capture_as_const=True) | @trace(symbolic=True, capture_as_const=True) | ||||
def fwd(data1, data2): | def fwd(data1, data2): | ||||
return F.matmul(data1, data2) | return F.matmul(data1, data2) | ||||
old = get_execution_strategy() | |||||
set_execution_strategy(execution_strategy) | |||||
max_err = None | |||||
if execution_strategy == "PROFILE_REPRODUCIBLE": | |||||
max_err = 1e-5 | |||||
data1 = Tensor(np.random.random((32, 64))) | data1 = Tensor(np.random.random((32, 64))) | ||||
data2 = Tensor(np.random.random((64, 16))) | data2 = Tensor(np.random.random((64, 16))) | ||||
result = fwd(data1, data2) | result = fwd(data1, data2) | ||||
check_pygraph_dump(fwd, [data1, data2], [result], max_err=max_err) | check_pygraph_dump(fwd, [data1, data2], [result], max_err=max_err) | ||||
set_execution_strategy(old) | |||||
mge.config.benchmark_kernel = old1 | |||||
mge.config.deterministic_kernel = old2 | |||||
monkeypatch.delenv("MGE_FASTRUN_CACHE_TYPE", raising=False) | |||||
def test_batchmatmul(): | def test_batchmatmul(): | ||||
@@ -290,9 +292,8 @@ def test_deformable_ps_roi_pooling(): | |||||
check_pygraph_dump(fwd, [inp, rois, trans], [result]) | check_pygraph_dump(fwd, [inp, rois, trans], [result]) | ||||
@pytest.mark.require_ngpu(1) | |||||
@pytest.mark.skipif( | @pytest.mark.skipif( | ||||
get_cuda_compute_capability(0) < 61, | |||||
get_device_count("gpu") > 0 and get_cuda_compute_capability(0) < 61, | |||||
reason="does not support int8 when gpu compute capability less than 6.1", | reason="does not support int8 when gpu compute capability less than 6.1", | ||||
) | ) | ||||
def test_convbias(): | def test_convbias(): | ||||