Compare commits

...

10 Commits

Author SHA1 Message Date
  温娟 0bed6c0fd9 chore(release): bump version 3 years ago
  Megvii Engine Team 9488cd1c35 fix(lite): fix lite cpu default not work 3 years ago
  Megvii Engine Team 518c7f3781 fix(imperative/src): fix empty_tensor bug of rng 3 years ago
  Megvii Engine Team cca38c4e8d fix(mge): fix fastpath check 3 years ago
  Megvii Engine Team b9e850a479 test(imperative): check env values after each pytest 3 years ago
  Megvii Engine Team d984be591d fix(imperative): restrict value converts to symbolvar 3 years ago
  Megvii Engine Team 5bf311637a fix(mge): fix infer output attrs fallible 3 years ago
  Megvii Engine Team 94960ecf42 fix(imperative): restrict using convert_inputs in py_apply 3 years ago
  Megvii Engine Team 09dab38748 feat(cuda): support int1 simplewq conv 3 years ago
  温娟 6554e26248 chore(release): bump version 3 years ago
35 changed files with 439 additions and 59 deletions
Split View
  1. +2
    -1
      dnn/src/common/convolution.cpp
  2. +2
    -1
      dnn/src/cuda/conv_bias/algo.cpp
  3. +20
    -0
      dnn/src/cuda/conv_bias/algo.h
  4. +2
    -0
      dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp
  5. +3
    -0
      dnn/src/cuda/conv_bias/helper.cpp
  6. +5
    -0
      dnn/src/cuda/conv_bias/opr_impl.cpp
  7. +1
    -0
      dnn/src/cuda/conv_bias/opr_impl.h
  8. +145
    -0
      dnn/src/cuda/conv_bias/simple_int1.cpp
  9. +4
    -0
      dnn/src/cuda/convolution/forward/algos.cpp
  10. +3
    -0
      dnn/src/naive/conv_bias/opr_impl.cpp
  11. +9
    -0
      dnn/src/naive/convolution/helper.h
  12. +26
    -0
      dnn/test/cuda/conv_bias.cpp
  13. +1
    -1
      imperative/python/megengine/functional/nn.py
  14. +0
    -1
      imperative/python/megengine/functional/tensor.py
  15. +4
    -1
      imperative/python/megengine/functional/utils.py
  16. +6
    -6
      imperative/python/megengine/functional/vision.py
  17. +7
    -0
      imperative/python/src/ops.cpp
  18. +15
    -3
      imperative/python/src/tensor.cpp
  19. +4
    -4
      imperative/python/src/tensor_utils.cpp
  20. +65
    -6
      imperative/python/test/conftest.py
  21. +1
    -1
      imperative/python/test/run.sh
  22. +5
    -0
      imperative/python/test/unit/functional/test_elemwise.py
  23. +15
    -15
      imperative/python/test/unit/functional/test_functional.py
  24. +2
    -1
      imperative/python/test/unit/functional/test_tensor.py
  25. +0
    -1
      imperative/python/test/unit/jit/test_tracing.py
  26. +0
    -3
      imperative/python/test/unit/random/test_rng.py
  27. +3
    -2
      imperative/python/test/unit/traced_module/test_preprocess_1.py
  28. +3
    -2
      imperative/python/test/unit/traced_module/test_preprocess_2.py
  29. +0
    -1
      imperative/src/impl/interpreter/interpreter_impl.cpp
  30. +9
    -4
      imperative/src/impl/ops/reduce.cpp
  31. +6
    -1
      imperative/src/impl/ops/rng.cpp
  32. +3
    -0
      imperative/src/impl/ops/tensor_manip.cpp
  33. +15
    -2
      lite/src/mge/network_impl.cpp
  34. +51
    -0
      lite/test/test_network.cpp
  35. +2
    -2
      src/core/include/megbrain/version.h

+ 2
- 1
dnn/src/common/convolution.cpp View File

