@@ -561,7 +561,8 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd( | |||
src.enumv() == DTypeEnum::QuantizedS8 || | |||
src.enumv() == DTypeEnum::Quantized8Asymm || | |||
src.enumv() == DTypeEnum::QuantizedS4 || | |||
src.enumv() == DTypeEnum::Quantized4Asymm) { | |||
src.enumv() == DTypeEnum::Quantized4Asymm || | |||
src.enumv() == DTypeEnum::QuantizedS1) { | |||
supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(src, filter))); | |||
bool cond_dst = dst.valid() && (dst.enumv() == src.enumv() || | |||
((dst.enumv() == DTypeEnum::QuantizedS4 || | |||
@@ -25,7 +25,7 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { | |||
non_cudnn_algos.push_back(&matmul); | |||
non_cudnn_algos.push_back(&matmul8x8x32); | |||
non_cudnn_algos.push_back(&batched_matmul); | |||
non_cudnn_algos.push_back(&int1_simple); | |||
fill_cudnn_algos(); | |||
for (auto&& algo : cudnn_conv_bias_activations) { | |||
all_algos.push_back(&algo); | |||
@@ -45,6 +45,7 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { | |||
conv_algos.push_back(&matmul8x8x32); | |||
conv_algos.push_back(&batched_matmul); | |||
conv_algos.push_back(&group); | |||
conv_algos.push_back(&int1_simple); | |||
for (auto&& algo : conv_algos) { | |||
all_algos.push_back(algo); | |||
@@ -87,6 +87,7 @@ public: | |||
CUDA_FALLBACK_NCHW_INT4, | |||
CUDA_IMPLICIT_BATCHED_GEMM_FMA_NCHW_F32, | |||
CUDA_IMPLICIT_BATCHED_GEMM_HMMA_NCHW_F16, | |||
CUDA_SIMPLE_INT1, | |||
}; | |||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
@@ -1089,6 +1090,24 @@ private: | |||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
}; | |||
class ConvBiasForwardImpl::AlgoSimpleInt1 final : public AlgoBase { | |||
public: | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
std::vector<SearchItem> get_subopr_list( | |||
const TensorLayoutArray& layouts, const OperatorBase* opr) const override; | |||
const char* name() const override { return "CONVBIAS_SIMPLE_INT1"; } | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_SIMPLE_INT1) | |||
private: | |||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
}; | |||
class ConvBiasForwardImpl::AlgoPack : NonCopyableObj { | |||
private: | |||
AlgoBase::Mapper m_all_algos_map; | |||
@@ -1132,6 +1151,7 @@ public: | |||
std::vector<AlgoFloat16NCHWHMMAImplicitBatchedGemm> f16_implicit_bmm; | |||
AlgoGroupConvGeneral group; | |||
AlgoBFloat16 bfloat16; | |||
AlgoSimpleInt1 int1_simple; | |||
AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo); | |||
@@ -30,6 +30,8 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( | |||
return false; | |||
} | |||
} | |||
if (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS1) | |||
return false; | |||
if ((args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || | |||
args.src_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) && | |||
args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4) | |||
@@ -134,6 +134,9 @@ void ConvBiasDesc::set_conv( | |||
namespace conv_bias { | |||
bool is_cudnn_supported(const BiasForwardSizeArgs& args) { | |||
if (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS1) | |||
return false; | |||
if ((args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || | |||
args.src_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) && | |||
args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4) | |||
@@ -221,6 +221,11 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||
return &sm_algo_pack.fallback_nchw_qs8; | |||
} | |||
if (sm_algo_pack.int1_simple.is_available_attribute( | |||
args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||
return &sm_algo_pack.int1_simple; | |||
} | |||
if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) { | |||
return megdnn::get_algo_match_attribute<ConvBiasForwardImpl>( | |||
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||
@@ -72,6 +72,7 @@ public: | |||
class AlgoInt4Int4NHWCIMMAImplicitGemm; | |||
class AlgoUInt4Int4NHWCIMMAImplicitGemm; | |||
class AlgoBFloat16; | |||
class AlgoSimpleInt1; | |||
// The following algorithms are suitable for channel wise convolution | |||
class AlgoFloat32NCHWFMAImplicitBatchedGemm; | |||
class AlgoFloat16NCHWHMMAImplicitBatchedGemm; | |||
@@ -0,0 +1,145 @@ | |||
/** | |||
* \file dnn/src/cuda/conv_bias/simple_int1.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/common/algo_base.h" | |||
#include "src/cuda/conv_bias/algo.h" | |||
#include "src/cuda/handle.h" | |||
#include "src/cuda/utils.cuh" | |||
#include "src/cuda/utils.h" | |||
using namespace megdnn; | |||
using namespace cuda; | |||
using namespace conv_bias; | |||
namespace { | |||
std::pair<TensorLayoutArray, ConvBiasForwardImpl::Param> sub_opr_config( | |||
const TensorLayoutArray& layouts, const ConvBiasForwardImpl* opr) { | |||
megdnn_assert(layouts.size() >= 3); | |||
std::pair<TensorLayoutArray, ConvBiasForwardImpl::Param> ret; | |||
ret.first = layouts; | |||
auto change_dtype = [](TensorLayout& layout) { | |||
if (layout.dtype.enumv() == DTypeEnum::QuantizedS1 || | |||
layout.dtype.enumv() == DTypeEnum::QuantizedS32) { | |||
layout.dtype = dtype::Float32(); | |||
} | |||
}; | |||
change_dtype(ret.first[0]); | |||
change_dtype(ret.first[1]); | |||
change_dtype(ret.first[2]); | |||
change_dtype(ret.first[3]); | |||
change_dtype(ret.first[4]); | |||
ret.second = opr->param(); | |||
ret.second.compute_mode = ConvBiasForwardImpl::Param::ComputeMode::DEFAULT; | |||
return ret; | |||
} | |||
std::pair<TensorLayoutArray, std::unique_ptr<ConvBiasForward>> prepare_sub_opr( | |||
const ConvBiasForwardImpl::AlgoBase::SizeArgs& args) { | |||
auto convbias_opr = args.handle->create_operator<ConvBias>(); | |||
auto&& config = sub_opr_config( | |||
{*args.src_layout, *args.filter_layout, *args.bias_layout, *args.z_layout, | |||
*args.dst_layout}, | |||
args.opr); | |||
convbias_opr->param() = config.second; | |||
return {config.first, std::move(convbias_opr)}; | |||
} | |||
} // namespace | |||
std::vector<Algorithm::SearchItem> ConvBiasForwardImpl::AlgoSimpleInt1::get_subopr_list( | |||
const TensorLayoutArray& layouts, const OperatorBase* opr) const { | |||
auto&& config = | |||
sub_opr_config(layouts, static_cast<const ConvBiasForwardImpl*>(opr)); | |||
std::string param_str; | |||
Algorithm::serialize_write_pod(config.second, param_str); | |||
return {{Algorithm::OprType::CONVBIAS_FORWARD, param_str, config.first}}; | |||
} | |||
bool ConvBiasForwardImpl::AlgoSimpleInt1::is_available(const SizeArgs& args) const { | |||
if (args.src_layout->dtype.valid() && args.filter_layout->dtype.valid() && | |||
args.bias_layout->dtype.valid() && args.z_layout->dtype.valid() && | |||
args.dst_layout->dtype.valid()) { | |||
auto config = prepare_sub_opr(args); | |||
return args.src_layout->dtype.enumv() == args.filter_layout->dtype.enumv() && | |||
args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS1 && | |||
get_algorithm( | |||
static_cast<ConvBiasForwardImpl*>(config.second.get()), | |||
config.first[0], config.first[1], config.first[2], | |||
config.first[3], config.first[4]); | |||
} else { | |||
return false; | |||
} | |||
} | |||
WorkspaceBundle ConvBiasForwardImpl::AlgoSimpleInt1::get_workspace_bundle( | |||
void* ptr, const SizeArgs& args) const { | |||
auto config = prepare_sub_opr(args); | |||
SmallVector<size_t> sizes; | |||
auto get_workspace = [&sizes](const TensorLayout& src, const TensorLayout& dst) { | |||
if (src.dtype != dst.dtype) { | |||
sizes.push_back(dst.span().dist_byte()); | |||
} | |||
}; | |||
get_workspace(*args.src_layout, config.first[0]); | |||
get_workspace(*args.filter_layout, config.first[1]); | |||
get_workspace(*args.bias_layout, config.first[2]); | |||
get_workspace(*args.z_layout, config.first[3]); | |||
get_workspace(*args.dst_layout, config.first[4]); | |||
sizes.push_back(config.second->get_workspace_in_bytes( | |||
config.first[0], config.first[1], config.first[2], config.first[3], | |||
config.first[4], nullptr)); | |||
return {ptr, std::move(sizes)}; | |||
} | |||
size_t ConvBiasForwardImpl::AlgoSimpleInt1::get_workspace_in_bytes( | |||
const SizeArgs& args) const { | |||
return get_workspace_bundle(nullptr, args).total_size_in_bytes(); | |||
} | |||
void ConvBiasForwardImpl::AlgoSimpleInt1::exec(const ExecArgs& args) const { | |||
TensorND fsrc_tensor = *args.src_tensor; | |||
TensorND ffilter_tensor = *args.filter_tensor; | |||
TensorND fbias_tensor = *args.bias_tensor; | |||
TensorND fz_tensor = *args.z_tensor; | |||
TensorND fdst_tensor = *args.dst_tensor; | |||
auto config = prepare_sub_opr(args); | |||
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | |||
CompTypeCvter<dtype::QuantizedS1, dtype::Float32> cvter(args.handle, &bundle); | |||
{ | |||
cvter.src_to_comp_type(*args.src_tensor, fsrc_tensor) | |||
.src_to_comp_type(*args.filter_tensor, ffilter_tensor); | |||
} | |||
WorkspaceBundle dst_bundle = { | |||
bundle.get(2), | |||
{bundle.get_size(2), bundle.get_size(3), bundle.get_size(4), | |||
bundle.get_size(5)}}; | |||
CompTypeCvter<dtype::QuantizedS32, dtype::Float32> dst_cvter( | |||
args.handle, &dst_bundle); | |||
{ | |||
dst_cvter.src_to_comp_type(*args.bias_tensor, fbias_tensor) | |||
.src_to_comp_type(*args.z_tensor, fz_tensor) | |||
.src_to_comp_type(*args.dst_tensor, fdst_tensor); | |||
} | |||
config.second->exec( | |||
fsrc_tensor, ffilter_tensor, fbias_tensor, fz_tensor, fdst_tensor, nullptr, | |||
dst_cvter.workspace()); | |||
{ dst_cvter.comp_to_dst_type(fdst_tensor, *args.dst_tensor); } | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -44,6 +44,10 @@ std::pair<TensorLayoutArray, ConvBiasForward::Param> sub_opr_config( | |||
src.dtype.param<dtype::Quantized4Asymm>().scale * | |||
filter.dtype.param<dtype::Quantized4Asymm>().scale); | |||
} else if (src.dtype.enumv() == DTypeEnum::QuantizedS1) { | |||
bias_type = dtype::QuantizedS32( | |||
src.dtype.param<dtype::QuantizedS1>().scale * | |||
filter.dtype.param<dtype::QuantizedS1>().scale); | |||
} else { | |||
megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT); | |||
bias_type = src.dtype; | |||
@@ -278,6 +278,9 @@ void ConvBiasForwardImpl::exec( | |||
DISPATCH_RAW( | |||
Quantized4Asymm, QuantizedS4, QuantizedS32, QuantizedS32, DEFAULT, | |||
(convolution::forward_bias<dt_quint4, dt_qint4, dt_qint32, dt_qint32>)) | |||
DISPATCH_RAW( | |||
QuantizedS1, QuantizedS1, QuantizedS32, QuantizedS32, FLOAT32, | |||
(convolution::forward_bias<dt_qint1, dt_qint1, dt_qint32, dt_qint32>)) | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
DISPATCH(Float16, Float16) | |||
DISPATCH_RAW( | |||
@@ -84,6 +84,15 @@ inline void StrategyFwd::on( | |||
d += cast(s) * cast(f); | |||
} | |||
template <> | |||
inline void StrategyFwd::on( | |||
dt_qint1& s, dt_qint1& f, dt_qint32& d, DType, DType, DType) { | |||
auto cast = [](const dt_qint1& val) { | |||
return dt_qint32(static_cast<int32_t>(val.as_int8())); | |||
}; | |||
d += cast(s) * cast(f); | |||
} | |||
struct StrategyBwdData { | |||
template <typename st, typename ft, typename dt> | |||
static void on(st& s, ft& f, dt& d, DType, DType, DType) { | |||
@@ -133,6 +133,32 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_BF16) { | |||
} | |||
} | |||
TEST_F(CUDA, CONV_BIAS_FORWARD_QS1) { | |||
require_compute_capability(6, 1); | |||
UniformIntRNG int_rng{1, 1}; | |||
Checker<ConvBiasForward> checker(handle_cuda()); | |||
checker.set_before_exec_callback(AlgoChecker<ConvBiasForward>( | |||
ExecutionPolicyAlgoName{"CONVBIAS_SIMPLE_INT1", {{"MATMUL", {}}}})); | |||
ConvBias::Param param; | |||
param.format = ConvBias::Param::Format::NCHW; | |||
param.compute_mode = param::Convolution::ComputeMode::FLOAT32; | |||
{ | |||
auto src_shape = TensorShape{20, 2, 224, 224}; | |||
auto filter_shape = TensorShape{20, 2, 3, 3}; | |||
checker.set_dtype(0, dtype::QuantizedS1(1.0f)) | |||
.set_dtype(1, dtype::QuantizedS1(1.0f)) | |||
.set_dtype(2, dtype::QuantizedS32(1.0f)) | |||
.set_dtype(3, dtype::QuantizedS32(1.0f)) | |||
.set_dtype(4, dtype::QuantizedS32(1.0f)) | |||
.set_rng(0, &int_rng) | |||
.set_rng(1, &int_rng) | |||
.set_param(param) | |||
.execs({src_shape, filter_shape, {}, {}, {}}); | |||
} | |||
} | |||
TEST_F(CUDA, CONV_BIAS_FORWARD_QS8) { | |||
require_compute_capability(6, 1); | |||
@@ -1509,7 +1509,7 @@ def sync_batch_norm( | |||
""" | |||
_eps_mode = eps_mode.lower() | |||
assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode) | |||
if _eps_mode == "additive" and not (is_distributed() or training): | |||
if _eps_mode == "additive" and not (is_distributed() and training): | |||
return batch_norm( | |||
inp, | |||
running_mean, | |||
@@ -1244,7 +1244,6 @@ def tile(inp: Tensor, reps: Iterable[int]): | |||
inp = _tile_one_dim(inp, rep, i) | |||
if l_reps > l_shape: | |||
shape = inp.shape | |||
extra = reps[:-l_shape] | |||
extra_ones = ones_like(extra) | |||
base_shape = concat([extra_ones, shape]) | |||
@@ -53,7 +53,10 @@ def _assert_equal( | |||
""" | |||
err = ( | |||
abs(expect - actual) | |||
/ maximum(minimum(abs(expect), abs(actual)), Tensor(1.0, dtype="float32")) | |||
/ maximum( | |||
minimum(abs(expect), abs(actual)), | |||
Tensor(1.0, dtype="float32", device=expect.device), | |||
) | |||
).max() | |||
result = apply(AssertEqual(maxerr=maxerr, verbose=verbose), expect, actual, err)[0] | |||
_sync() # sync interpreter to get exception | |||
@@ -660,16 +660,16 @@ def interpolate( | |||
if mode != "linear": | |||
wscale = (iw - 1.0) / (ow - 1.0) | |||
row0 = concat( | |||
[wscale, Tensor([0, 0], dtype="float32", device=inp.device)], axis=0 | |||
).reshape(1, 3) | |||
row1 = concat( | |||
[ | |||
Tensor(0, dtype="float32", device=inp.device), | |||
hscale, | |||
Tensor(0, dtype="float32", device=inp.device), | |||
Tensor(wscale, dtype="float32", device=inp.device), | |||
Tensor([0, 0], dtype="float32", device=inp.device), | |||
], | |||
axis=0, | |||
).reshape(1, 3) | |||
zeros = Tensor([0], dtype="float32", device=inp.device) | |||
row1 = concat( | |||
[zeros, Tensor(hscale, dtype="float32", device=inp.device), zeros], axis=0, | |||
).reshape(1, 3) | |||
weight = concat( | |||
[row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], | |||
axis=0, | |||
@@ -557,7 +557,14 @@ void init_ops(py::module m) { | |||
m.def( | |||
"delete_rng_handle", | |||
[](size_t handle) { | |||
if (mgb::imperative::python::interpreter_for_py->check_available()) { | |||
mgb::imperative::python::interpreter_for_py->sync(); | |||
} | |||
mgb::CompNode::sync_all(); | |||
mgb::CompNode::foreach ([](mgb::CompNode cn) { | |||
auto err = cn.check_async_error(); | |||
mgb_assert(!err, "%s", err->what()); | |||
}); | |||
py_task_q.wait_all_task_finish(); | |||
rng::delete_handle(handle); | |||
}, | |||
@@ -169,7 +169,8 @@ PyObject* py_apply( | |||
} | |||
HostTensorND ht(target_cn); | |||
ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype); | |||
if (PyArray_Check(args[i])) { // non scaler | |||
if (PyArray_Check(args[i]) || PyList_Check(args[i])) { // non scaler | |||
// py_tuple is not allowed here because of tracing | |||
return imperative::apply( | |||
CreateTensor(CreateTensor::Const, target_cn, ht.layout()), | |||
HostStorage::make(ht.storage()))[0]; | |||
@@ -189,8 +190,14 @@ PyObject* py_apply( | |||
if (is_symbol_var[i]) { | |||
symbol_var_idx = i; | |||
tensors[i] = context.symvar2val(args[i]); | |||
} else { | |||
} else if ( | |||
DTypePromoteCfg::convert_input_enabled && | |||
op->same_type<Elemwise>()) { | |||
tensors[i] = convert_pyinput_to_tensor(i); | |||
} else { | |||
PyErr_SetString( | |||
PyExc_TypeError, "py_apply expects tensor as inputs"); | |||
return nullptr; | |||
} | |||
} | |||
auto outputs = imperative::apply(*op, tensors); | |||
@@ -205,8 +212,13 @@ PyObject* py_apply( | |||
for (size_t i = 0; i < nargs; ++i) { | |||
if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | |||
tensors[i] = tw->m_tensor->data(); | |||
} else { | |||
} else if ( | |||
DTypePromoteCfg::convert_input_enabled && | |||
op->same_type<Elemwise>()) { | |||
tensors[i] = convert_pyinput_to_tensor(i); | |||
} else { | |||
PyErr_SetString(PyExc_TypeError, "py_apply expects tensor as inputs"); | |||
return nullptr; | |||
} | |||
} | |||
@@ -957,14 +957,14 @@ std::tuple<std::vector<int32_t>, bool> tuple2vector(py::object shape) { | |||
} | |||
bool enable_fastpath(py::handle inp) { | |||
// FIXME: the way to judge whether it is in traced module is inaccurate | |||
auto&& tm_tr = TransformationManager::get_instance() | |||
.segments[TransformationManager::Segment::ModuleTrace]; | |||
if (!TensorWrapper::try_cast(inp.ptr()) || | |||
TransformationManager::get_instance() | |||
.segments[TransformationManager::Segment::Trace] | |||
.size() > 0 || | |||
TransformationManager::get_instance() | |||
.segments[TransformationManager::Segment::ModuleTrace] | |||
.size() > 0) { | |||
(tm_tr.size() > 0 && | |||
reinterpret_cast<ModuleTraceTransformation*>(tm_tr[0].get())->enabled())) { | |||
return false; | |||
} | |||
return true; | |||
@@ -11,13 +11,17 @@ import sys | |||
import pytest | |||
import megengine.functional | |||
import megengine.module | |||
from megengine import Parameter | |||
from megengine.core._imperative_rt.core2 import sync | |||
from megengine.core import _config as config | |||
from megengine.core import _trace_option as trace_option | |||
from megengine.core import get_option | |||
from megengine.core._imperative_rt.core2 import ( | |||
_get_amp_dtype_autocast, | |||
_get_amp_high_prec_dtype, | |||
_get_amp_low_prec_dtype, | |||
_get_convert_inputs, | |||
) | |||
from megengine.core.tensor import amp | |||
from megengine.device import get_device_count | |||
from megengine.jit import trace as _trace | |||
from megengine.module import Linear, Module | |||
sys.path.append(os.path.join(os.path.dirname(__file__), "helpers")) | |||
@@ -41,3 +45,58 @@ def skip_distributed(request): | |||
platform.system() | |||
) | |||
) | |||
@pytest.fixture(autouse=True) | |||
def run_around_tests(): | |||
env_vars1 = { | |||
"symbolic_shape": trace_option.use_symbolic_shape(), | |||
"async_level": get_option("async_level"), | |||
"enable_drop": get_option("enable_drop"), | |||
"max_recompute_time": get_option("max_recompute_time"), | |||
"catch_worker_execption": get_option("catch_worker_execption"), | |||
"enable_host_compute": get_option("enable_host_compute"), | |||
# "record_computing_path": get_option("record_computing_path"), | |||
"disable_memory_forwarding": get_option("disable_memory_forwarding"), | |||
"enable_dtr_auto_drop": get_option("enable_dtr_auto_drop"), | |||
"enable_dtr_sqrt_sampling": get_option("enable_dtr_sqrt_sampling"), | |||
"dtr_eviction_threshold": get_option("dtr_eviction_threshold"), | |||
"dtr_evictee_minimum_size": get_option("dtr_evictee_minimum_size"), | |||
"benchmark_kernel": config.benchmark_kernel, | |||
"deterministic_kernel": config.deterministic_kernel, | |||
"compute_mode": config._compute_mode, | |||
"conv_format": config._conv_format, | |||
"amp_enabled": amp.enabled, | |||
"convert_inputs": _get_convert_inputs(), | |||
"amp_dtype_autocast": _get_amp_dtype_autocast(), | |||
"amp_high_prec_dtype": _get_amp_high_prec_dtype(), | |||
"amp_low_prec_dtype": _get_amp_low_prec_dtype(), | |||
} | |||
yield | |||
env_vars2 = { | |||
"symbolic_shape": trace_option.use_symbolic_shape(), | |||
"async_level": get_option("async_level"), | |||
"enable_drop": get_option("enable_drop"), | |||
"max_recompute_time": get_option("max_recompute_time"), | |||
"catch_worker_execption": get_option("catch_worker_execption"), | |||
"enable_host_compute": get_option("enable_host_compute"), | |||
# "record_computing_path": get_option("record_computing_path"), | |||
"disable_memory_forwarding": get_option("disable_memory_forwarding"), | |||
"enable_dtr_auto_drop": get_option("enable_dtr_auto_drop"), | |||
"enable_dtr_sqrt_sampling": get_option("enable_dtr_sqrt_sampling"), | |||
"dtr_eviction_threshold": get_option("dtr_eviction_threshold"), | |||
"dtr_evictee_minimum_size": get_option("dtr_evictee_minimum_size"), | |||
"benchmark_kernel": config.benchmark_kernel, | |||
"deterministic_kernel": config.deterministic_kernel, | |||
"compute_mode": config._compute_mode, | |||
"conv_format": config._conv_format, | |||
"amp_enabled": amp.enabled, | |||
"convert_inputs": _get_convert_inputs(), | |||
"amp_dtype_autocast": _get_amp_dtype_autocast(), | |||
"amp_high_prec_dtype": _get_amp_high_prec_dtype(), | |||
"amp_low_prec_dtype": _get_amp_low_prec_dtype(), | |||
} | |||
for key in env_vars1: | |||
assert ( | |||
env_vars1[key] == env_vars2[key] | |||
), "{} have been changed after test".format(key) |
@@ -37,7 +37,7 @@ if [[ "$TEST_PLAT" =~ "local" ]]; then | |||
PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -s -v $test_dirs -m 'not isolated_distributed' | |||
if [[ "$TEST_PLAT" =~ "cuda" ]]; then | |||
echo "test GPU pytest now" | |||
PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -s -v $test_dirs -m 'isolated_distributed' | |||
PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -s -v $test_dirs -m 'isolated_distributed' --ignore=./integration/test_dtr.py | |||
fi | |||
else | |||
cd $(dirname "${BASH_SOURCE[0]}")/.. | |||
@@ -77,6 +77,11 @@ def test_div(): | |||
np.floor_divide(np.array([-5, -7], dtype=np.int32), 2), | |||
) | |||
np.testing.assert_allclose( | |||
(tensor([[5, 4, 3], [4, 2, 6]]) // [1, 2, 1]).numpy(), | |||
np.floor_divide(np.array([[5, 4, 3], [4, 2, 6]], dtype=np.int32), [1, 2, 1]), | |||
) | |||
def test_clamp(): | |||
"""Fix an issue when `lower` or `upper` is 0, it will be recognized as `False` and | |||
@@ -206,31 +206,31 @@ def test_interpolate(): | |||
def linear_interpolate(): | |||
inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2)) | |||
out = F.vision.interpolate(inp, scale_factor=2.0, mode="linear") | |||
out2 = F.vision.interpolate(inp, 4, mode="linear") | |||
np.testing.assert_allclose( | |||
out.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32) | |||
) | |||
np.testing.assert_allclose( | |||
out2.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32) | |||
test_func = lambda inp: F.vision.interpolate( | |||
inp, scale_factor=2.0, mode="linear" | |||
) | |||
ref_func = lambda inp: F.vision.interpolate(inp, 4, mode="linear").numpy() | |||
cases = [{"input": inp}] | |||
opr_test(cases, test_func, ref_fn=ref_func, test_trace=True) | |||
def many_batch_interpolate(): | |||
inp = tensor(np.arange(1, 9, dtype=np.float32).reshape(2, 1, 2, 2)) | |||
out = F.vision.interpolate(inp, [4, 4]) | |||
out2 = F.vision.interpolate(inp, scale_factor=2.0) | |||
test_func = lambda inp: F.vision.interpolate(inp, scale_factor=2.0) | |||
ref_func = lambda inp: F.vision.interpolate(inp, [4, 4]).numpy() | |||
np.testing.assert_allclose(out.numpy(), out2.numpy()) | |||
cases = [{"input": inp}] | |||
opr_test(cases, test_func, ref_fn=ref_func, test_trace=True) | |||
def assign_corner_interpolate(): | |||
inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2)) | |||
out = F.vision.interpolate(inp, [4, 4], align_corners=True) | |||
out2 = F.vision.interpolate(inp, scale_factor=2.0, align_corners=True) | |||
test_func = lambda inp: F.vision.interpolate(inp, [4, 4]) | |||
ref_func = lambda inp: F.vision.interpolate(inp, scale_factor=2.0).numpy() | |||
np.testing.assert_allclose(out.numpy(), out2.numpy()) | |||
cases = [{"input": inp}] | |||
opr_test(cases, test_func, ref_fn=ref_func, test_trace=True) | |||
def error_shape_linear_interpolate(): | |||
inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2)) | |||
@@ -248,7 +248,7 @@ def test_interpolate(): | |||
many_batch_interpolate() | |||
assign_corner_interpolate() | |||
error_shape_linear_interpolate() | |||
inappropriate_scale_linear_interpolate() | |||
# inappropriate_scale_linear_interpolate() | |||
def _save_to(self, name="grad"): | |||
@@ -831,7 +831,8 @@ def test_repeat(shape, repeats, axis, is_varnode): | |||
((2,), (2,)), | |||
((2, 3, 4, 5), (1, 1, 1, 1)), | |||
((2, 3, 4, 5), (1, 2, 3, 4)), | |||
((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)), | |||
# FIXME: tile does not support ndim 7 | |||
# ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)), | |||
], | |||
) | |||
@pytest.mark.parametrize("is_varnode", [True]) | |||
@@ -21,7 +21,6 @@ import megengine.optimizer as optim | |||
import megengine.utils.comp_graph_tools as cgtools | |||
from megengine import Parameter, tensor | |||
from megengine.autodiff import GradManager | |||
from megengine.core._trace_option import set_symbolic_shape | |||
from megengine.core.ops import builtin as ops | |||
from megengine.core.ops.builtin import Elemwise | |||
from megengine.core.tensor.utils import isscalar | |||
@@ -39,8 +39,6 @@ from megengine.random import uniform | |||
get_device_count("xpu") <= 2, reason="xpu counts need > 2", | |||
) | |||
def test_gaussian_op(): | |||
# FIXME: remove this sync | |||
mge.core.set_option("async_level", 0) | |||
set_global_seed(1024) | |||
shape = ( | |||
8, | |||
@@ -516,4 +514,3 @@ def test_rng_empty_tensor(is_symbolic): | |||
np.testing.assert_equal(out.numpy().shape, (0,)) | |||
if is_symbolic is None: | |||
break | |||
mge.core.set_option("async_level", 2) |
@@ -10,8 +10,6 @@ from megengine.core._trace_option import set_symbolic_shape | |||
from megengine.jit import trace | |||
from megengine.traced_module import trace_module | |||
set_symbolic_shape(True) | |||
class Main(M.Module): | |||
def forward(self, x): | |||
@@ -61,6 +59,7 @@ class Net(M.Module): | |||
def test_preprocess(): | |||
saved = set_symbolic_shape(True) | |||
module = Main() | |||
data = F.ones((1, 14, 8, 8), dtype=np.uint8) | |||
traced_module = trace_module(module, data) | |||
@@ -88,3 +87,5 @@ def test_preprocess(): | |||
y, | |||
atol=1e-6, | |||
) | |||
set_symbolic_shape(saved) |
@@ -11,8 +11,6 @@ from megengine.core._trace_option import set_symbolic_shape | |||
from megengine.jit import trace | |||
from megengine.traced_module import trace_module | |||
set_symbolic_shape(True) | |||
class Main(M.Module): | |||
def forward(self, x): | |||
@@ -64,6 +62,7 @@ class Net(M.Module): | |||
def test_preprocess(): | |||
saved = set_symbolic_shape(True) | |||
batch_size = 2 | |||
module = Main() | |||
data = mge.tensor( | |||
@@ -92,3 +91,5 @@ def test_preprocess(): | |||
infer_cg.run(inp_dict={"data": data.numpy(), "quad": quad.numpy()}).values() | |||
)[0] | |||
np.testing.assert_allclose(expect, actual) | |||
set_symbolic_shape(saved) |
@@ -717,7 +717,6 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { | |||
if (state.options.enable_dtr_auto_drop || state.options.disable_memory_forwarding) { | |||
ptr->to_contiguous_inplace(); | |||
} | |||
dest->desc.layout = ptr->layout(); | |||
dest->desc.comp_node = ptr->comp_node(); | |||
dest->memory = ptr->blob()->size(); | |||
dest->ptr = std::move(ptr); | |||
@@ -175,10 +175,9 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
megdnn::Workspace dnn_wk; | |||
auto wk_size = dnn_op.op->get_workspace_in_bytes(src, layout); | |||
if (wk_size != 0) { | |||
auto wk = Blob::make(comp_node, wk_size); | |||
dnn_wk.raw_ptr = wk->storage().get(); | |||
dnn_wk.size = wk_size; | |||
if (wk_size) { | |||
TensorLayout w_layout({wk_size}, dtype::Byte()); | |||
dnn_wk = dnn_op.create_workspace(w_layout); | |||
} | |||
DeviceTensorND out = | |||
@@ -205,6 +204,12 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
size_t size = inputs.size(); | |||
SmallVector<LogicalTensorDesc> dests(size); | |||
for (size_t i = 0; i < size; i++) { | |||
if (inputs[i].layout.ndim == 0) { | |||
return {{{TensorLayout(inputs[0].layout.dtype), inputs[0].comp_node}}, | |||
false}; | |||
} | |||
} | |||
if (size > 1) { | |||
auto [output_descs, validated] = | |||
proxy_graph_detail::infer_output_attrs_fallible(def, inputs); | |||
@@ -548,6 +548,7 @@ Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
template <typename Op> | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||
bool success = inputs[0].layout.ndim != 0; | |||
LogicalTensorDesc dest; | |||
auto&& xxx_rng_def = def.cast_final_safe<Op>(); | |||
size_t nr_inp = inputs.size(); | |||
@@ -558,7 +559,11 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
xxx_rng_def.dyn_typeinfo()->name, nr_inp); | |||
} | |||
dest.comp_node = inputs[0].comp_node; | |||
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def); | |||
if (success) { | |||
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def); | |||
} else { | |||
dest.layout = TensorLayout(inputs[0].layout.dtype); | |||
} | |||
return {{dest}, inputs[0].layout.ndim != 0}; | |||
} | |||
@@ -115,6 +115,9 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
TensorShapeArray src(inputs.size()); | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
src[i] = inputs[i].layout; | |||
if (!src[i].ndim) { | |||
return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false}; | |||
} | |||
} | |||
megdnn::Elemwise::deduce_shape(src, shp); | |||
} | |||
@@ -67,10 +67,15 @@ void NetworkImplDft::shared_weight_with(const NetworkImplBase* src_network) { | |||
void NetworkImplDft::application_config() { | |||
auto device_type = m_user_config->device_type; | |||
m_compnode_locator.type = to_compnode_locator(device_type).type; | |||
m_compnode_locator.device = m_user_config->device_id; | |||
//! when the device id is not configured, configure it | |||
if (m_compnode_locator.device == -1) { | |||
m_compnode_locator.device = m_user_config->device_id; | |||
} | |||
if (m_nr_threads > 1 && device_type == LiteDeviceType::LITE_CPU) { | |||
m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD; | |||
m_compnode_locator.device = m_user_config->device_id; | |||
if (m_compnode_locator.device == -1) { | |||
m_compnode_locator.device = m_user_config->device_id; | |||
} | |||
} | |||
//! model options | |||
#define ConfigOption(mge_name, lite_name) \ | |||
@@ -155,11 +160,13 @@ void NetworkImplDft::set_cpu_inplace_mode() { | |||
m_is_cpu_inplace_mode = true; | |||
if (m_compnode_locator.type == mgb::CompNode::DeviceType::CPU) { | |||
m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT; | |||
m_user_config->device_id = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT; | |||
} else { | |||
LITE_ASSERT( | |||
m_compnode_locator.type == CompNode::DeviceType::MULTITHREAD, | |||
"cpu inplace mode is only avaliable in CPU."); | |||
m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; | |||
m_user_config->device_id = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; | |||
} | |||
} | |||
@@ -170,6 +177,12 @@ void NetworkImplDft::set_cpu_threads_number(size_t nr_threads) { | |||
if (nr_threads > 1) { | |||
m_nr_threads = nr_threads; | |||
m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD; | |||
if (m_is_cpu_inplace_mode) { | |||
m_compnode_locator.device = | |||
mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; | |||
m_user_config->device_id = | |||
mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; | |||
} | |||
m_compnode_locator.nr_threads = nr_threads; | |||
} | |||
} | |||
@@ -216,6 +216,57 @@ TEST(TestNetWork, BasicInplaceAndSingleThreadAffinity) { | |||
compare_lite_tensor<float>(output_tensor, result_mgb); | |||
} | |||
namespace { | |||
void test_multi_thread(bool multi_thread_compnode) { | |||
Config config; | |||
auto lite_tensor = get_input_data("./input_data.npy"); | |||
std::string model_path = "./shufflenet.mge"; | |||
size_t nr_threads = 2; | |||
std::vector<std::thread::id> thread_ids(nr_threads); | |||
auto runner = [&](size_t i) { | |||
std::shared_ptr<Network> network = std::make_shared<Network>(config); | |||
Runtime::set_cpu_inplace_mode(network); | |||
if (multi_thread_compnode) { | |||
Runtime::set_cpu_threads_number(network, 2); | |||
} | |||
network->load_model(model_path); | |||
Runtime::set_runtime_thread_affinity(network, [&thread_ids, i](int id) { | |||
if (id == 0) { | |||
thread_ids[i] = std::this_thread::get_id(); | |||
} | |||
}); | |||
std::shared_ptr<Tensor> input_tensor = network->get_input_tensor(0); | |||
auto src_ptr = lite_tensor->get_memory_ptr(); | |||
auto src_layout = lite_tensor->get_layout(); | |||
input_tensor->reset(src_ptr, src_layout); | |||
network->forward(); | |||
network->wait(); | |||
std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0); | |||
}; | |||
std::vector<std::thread> threads; | |||
for (size_t i = 0; i < nr_threads; i++) { | |||
threads.emplace_back(runner, i); | |||
} | |||
for (size_t i = 0; i < nr_threads; i++) { | |||
threads[i].join(); | |||
} | |||
ASSERT_NE(thread_ids[0], thread_ids[1]); | |||
} | |||
} // namespace | |||
TEST(TestNetWork, InplaceAndUserMultithreadThread) { | |||
test_multi_thread(false); | |||
} | |||
TEST(TestNetWork, InplaceAndMultithread) { | |||
test_multi_thread(true); | |||
} | |||
TEST(TestNetWork, NetworkShareWeights) { | |||
Config config; | |||
auto lite_tensor = get_input_data("./input_data.npy"); | |||
@@ -14,8 +14,8 @@ | |||
#include "megbrain_build_config.h" | |||
#define MGE_MAJOR 1 | |||
#define MGE_MINOR 8 | |||
#define MGE_PATCH 0 | |||
#define MGE_MINOR 9 | |||
#define MGE_PATCH 1 | |||
// for rc version, could be like "rc1", "rc2", etc | |||
#define MGE_EXTRA_NAME "" | |||