@@ -561,7 +561,8 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(
src.enumv() == DTypeEnum::QuantizedS8 ||
src.enumv() == DTypeEnum::Quantized8Asymm ||
src.enumv() == DTypeEnum::QuantizedS4 ||
src.enumv() == DTypeEnum::Quantized4Asymm) {
src.enumv() == DTypeEnum::Quantized4Asymm ||
src.enumv() == DTypeEnum::QuantizedS1) {
supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(src, filter)));
bool cond_dst = dst.valid() && (dst.enumv() == src.enumv() ||
((dst.enumv() == DTypeEnum::QuantizedS4 ||


+ 2
- 1
dnn/src/cuda/conv_bias/algo.cpp View File

@@ -25,7 +25,7 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() {
non_cudnn_algos.push_back(&matmul);
non_cudnn_algos.push_back(&matmul8x8x32);
non_cudnn_algos.push_back(&batched_matmul);
non_cudnn_algos.push_back(&int1_simple);
fill_cudnn_algos();
for (auto&& algo : cudnn_conv_bias_activations) {
all_algos.push_back(&algo);
@@ -45,6 +45,7 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() {
conv_algos.push_back(&matmul8x8x32);
conv_algos.push_back(&batched_matmul);
conv_algos.push_back(&group);
conv_algos.push_back(&int1_simple);

for (auto&& algo : conv_algos) {
all_algos.push_back(algo);


+ 20
- 0
dnn/src/cuda/conv_bias/algo.h View File

@@ -87,6 +87,7 @@ public:
CUDA_FALLBACK_NCHW_INT4,
CUDA_IMPLICIT_BATCHED_GEMM_FMA_NCHW_F32,
CUDA_IMPLICIT_BATCHED_GEMM_HMMA_NCHW_F16,
CUDA_SIMPLE_INT1,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;

@@ -1089,6 +1090,24 @@ private:
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
};

class ConvBiasForwardImpl::AlgoSimpleInt1 final : public AlgoBase {
public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;

std::vector<SearchItem> get_subopr_list(
const TensorLayoutArray& layouts, const OperatorBase* opr) const override;

const char* name() const override { return "CONVBIAS_SIMPLE_INT1"; }

AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }

MEGDNN_DECL_ALGO_TYPE(CUDA_SIMPLE_INT1)
private:
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
};

class ConvBiasForwardImpl::AlgoPack : NonCopyableObj {
private:
AlgoBase::Mapper m_all_algos_map;
@@ -1132,6 +1151,7 @@ public:
std::vector<AlgoFloat16NCHWHMMAImplicitBatchedGemm> f16_implicit_bmm;
AlgoGroupConvGeneral group;
AlgoBFloat16 bfloat16;
AlgoSimpleInt1 int1_simple;

AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo);



+ 2
- 0
dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp View File

@@ -30,6 +30,8 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available(
return false;
}
}
if (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS1)
return false;
if ((args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS4 ||
args.src_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4)


+ 3
- 0
dnn/src/cuda/conv_bias/helper.cpp View File

@@ -134,6 +134,9 @@ void ConvBiasDesc::set_conv(
namespace conv_bias {

bool is_cudnn_supported(const BiasForwardSizeArgs& args) {
if (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS1)
return false;

if ((args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS4 ||
args.src_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4)


+ 5
- 0
dnn/src/cuda/conv_bias/opr_impl.cpp View File

@@ -221,6 +221,11 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
return &sm_algo_pack.fallback_nchw_qs8;
}

if (sm_algo_pack.int1_simple.is_available_attribute(
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.int1_simple;
}

if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) {
return megdnn::get_algo_match_attribute<ConvBiasForwardImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,


+ 1
- 0
dnn/src/cuda/conv_bias/opr_impl.h View File

@@ -72,6 +72,7 @@ public:
class AlgoInt4Int4NHWCIMMAImplicitGemm;
class AlgoUInt4Int4NHWCIMMAImplicitGemm;
class AlgoBFloat16;
class AlgoSimpleInt1;
// The following algorithms are suitable for channel wise convolution
class AlgoFloat32NCHWFMAImplicitBatchedGemm;
class AlgoFloat16NCHWHMMAImplicitBatchedGemm;


+ 145
- 0
dnn/src/cuda/conv_bias/simple_int1.cpp View File

@@ -0,0 +1,145 @@
/**
* \file dnn/src/cuda/conv_bias/simple_int1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/common/algo_base.h"
#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/handle.h"
#include "src/cuda/utils.cuh"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;
using namespace conv_bias;

namespace {
std::pair<TensorLayoutArray, ConvBiasForwardImpl::Param> sub_opr_config(
const TensorLayoutArray& layouts, const ConvBiasForwardImpl* opr) {
megdnn_assert(layouts.size() >= 3);
std::pair<TensorLayoutArray, ConvBiasForwardImpl::Param> ret;
ret.first = layouts;
auto change_dtype = [](TensorLayout& layout) {
if (layout.dtype.enumv() == DTypeEnum::QuantizedS1 ||
layout.dtype.enumv() == DTypeEnum::QuantizedS32) {
layout.dtype = dtype::Float32();
}
};
change_dtype(ret.first[0]);
change_dtype(ret.first[1]);
change_dtype(ret.first[2]);
change_dtype(ret.first[3]);
change_dtype(ret.first[4]);

ret.second = opr->param();
ret.second.compute_mode = ConvBiasForwardImpl::Param::ComputeMode::DEFAULT;
return ret;
}

std::pair<TensorLayoutArray, std::unique_ptr<ConvBiasForward>> prepare_sub_opr(
const ConvBiasForwardImpl::AlgoBase::SizeArgs& args) {
auto convbias_opr = args.handle->create_operator<ConvBias>();
auto&& config = sub_opr_config(
{*args.src_layout, *args.filter_layout, *args.bias_layout, *args.z_layout,
*args.dst_layout},
args.opr);
convbias_opr->param() = config.second;

return {config.first, std::move(convbias_opr)};
}
} // namespace

std::vector<Algorithm::SearchItem> ConvBiasForwardImpl::AlgoSimpleInt1::get_subopr_list(
const TensorLayoutArray& layouts, const OperatorBase* opr) const {
auto&& config =
sub_opr_config(layouts, static_cast<const ConvBiasForwardImpl*>(opr));

std::string param_str;
Algorithm::serialize_write_pod(config.second, param_str);
return {{Algorithm::OprType::CONVBIAS_FORWARD, param_str, config.first}};
}

bool ConvBiasForwardImpl::AlgoSimpleInt1::is_available(const SizeArgs& args) const {
if (args.src_layout->dtype.valid() && args.filter_layout->dtype.valid() &&
args.bias_layout->dtype.valid() && args.z_layout->dtype.valid() &&
args.dst_layout->dtype.valid()) {
auto config = prepare_sub_opr(args);

return args.src_layout->dtype.enumv() == args.filter_layout->dtype.enumv() &&
args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS1 &&
get_algorithm(
static_cast<ConvBiasForwardImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2],
config.first[3], config.first[4]);
} else {
return false;
}
}

WorkspaceBundle ConvBiasForwardImpl::AlgoSimpleInt1::get_workspace_bundle(
void* ptr, const SizeArgs& args) const {
auto config = prepare_sub_opr(args);

SmallVector<size_t> sizes;
auto get_workspace = [&sizes](const TensorLayout& src, const TensorLayout& dst) {
if (src.dtype != dst.dtype) {
sizes.push_back(dst.span().dist_byte());
}
};
get_workspace(*args.src_layout, config.first[0]);
get_workspace(*args.filter_layout, config.first[1]);
get_workspace(*args.bias_layout, config.first[2]);
get_workspace(*args.z_layout, config.first[3]);
get_workspace(*args.dst_layout, config.first[4]);
sizes.push_back(config.second->get_workspace_in_bytes(
config.first[0], config.first[1], config.first[2], config.first[3],
config.first[4], nullptr));

return {ptr, std::move(sizes)};
}

size_t ConvBiasForwardImpl::AlgoSimpleInt1::get_workspace_in_bytes(
const SizeArgs& args) const {
return get_workspace_bundle(nullptr, args).total_size_in_bytes();
}

void ConvBiasForwardImpl::AlgoSimpleInt1::exec(const ExecArgs& args) const {
TensorND fsrc_tensor = *args.src_tensor;
TensorND ffilter_tensor = *args.filter_tensor;
TensorND fbias_tensor = *args.bias_tensor;
TensorND fz_tensor = *args.z_tensor;
TensorND fdst_tensor = *args.dst_tensor;
auto config = prepare_sub_opr(args);

auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args);
CompTypeCvter<dtype::QuantizedS1, dtype::Float32> cvter(args.handle, &bundle);
{
cvter.src_to_comp_type(*args.src_tensor, fsrc_tensor)
.src_to_comp_type(*args.filter_tensor, ffilter_tensor);
}
WorkspaceBundle dst_bundle = {
bundle.get(2),
{bundle.get_size(2), bundle.get_size(3), bundle.get_size(4),
bundle.get_size(5)}};
CompTypeCvter<dtype::QuantizedS32, dtype::Float32> dst_cvter(
args.handle, &dst_bundle);
{
dst_cvter.src_to_comp_type(*args.bias_tensor, fbias_tensor)
.src_to_comp_type(*args.z_tensor, fz_tensor)
.src_to_comp_type(*args.dst_tensor, fdst_tensor);
}
config.second->exec(
fsrc_tensor, ffilter_tensor, fbias_tensor, fz_tensor, fdst_tensor, nullptr,
dst_cvter.workspace());

{ dst_cvter.comp_to_dst_type(fdst_tensor, *args.dst_tensor); }
}

// vim: syntax=cpp.doxygen

+ 4
- 0
dnn/src/cuda/convolution/forward/algos.cpp View File

@@ -44,6 +44,10 @@ std::pair<TensorLayoutArray, ConvBiasForward::Param> sub_opr_config(
src.dtype.param<dtype::Quantized4Asymm>().scale *

filter.dtype.param<dtype::Quantized4Asymm>().scale);
} else if (src.dtype.enumv() == DTypeEnum::QuantizedS1) {
bias_type = dtype::QuantizedS32(
src.dtype.param<dtype::QuantizedS1>().scale *
filter.dtype.param<dtype::QuantizedS1>().scale);
} else {
megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT);
bias_type = src.dtype;


+ 3
- 0
dnn/src/naive/conv_bias/opr_impl.cpp View File

@@ -278,6 +278,9 @@ void ConvBiasForwardImpl::exec(
DISPATCH_RAW(
Quantized4Asymm, QuantizedS4, QuantizedS32, QuantizedS32, DEFAULT,
(convolution::forward_bias<dt_quint4, dt_qint4, dt_qint32, dt_qint32>))
DISPATCH_RAW(
QuantizedS1, QuantizedS1, QuantizedS32, QuantizedS32, FLOAT32,
(convolution::forward_bias<dt_qint1, dt_qint1, dt_qint32, dt_qint32>))
#if !MEGDNN_DISABLE_FLOAT16
DISPATCH(Float16, Float16)
DISPATCH_RAW(


+ 9
- 0
dnn/src/naive/convolution/helper.h View File

@@ -84,6 +84,15 @@ inline void StrategyFwd::on(
d += cast(s) * cast(f);
}

template <>
inline void StrategyFwd::on(
dt_qint1& s, dt_qint1& f, dt_qint32& d, DType, DType, DType) {
auto cast = [](const dt_qint1& val) {
return dt_qint32(static_cast<int32_t>(val.as_int8()));
};
d += cast(s) * cast(f);
}

struct StrategyBwdData {
template <typename st, typename ft, typename dt>
static void on(st& s, ft& f, dt& d, DType, DType, DType) {


+ 26
- 0
dnn/test/cuda/conv_bias.cpp View File

@@ -133,6 +133,32 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_BF16) {
}
}

TEST_F(CUDA, CONV_BIAS_FORWARD_QS1) {
require_compute_capability(6, 1);

UniformIntRNG int_rng{1, 1};
Checker<ConvBiasForward> checker(handle_cuda());
checker.set_before_exec_callback(AlgoChecker<ConvBiasForward>(
ExecutionPolicyAlgoName{"CONVBIAS_SIMPLE_INT1", {{"MATMUL", {}}}}));

ConvBias::Param param;
param.format = ConvBias::Param::Format::NCHW;
param.compute_mode = param::Convolution::ComputeMode::FLOAT32;
{
auto src_shape = TensorShape{20, 2, 224, 224};
auto filter_shape = TensorShape{20, 2, 3, 3};
checker.set_dtype(0, dtype::QuantizedS1(1.0f))
.set_dtype(1, dtype::QuantizedS1(1.0f))
.set_dtype(2, dtype::QuantizedS32(1.0f))
.set_dtype(3, dtype::QuantizedS32(1.0f))
.set_dtype(4, dtype::QuantizedS32(1.0f))
.set_rng(0, &int_rng)
.set_rng(1, &int_rng)
.set_param(param)
.execs({src_shape, filter_shape, {}, {}, {}});
}
}

TEST_F(CUDA, CONV_BIAS_FORWARD_QS8) {
require_compute_capability(6, 1);



+ 1
- 1
imperative/python/megengine/functional/nn.py View File

@@ -1509,7 +1509,7 @@ def sync_batch_norm(
"""
_eps_mode = eps_mode.lower()
assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode)
if _eps_mode == "additive" and not (is_distributed() or training):
if _eps_mode == "additive" and not (is_distributed() and training):
return batch_norm(
inp,
running_mean,


+ 0
- 1
imperative/python/megengine/functional/tensor.py View File

@@ -1244,7 +1244,6 @@ def tile(inp: Tensor, reps: Iterable[int]):
inp = _tile_one_dim(inp, rep, i)

if l_reps > l_shape:
shape = inp.shape
extra = reps[:-l_shape]
extra_ones = ones_like(extra)
base_shape = concat([extra_ones, shape])


+ 4
- 1
imperative/python/megengine/functional/utils.py View File

@@ -53,7 +53,10 @@ def _assert_equal(
"""
err = (
abs(expect - actual)
/ maximum(minimum(abs(expect), abs(actual)), Tensor(1.0, dtype="float32"))
/ maximum(
minimum(abs(expect), abs(actual)),
Tensor(1.0, dtype="float32", device=expect.device),
)
).max()
result = apply(AssertEqual(maxerr=maxerr, verbose=verbose), expect, actual, err)[0]
_sync() # sync interpreter to get exception


+ 6
- 6
imperative/python/megengine/functional/vision.py View File

@@ -660,16 +660,16 @@ def interpolate(
if mode != "linear":
wscale = (iw - 1.0) / (ow - 1.0)
row0 = concat(
[wscale, Tensor([0, 0], dtype="float32", device=inp.device)], axis=0
).reshape(1, 3)
row1 = concat(
[
Tensor(0, dtype="float32", device=inp.device),
hscale,
Tensor(0, dtype="float32", device=inp.device),
Tensor(wscale, dtype="float32", device=inp.device),
Tensor([0, 0], dtype="float32", device=inp.device),
],
axis=0,
).reshape(1, 3)
zeros = Tensor([0], dtype="float32", device=inp.device)
row1 = concat(
[zeros, Tensor(hscale, dtype="float32", device=inp.device), zeros], axis=0,
).reshape(1, 3)
weight = concat(
[row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)],
axis=0,


+ 7
- 0
imperative/python/src/ops.cpp View File

@@ -557,7 +557,14 @@ void init_ops(py::module m) {
m.def(
"delete_rng_handle",
[](size_t handle) {
if (mgb::imperative::python::interpreter_for_py->check_available()) {
mgb::imperative::python::interpreter_for_py->sync();
}
mgb::CompNode::sync_all();
mgb::CompNode::foreach ([](mgb::CompNode cn) {
auto err = cn.check_async_error();
mgb_assert(!err, "%s", err->what());
});
py_task_q.wait_all_task_finish();
rng::delete_handle(handle);
},


+ 15
- 3
imperative/python/src/tensor.cpp View File

@@ -169,7 +169,8 @@ PyObject* py_apply(
}
HostTensorND ht(target_cn);
ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype);
if (PyArray_Check(args[i])) { // non scaler
if (PyArray_Check(args[i]) || PyList_Check(args[i])) { // non scaler
// py_tuple is not allowed here because of tracing
return imperative::apply(
CreateTensor(CreateTensor::Const, target_cn, ht.layout()),
HostStorage::make(ht.storage()))[0];
@@ -189,8 +190,14 @@ PyObject* py_apply(
if (is_symbol_var[i]) {
symbol_var_idx = i;
tensors[i] = context.symvar2val(args[i]);
} else {
} else if (
DTypePromoteCfg::convert_input_enabled &&
op->same_type<Elemwise>()) {
tensors[i] = convert_pyinput_to_tensor(i);
} else {
PyErr_SetString(
PyExc_TypeError, "py_apply expects tensor as inputs");
return nullptr;
}
}
auto outputs = imperative::apply(*op, tensors);
@@ -205,8 +212,13 @@ PyObject* py_apply(
for (size_t i = 0; i < nargs; ++i) {
if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
tensors[i] = tw->m_tensor->data();
} else {
} else if (
DTypePromoteCfg::convert_input_enabled &&
op->same_type<Elemwise>()) {
tensors[i] = convert_pyinput_to_tensor(i);
} else {
PyErr_SetString(PyExc_TypeError, "py_apply expects tensor as inputs");
return nullptr;
}
}



+ 4
- 4
imperative/python/src/tensor_utils.cpp View File

@@ -957,14 +957,14 @@ std::tuple<std::vector<int32_t>, bool> tuple2vector(py::object shape) {
}

bool enable_fastpath(py::handle inp) {
// FIXME: the way to judge whether it is in traced module is inaccurate
auto&& tm_tr = TransformationManager::get_instance()
.segments[TransformationManager::Segment::ModuleTrace];
if (!TensorWrapper::try_cast(inp.ptr()) ||
TransformationManager::get_instance()
.segments[TransformationManager::Segment::Trace]
.size() > 0 ||
TransformationManager::get_instance()
.segments[TransformationManager::Segment::ModuleTrace]
.size() > 0) {
(tm_tr.size() > 0 &&
reinterpret_cast<ModuleTraceTransformation*>(tm_tr[0].get())->enabled())) {
return false;
}
return true;


+ 65
- 6
imperative/python/test/conftest.py View File

@@ -11,13 +11,17 @@ import sys

import pytest

import megengine.functional
import megengine.module
from megengine import Parameter
from megengine.core._imperative_rt.core2 import sync
from megengine.core import _config as config
from megengine.core import _trace_option as trace_option
from megengine.core import get_option
from megengine.core._imperative_rt.core2 import (
_get_amp_dtype_autocast,
_get_amp_high_prec_dtype,
_get_amp_low_prec_dtype,
_get_convert_inputs,
)
from megengine.core.tensor import amp
from megengine.device import get_device_count
from megengine.jit import trace as _trace
from megengine.module import Linear, Module

sys.path.append(os.path.join(os.path.dirname(__file__), "helpers"))

@@ -41,3 +45,58 @@ def skip_distributed(request):
platform.system()
)
)


@pytest.fixture(autouse=True)
def run_around_tests():
env_vars1 = {
"symbolic_shape": trace_option.use_symbolic_shape(),
"async_level": get_option("async_level"),
"enable_drop": get_option("enable_drop"),
"max_recompute_time": get_option("max_recompute_time"),
"catch_worker_execption": get_option("catch_worker_execption"),
"enable_host_compute": get_option("enable_host_compute"),
# "record_computing_path": get_option("record_computing_path"),
"disable_memory_forwarding": get_option("disable_memory_forwarding"),
"enable_dtr_auto_drop": get_option("enable_dtr_auto_drop"),
"enable_dtr_sqrt_sampling": get_option("enable_dtr_sqrt_sampling"),
"dtr_eviction_threshold": get_option("dtr_eviction_threshold"),
"dtr_evictee_minimum_size": get_option("dtr_evictee_minimum_size"),
"benchmark_kernel": config.benchmark_kernel,
"deterministic_kernel": config.deterministic_kernel,
"compute_mode": config._compute_mode,
"conv_format": config._conv_format,
"amp_enabled": amp.enabled,
"convert_inputs": _get_convert_inputs(),
"amp_dtype_autocast": _get_amp_dtype_autocast(),
"amp_high_prec_dtype": _get_amp_high_prec_dtype(),
"amp_low_prec_dtype": _get_amp_low_prec_dtype(),
}
yield
env_vars2 = {
"symbolic_shape": trace_option.use_symbolic_shape(),
"async_level": get_option("async_level"),
"enable_drop": get_option("enable_drop"),
"max_recompute_time": get_option("max_recompute_time"),
"catch_worker_execption": get_option("catch_worker_execption"),
"enable_host_compute": get_option("enable_host_compute"),
# "record_computing_path": get_option("record_computing_path"),
"disable_memory_forwarding": get_option("disable_memory_forwarding"),
"enable_dtr_auto_drop": get_option("enable_dtr_auto_drop"),
"enable_dtr_sqrt_sampling": get_option("enable_dtr_sqrt_sampling"),
"dtr_eviction_threshold": get_option("dtr_eviction_threshold"),
"dtr_evictee_minimum_size": get_option("dtr_evictee_minimum_size"),
"benchmark_kernel": config.benchmark_kernel,
"deterministic_kernel": config.deterministic_kernel,
"compute_mode": config._compute_mode,
"conv_format": config._conv_format,
"amp_enabled": amp.enabled,
"convert_inputs": _get_convert_inputs(),
"amp_dtype_autocast": _get_amp_dtype_autocast(),
"amp_high_prec_dtype": _get_amp_high_prec_dtype(),
"amp_low_prec_dtype": _get_amp_low_prec_dtype(),
}
for key in env_vars1:
assert (
env_vars1[key] == env_vars2[key]
), "{} have been changed after test".format(key)

+ 1
- 1
imperative/python/test/run.sh View File

@@ -37,7 +37,7 @@ if [[ "$TEST_PLAT" =~ "local" ]]; then
PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -s -v $test_dirs -m 'not isolated_distributed'
if [[ "$TEST_PLAT" =~ "cuda" ]]; then
echo "test GPU pytest now"
PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -s -v $test_dirs -m 'isolated_distributed'
PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -s -v $test_dirs -m 'isolated_distributed' --ignore=./integration/test_dtr.py
fi
else
cd $(dirname "${BASH_SOURCE[0]}")/..


+ 5
- 0
imperative/python/test/unit/functional/test_elemwise.py View File

@@ -77,6 +77,11 @@ def test_div():
np.floor_divide(np.array([-5, -7], dtype=np.int32), 2),
)

np.testing.assert_allclose(
(tensor([[5, 4, 3], [4, 2, 6]]) // [1, 2, 1]).numpy(),
np.floor_divide(np.array([[5, 4, 3], [4, 2, 6]], dtype=np.int32), [1, 2, 1]),
)


def test_clamp():
"""Fix an issue when `lower` or `upper` is 0, it will be recognized as `False` and


+ 15
- 15
imperative/python/test/unit/functional/test_functional.py View File

@@ -206,31 +206,31 @@ def test_interpolate():
def linear_interpolate():
inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))

out = F.vision.interpolate(inp, scale_factor=2.0, mode="linear")
out2 = F.vision.interpolate(inp, 4, mode="linear")

np.testing.assert_allclose(
out.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32)
)
np.testing.assert_allclose(
out2.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32)
test_func = lambda inp: F.vision.interpolate(
inp, scale_factor=2.0, mode="linear"
)
ref_func = lambda inp: F.vision.interpolate(inp, 4, mode="linear").numpy()

cases = [{"input": inp}]
opr_test(cases, test_func, ref_fn=ref_func, test_trace=True)

def many_batch_interpolate():
inp = tensor(np.arange(1, 9, dtype=np.float32).reshape(2, 1, 2, 2))

out = F.vision.interpolate(inp, [4, 4])
out2 = F.vision.interpolate(inp, scale_factor=2.0)
test_func = lambda inp: F.vision.interpolate(inp, scale_factor=2.0)
ref_func = lambda inp: F.vision.interpolate(inp, [4, 4]).numpy()

np.testing.assert_allclose(out.numpy(), out2.numpy())
cases = [{"input": inp}]
opr_test(cases, test_func, ref_fn=ref_func, test_trace=True)

def assign_corner_interpolate():
inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))

out = F.vision.interpolate(inp, [4, 4], align_corners=True)
out2 = F.vision.interpolate(inp, scale_factor=2.0, align_corners=True)
test_func = lambda inp: F.vision.interpolate(inp, [4, 4])
ref_func = lambda inp: F.vision.interpolate(inp, scale_factor=2.0).numpy()

np.testing.assert_allclose(out.numpy(), out2.numpy())
cases = [{"input": inp}]
opr_test(cases, test_func, ref_fn=ref_func, test_trace=True)

def error_shape_linear_interpolate():
inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
@@ -248,7 +248,7 @@ def test_interpolate():
many_batch_interpolate()
assign_corner_interpolate()
error_shape_linear_interpolate()
inappropriate_scale_linear_interpolate()
# inappropriate_scale_linear_interpolate()


def _save_to(self, name="grad"):


+ 2
- 1
imperative/python/test/unit/functional/test_tensor.py View File

@@ -831,7 +831,8 @@ def test_repeat(shape, repeats, axis, is_varnode):
((2,), (2,)),
((2, 3, 4, 5), (1, 1, 1, 1)),
((2, 3, 4, 5), (1, 2, 3, 4)),
((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)),
# FIXME: tile does not support ndim 7
# ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)),
],
)
@pytest.mark.parametrize("is_varnode", [True])


+ 0
- 1
imperative/python/test/unit/jit/test_tracing.py View File

@@ -21,7 +21,6 @@ import megengine.optimizer as optim
import megengine.utils.comp_graph_tools as cgtools
from megengine import Parameter, tensor
from megengine.autodiff import GradManager
from megengine.core._trace_option import set_symbolic_shape
from megengine.core.ops import builtin as ops
from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.utils import isscalar


+ 0
- 3
imperative/python/test/unit/random/test_rng.py View File

@@ -39,8 +39,6 @@ from megengine.random import uniform
get_device_count("xpu") <= 2, reason="xpu counts need > 2",
)
def test_gaussian_op():
# FIXME: remove this sync
mge.core.set_option("async_level", 0)
set_global_seed(1024)
shape = (
8,
@@ -516,4 +514,3 @@ def test_rng_empty_tensor(is_symbolic):
np.testing.assert_equal(out.numpy().shape, (0,))
if is_symbolic is None:
break
mge.core.set_option("async_level", 2)

+ 3
- 2
imperative/python/test/unit/traced_module/test_preprocess_1.py View File

@@ -10,8 +10,6 @@ from megengine.core._trace_option import set_symbolic_shape
from megengine.jit import trace
from megengine.traced_module import trace_module

set_symbolic_shape(True)


class Main(M.Module):
def forward(self, x):
@@ -61,6 +59,7 @@ class Net(M.Module):


def test_preprocess():
saved = set_symbolic_shape(True)
module = Main()
data = F.ones((1, 14, 8, 8), dtype=np.uint8)
traced_module = trace_module(module, data)
@@ -88,3 +87,5 @@ def test_preprocess():
y,
atol=1e-6,
)

set_symbolic_shape(saved)

+ 3
- 2
imperative/python/test/unit/traced_module/test_preprocess_2.py View File

@@ -11,8 +11,6 @@ from megengine.core._trace_option import set_symbolic_shape
from megengine.jit import trace
from megengine.traced_module import trace_module

set_symbolic_shape(True)


class Main(M.Module):
def forward(self, x):
@@ -64,6 +62,7 @@ class Net(M.Module):


def test_preprocess():
saved = set_symbolic_shape(True)
batch_size = 2
module = Main()
data = mge.tensor(
@@ -92,3 +91,5 @@ def test_preprocess():
infer_cg.run(inp_dict={"data": data.numpy(), "quad": quad.numpy()}).values()
)[0]
np.testing.assert_allclose(expect, actual)

set_symbolic_shape(saved)

+ 0
- 1
imperative/src/impl/interpreter/interpreter_impl.cpp View File

@@ -717,7 +717,6 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
if (state.options.enable_dtr_auto_drop || state.options.disable_memory_forwarding) {
ptr->to_contiguous_inplace();
}
dest->desc.layout = ptr->layout();
dest->desc.comp_node = ptr->comp_node();
dest->memory = ptr->blob()->size();
dest->ptr = std::move(ptr);


+ 9
- 4
imperative/src/impl/ops/reduce.cpp View File

@@ -175,10 +175,9 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
megdnn::Workspace dnn_wk;

auto wk_size = dnn_op.op->get_workspace_in_bytes(src, layout);
if (wk_size != 0) {
auto wk = Blob::make(comp_node, wk_size);
dnn_wk.raw_ptr = wk->storage().get();
dnn_wk.size = wk_size;
if (wk_size) {
TensorLayout w_layout({wk_size}, dtype::Byte());
dnn_wk = dnn_op.create_workspace(w_layout);
}

DeviceTensorND out =
@@ -205,6 +204,12 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
size_t size = inputs.size();
SmallVector<LogicalTensorDesc> dests(size);

for (size_t i = 0; i < size; i++) {
if (inputs[i].layout.ndim == 0) {
return {{{TensorLayout(inputs[0].layout.dtype), inputs[0].comp_node}},
false};
}
}
if (size > 1) {
auto [output_descs, validated] =
proxy_graph_detail::infer_output_attrs_fallible(def, inputs);


+ 6
- 1
imperative/src/impl/ops/rng.cpp View File

@@ -548,6 +548,7 @@ Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
template <typename Op>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
bool success = inputs[0].layout.ndim != 0;
LogicalTensorDesc dest;
auto&& xxx_rng_def = def.cast_final_safe<Op>();
size_t nr_inp = inputs.size();
@@ -558,7 +559,11 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
xxx_rng_def.dyn_typeinfo()->name, nr_inp);
}
dest.comp_node = inputs[0].comp_node;
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def);
if (success) {
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def);
} else {
dest.layout = TensorLayout(inputs[0].layout.dtype);
}
return {{dest}, inputs[0].layout.ndim != 0};
}



+ 3
- 0
imperative/src/impl/ops/tensor_manip.cpp View File

@@ -115,6 +115,9 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
TensorShapeArray src(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
src[i] = inputs[i].layout;
if (!src[i].ndim) {
return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false};
}
}
megdnn::Elemwise::deduce_shape(src, shp);
}


+ 15
- 2
lite/src/mge/network_impl.cpp View File

@@ -67,10 +67,15 @@ void NetworkImplDft::shared_weight_with(const NetworkImplBase* src_network) {
void NetworkImplDft::application_config() {
auto device_type = m_user_config->device_type;
m_compnode_locator.type = to_compnode_locator(device_type).type;
m_compnode_locator.device = m_user_config->device_id;
//! when the device id is not configured, configure it
if (m_compnode_locator.device == -1) {
m_compnode_locator.device = m_user_config->device_id;
}
if (m_nr_threads > 1 && device_type == LiteDeviceType::LITE_CPU) {
m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD;
m_compnode_locator.device = m_user_config->device_id;
if (m_compnode_locator.device == -1) {
m_compnode_locator.device = m_user_config->device_id;
}
}
//! model options
#define ConfigOption(mge_name, lite_name) \
@@ -155,11 +160,13 @@ void NetworkImplDft::set_cpu_inplace_mode() {
m_is_cpu_inplace_mode = true;
if (m_compnode_locator.type == mgb::CompNode::DeviceType::CPU) {
m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT;
m_user_config->device_id = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT;
} else {
LITE_ASSERT(
m_compnode_locator.type == CompNode::DeviceType::MULTITHREAD,
"cpu inplace mode is only avaliable in CPU.");
m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
m_user_config->device_id = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
}
}

@@ -170,6 +177,12 @@ void NetworkImplDft::set_cpu_threads_number(size_t nr_threads) {
if (nr_threads > 1) {
m_nr_threads = nr_threads;
m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD;
if (m_is_cpu_inplace_mode) {
m_compnode_locator.device =
mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
m_user_config->device_id =
mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
}
m_compnode_locator.nr_threads = nr_threads;
}
}


+ 51
- 0
lite/test/test_network.cpp View File

@@ -216,6 +216,57 @@ TEST(TestNetWork, BasicInplaceAndSingleThreadAffinity) {
compare_lite_tensor<float>(output_tensor, result_mgb);
}

namespace {
void test_multi_thread(bool multi_thread_compnode) {
Config config;
auto lite_tensor = get_input_data("./input_data.npy");
std::string model_path = "./shufflenet.mge";

size_t nr_threads = 2;
std::vector<std::thread::id> thread_ids(nr_threads);
auto runner = [&](size_t i) {
std::shared_ptr<Network> network = std::make_shared<Network>(config);
Runtime::set_cpu_inplace_mode(network);
if (multi_thread_compnode) {
Runtime::set_cpu_threads_number(network, 2);
}

network->load_model(model_path);
Runtime::set_runtime_thread_affinity(network, [&thread_ids, i](int id) {
if (id == 0) {
thread_ids[i] = std::this_thread::get_id();
}
});

std::shared_ptr<Tensor> input_tensor = network->get_input_tensor(0);
auto src_ptr = lite_tensor->get_memory_ptr();
auto src_layout = lite_tensor->get_layout();
input_tensor->reset(src_ptr, src_layout);

network->forward();
network->wait();
std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0);
};
std::vector<std::thread> threads;
for (size_t i = 0; i < nr_threads; i++) {
threads.emplace_back(runner, i);
}
for (size_t i = 0; i < nr_threads; i++) {
threads[i].join();
}
ASSERT_NE(thread_ids[0], thread_ids[1]);
}

} // namespace

TEST(TestNetWork, InplaceAndUserMultithreadThread) {
test_multi_thread(false);
}

TEST(TestNetWork, InplaceAndMultithread) {
test_multi_thread(true);
}

TEST(TestNetWork, NetworkShareWeights) {
Config config;
auto lite_tensor = get_input_data("./input_data.npy");


+ 2
- 2
src/core/include/megbrain/version.h View File

@@ -14,8 +14,8 @@
#include "megbrain_build_config.h"

#define MGE_MAJOR 1
#define MGE_MINOR 8
#define MGE_PATCH 0
#define MGE_MINOR 9
#define MGE_PATCH 1

// for rc version, could be like "rc1", "rc2", etc
#define MGE_EXTRA_NAME ""


Loading…
Cancel
Save