Browse Source

feat(dnn/cuda): add conv u4xs4 sass kernel

GitOrigin-RevId: 4defcf5f1f
release-1.5
Megvii Engine Team 4 years ago
parent
commit
4a802d21ca
21 changed files with 447 additions and 543 deletions
  1. +3
    -1
      dnn/src/common/conv_bias.cpp
  2. +4
    -2
      dnn/src/common/convolution.cpp
  3. +11
    -2
      dnn/src/common/tensor_format.cpp
  4. +1
    -0
      dnn/src/common/utils.cpp
  5. +2
    -1
      dnn/src/cuda/conv_bias/algo.h
  6. +6
    -6
      dnn/src/cuda/conv_bias/quint4x4x32_wmma.cpp
  7. +5
    -6
      dnn/src/cuda/conv_bias/quint4x4x32_wmma/activation_u4.cu
  8. +4
    -6
      dnn/src/cuda/conv_bias/quint4x4x32_wmma/activation_u4.cuh
  9. +1
    -1
      dnn/src/cuda/conv_bias/quint4x4x32_wmma/reduce_with_scale_data.cu
  10. +1
    -1
      dnn/src/cuda/conv_bias/quint4x4x32_wmma/reduce_with_scale_data.cuh
  11. +0
    -100
      dnn/src/cuda/conv_bias/quint4x4x32_wmma/reduce_with_scale_filter.cu
  12. +0
    -48
      dnn/src/cuda/conv_bias/quint4x4x32_wmma/reduce_with_scale_filter.cuh
  13. +114
    -0
      dnn/src/cuda/conv_bias/reduce_with_scale_filter.cu
  14. +52
    -0
      dnn/src/cuda/conv_bias/reduce_with_scale_filter.cuh
  15. +0
    -312
      dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp
  16. +44
    -6
      dnn/src/naive/conv_bias/opr_impl.cpp
  17. +48
    -0
      dnn/src/naive/lowbit_utils.cpp
  18. +4
    -0
      dnn/src/naive/lowbit_utils.h
  19. +20
    -3
      dnn/test/common/conv_bias.cpp
  20. +62
    -46
      dnn/test/cuda/conv_test_utils.cpp
  21. +65
    -2
      dnn/test/naive/conv_bias.cpp

+ 3
- 1
dnn/src/common/conv_bias.cpp View File

@@ -35,7 +35,9 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
const TensorLayout& bias, const TensorLayout& z, const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst, size_t workspace_in_bytes, const TensorLayout& dst, size_t workspace_in_bytes,
const PreprocessedFilter* preprocessed_filter) { const PreprocessedFilter* preprocessed_filter) {
megdnn_assert(src.dtype.enumv() == filter.dtype.enumv());
megdnn_assert((src.dtype.enumv() == filter.dtype.enumv()) ||
(src.dtype.enumv() == DTypeEnum::Quantized4Asymm &&
filter.dtype.enumv() == DTypeEnum::QuantizedS4));
// check compatibility of bias's scale // check compatibility of bias's scale
if (src.dtype.category() == DTypeCategory::QUANTIZED) { if (src.dtype.category() == DTypeCategory::QUANTIZED) {
if (bias.dtype.enumv() == DTypeEnum::QuantizedS32) { if (bias.dtype.enumv() == DTypeEnum::QuantizedS32) {


+ 4
- 2
dnn/src/common/convolution.cpp View File

@@ -598,8 +598,10 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
megdnn_assert_contiguous(src); megdnn_assert_contiguous(src);
megdnn_assert_contiguous(filter); megdnn_assert_contiguous(filter);
megdnn_assert(src.ndim >= 3_z, "%s", errmsg().c_str()); megdnn_assert(src.ndim >= 3_z, "%s", errmsg().c_str());
megdnn_assert(src.dtype.enumv() == filter.dtype.enumv(), "%s",
errmsg().c_str());
megdnn_assert(((src.dtype.enumv() == filter.dtype.enumv()) ||
(src.dtype.enumv() == DTypeEnum::Quantized4Asymm &&
filter.dtype.enumv() == DTypeEnum::QuantizedS4)),
"%s", errmsg().c_str());
check_or_deduce_dtype_fwd(src.dtype, filter.dtype, dst.dtype); check_or_deduce_dtype_fwd(src.dtype, filter.dtype, dst.dtype);
size_t img_dim; size_t img_dim;
if (param().format == Param::Format::NCHW || if (param().format == Param::Format::NCHW ||


+ 11
- 2
dnn/src/common/tensor_format.cpp View File

@@ -488,6 +488,10 @@ void LowbitsAlignedTensorFormatBase::assert_valid(
"bad stride:%s, %zu", layout.to_string().c_str(), "bad stride:%s, %zu", layout.to_string().c_str(),
layout.stride[i]); layout.stride[i]);
} }
if (!has_dim_unity_stride &&
(int)layout.stride[layout.ndim - 1] ==
round_up(1, (int)m_align_size_in_elements))
has_dim_unity_stride = true;
megdnn_assert(layout.ndim == 0 || has_dim_unity_stride, megdnn_assert(layout.ndim == 0 || has_dim_unity_stride,
"innermost dim not contiguous"); "innermost dim not contiguous");
} }
@@ -546,7 +550,12 @@ bool LowbitsAlignedTensorFormatBase::is_contiguous_spec(
assert_valid(layout); assert_valid(layout);
ptrdiff_t expected = 1; ptrdiff_t expected = 1;
for (int i = static_cast<int>(layout.ndim) - 1; i >= 0; --i) { for (int i = static_cast<int>(layout.ndim) - 1; i >= 0; --i) {
if (layout.shape[i] != 1 && layout.stride[i] != expected)
bool is_valid_stride =
(layout.stride[i] == expected) ||
(expected == 1 &&
(int)layout.stride[i] ==
round_up(1, (int)m_align_size_in_elements));
if (layout.shape[i] != 1 && !is_valid_stride)
return false; return false;
auto multiplier = layout.shape[i]; auto multiplier = layout.shape[i];
if (i == static_cast<int>(layout.ndim) - 1) if (i == static_cast<int>(layout.ndim) - 1)
@@ -568,7 +577,7 @@ TensorLayout LowbitsAlignedTensorFormatBase::collapse_contiguous_spec(
res.stride[0] = 1; res.stride[0] = 1;
return res; return res;
} }
if (res.shape[i] == 1 && res.stride[i] != 1) {
if (res.shape[i] == 1) {
res.remove_axis_inplace(i); res.remove_axis_inplace(i);
} }
} }


+ 1
- 0
dnn/src/common/utils.cpp View File

@@ -232,6 +232,7 @@ float megdnn::mul_scale(DType lhs, DType rhs) {
(rhs.enumv() == DTypeTrait<dt2>::enumv)) \ (rhs.enumv() == DTypeTrait<dt2>::enumv)) \
return lhs.param<dt1>().scale * rhs.param<dt2>().scale; return lhs.param<dt1>().scale * rhs.param<dt2>().scale;
cb_binary(::megdnn::dtype::QuantizedS8, ::megdnn::dtype::QuantizedS16) cb_binary(::megdnn::dtype::QuantizedS8, ::megdnn::dtype::QuantizedS16)
cb_binary(::megdnn::dtype::Quantized4Asymm, ::megdnn::dtype::QuantizedS4)
#undef cb_binary #undef cb_binary


megdnn_assert(lhs.enumv() == rhs.enumv()); megdnn_assert(lhs.enumv() == rhs.enumv());


+ 2
- 1
dnn/src/cuda/conv_bias/algo.h View File

@@ -66,7 +66,8 @@ public:
CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW4_DOTPROD_INT8, CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW4_DOTPROD_INT8,
CUDA_IMPLICIT_GEMM_SASS_NCHW32_IMMA_INT8, CUDA_IMPLICIT_GEMM_SASS_NCHW32_IMMA_INT8,
CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW32_IMMA_INT8, CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW32_IMMA_INT8,
CUDA_IMPLICIT_GEMM_SASS_NCHW64_IMMA_INT4,
CUDA_IMPLICIT_GEMM_SASS_NCHW64_IMMA_INT4_INT4,
CUDA_IMPLICIT_GEMM_SASS_NCHW64_IMMA_UINT4_INT4,
}; };
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;




+ 6
- 6
dnn/src/cuda/conv_bias/quint4x4x32_wmma.cpp View File

@@ -15,7 +15,7 @@


#include "./quint4x4x32_wmma/activation_u4.cuh" #include "./quint4x4x32_wmma/activation_u4.cuh"
#include "./quint4x4x32_wmma/reduce_with_scale_data.cuh" #include "./quint4x4x32_wmma/reduce_with_scale_data.cuh"
#include "./quint4x4x32_wmma/reduce_with_scale_filter.cuh"
#include "./reduce_with_scale_filter.cuh"
#include "./quint4x4x32_wmma/wmma_conv_integer_u4.cuh" #include "./quint4x4x32_wmma/wmma_conv_integer_u4.cuh"


using namespace megdnn; using namespace megdnn;
@@ -75,7 +75,7 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::get_workspace_bundle(
// for reduce filter // for reduce filter
{ {
size_t A = OC, B = IC * FH * FW / 8, C = 1; size_t A = OC, B = IC * FH * FW / 8, C = 1;
ws_size_zp_filter += _do_dispatch_reduce_workspace_in_bytes(A, B, C);
ws_size_zp_filter += do_dispatch_reduce_workspace_in_bytes(A, B, C);
} }
size_t ws_size_zp_data = N * OH * OW * sizeof(int32_t); size_t ws_size_zp_data = N * OH * OW * sizeof(int32_t);
size_t ws_size_relayout_filter = get_workspace_in_bytes_do_conv(args); size_t ws_size_relayout_filter = get_workspace_in_bytes_do_conv(args);
@@ -135,11 +135,11 @@ void ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::exec(
int32_t zp_data_filter = zp_data * zp_filter * FH * FW * IC; int32_t zp_data_filter = zp_data * zp_filter * FH * FW * IC;
auto&& stream = cuda_stream(handle); auto&& stream = cuda_stream(handle);
// zp filter // zp filter
_do_dispatch_reduce_with_scale_filter_u4(
do_dispatch_reduce_with_scale_filter_4bit<false>(
static_cast<uint8_t*>(args.filter_tensor->raw_ptr), -zp_data, OC, static_cast<uint8_t*>(args.filter_tensor->raw_ptr), -zp_data, OC,
FH * FW * IC / 8, ws_zp_filter.ptr<int32_t>(), stream); FH * FW * IC / 8, ws_zp_filter.ptr<int32_t>(), stream);
// zp data // zp data
_do_dispatch_reduce_with_scale_data_u4(
do_dispatch_reduce_with_scale_data_u4(
ws_zp_data.ptr<int32_t>(), ws_zp_data.ptr<int32_t>(),
static_cast<uint8_t*>(args.src_tensor->raw_ptr), N, IH, IW, OH, OW, static_cast<uint8_t*>(args.src_tensor->raw_ptr), N, IH, IW, OH, OW,
PH, PW, FH, FW, SH, SW, IC, -zp_filter, PH, PW, FH, FW, SH, SW, IC, -zp_filter,
@@ -173,12 +173,12 @@ void ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::exec(
args.bias_tensor->compatible_ptr<int32_t>(), s0, s1, s2, s3}; args.bias_tensor->compatible_ptr<int32_t>(), s0, s1, s2, s3};
auto&& param = args.opr->param(); auto&& param = args.opr->param();
if (param.nonlineMode == Param::NonlineMode::RELU) { if (param.nonlineMode == Param::NonlineMode::RELU) {
_do_dispatch_activation_u4<ActivationRELU>(
do_dispatch_activation_u4<ActivationRELU>(
args.dst_tensor->compatible_ptr<int32_t>(), visitor, args.dst_tensor->compatible_ptr<int32_t>(), visitor,
ws_zp_data.ptr<int32_t>(), ws_zp_filter.ptr<int32_t>(), ws_zp_data.ptr<int32_t>(), ws_zp_filter.ptr<int32_t>(),
zp_data_filter, N, OC, OH, OW, stream); zp_data_filter, N, OC, OH, OW, stream);
} else if (param.nonlineMode == Param::NonlineMode::IDENTITY) { } else if (param.nonlineMode == Param::NonlineMode::IDENTITY) {
_do_dispatch_activation_u4<ActivationIdentity>(
do_dispatch_activation_u4<ActivationIdentity>(
args.dst_tensor->compatible_ptr<int32_t>(), visitor, args.dst_tensor->compatible_ptr<int32_t>(), visitor,
ws_zp_data.ptr<int32_t>(), ws_zp_filter.ptr<int32_t>(), ws_zp_data.ptr<int32_t>(), ws_zp_filter.ptr<int32_t>(),
zp_data_filter, N, OC, OH, OW, stream); zp_data_filter, N, OC, OH, OW, stream);


+ 5
- 6
dnn/src/cuda/conv_bias/quint4x4x32_wmma/activation_u4.cu View File

@@ -87,11 +87,10 @@ __global__ void kern_activation_u4(int32_t* dst, const int32_t* zp_data,
} // namespace } // namespace


template <typename ActivationOp> template <typename ActivationOp>
void _do_dispatch_activation_u4(int32_t* dst, BiasVisitor visitor,
const int32_t* zp_data,
const int32_t* zp_filter,
int32_t zp_data_filter, int batch_size, int co,
int ho, int wo, cudaStream_t stream) {
void do_dispatch_activation_u4(int32_t* dst, BiasVisitor visitor,
const int32_t* zp_data, const int32_t* zp_filter,
int32_t zp_data_filter, int batch_size, int co,
int ho, int wo, cudaStream_t stream) {
void (*fptr)(int32_t*, const int32_t*, const int32_t*, int32_t, int, int OC, void (*fptr)(int32_t*, const int32_t*, const int32_t*, int32_t, int, int OC,
int, int, BiasVisitor) = kern_activation_u4<ActivationOp>; int, int, BiasVisitor) = kern_activation_u4<ActivationOp>;
dim3 grids{0, 0, 0}; dim3 grids{0, 0, 0};
@@ -105,7 +104,7 @@ void _do_dispatch_activation_u4(int32_t* dst, BiasVisitor visitor,
} }


#define INST(_op) \ #define INST(_op) \
template void _do_dispatch_activation_u4<_op>( \
template void do_dispatch_activation_u4<_op>( \
int32_t * dst, BiasVisitor visitor, const int32_t* zp_data, \ int32_t * dst, BiasVisitor visitor, const int32_t* zp_data, \
const int32_t* zp_filter, int32_t zp_data_filter, int batch_size, \ const int32_t* zp_filter, int32_t zp_data_filter, int batch_size, \
int co, int ho, int wo, cudaStream_t stream); int co, int ho, int wo, cudaStream_t stream);


+ 4
- 6
dnn/src/cuda/conv_bias/quint4x4x32_wmma/activation_u4.cuh View File

@@ -82,12 +82,10 @@ struct ActivationIdentity {
} // namespace activation_u4 } // namespace activation_u4


template <typename ActivationOp> template <typename ActivationOp>
void _do_dispatch_activation_u4(int32_t* dst,
activation_u4::BiasVisitor visitor,
const int32_t* zp_data,
const int32_t* zp_filter,
int32_t zp_data_filter, int batch_size, int co,
int ho, int wo, cudaStream_t stream);
void do_dispatch_activation_u4(int32_t* dst, activation_u4::BiasVisitor visitor,
const int32_t* zp_data, const int32_t* zp_filter,
int32_t zp_data_filter, int batch_size, int co,
int ho, int wo, cudaStream_t stream);


} // namespace cuda } // namespace cuda
} // namespace megdnn } // namespace megdnn


+ 1
- 1
dnn/src/cuda/conv_bias/quint4x4x32_wmma/reduce_with_scale_data.cu View File

@@ -444,7 +444,7 @@ reduce_in_spatial_block_and_along_input_channel_with_scale_u4_large_channels(


} // namespace } // namespace


void megdnn::cuda::_do_dispatch_reduce_with_scale_data_u4(
void megdnn::cuda::do_dispatch_reduce_with_scale_data_u4(
int32_t* dst, const uint8_t* src, int batch_size, int ih, int iw, int32_t* dst, const uint8_t* src, int batch_size, int ih, int iw,
int oh, int ow, int ph, int pw, int fh, int fw, int sh, int sw, int ic, int oh, int ow, int ph, int pw, int fh, int fw, int sh, int sw, int ic,
int32_t scale, uint8_t zp_data, cudaStream_t stream) { int32_t scale, uint8_t zp_data, cudaStream_t stream) {


+ 1
- 1
dnn/src/cuda/conv_bias/quint4x4x32_wmma/reduce_with_scale_data.cuh View File

@@ -37,7 +37,7 @@


namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {
void _do_dispatch_reduce_with_scale_data_u4(
void do_dispatch_reduce_with_scale_data_u4(
int32_t* dst, const uint8_t* src, int batch_size, int ih, int iw, int32_t* dst, const uint8_t* src, int batch_size, int ih, int iw,
int oh, int ow, int ph, int pw, int fh, int fw, int sh, int sw, int ic, int oh, int ow, int ph, int pw, int fh, int fw, int sh, int sw, int ic,
int32_t scale, uint8_t zp_data, cudaStream_t stream); int32_t scale, uint8_t zp_data, cudaStream_t stream);


+ 0
- 100
dnn/src/cuda/conv_bias/quint4x4x32_wmma/reduce_with_scale_filter.cu View File

@@ -1,100 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
* \file dnn/src/cuda/conv_bias/quint4x4x32_wmma/reduce_with_scale_filter.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "./reduce_with_scale_filter.cuh"
#include "src/cuda/reduce_helper.cuh"

using namespace megdnn;
using namespace cuda;

namespace {

struct ReduceWithScaleUInt4Op {
typedef int32_t wtype;
const uint8_t* src;
int32_t* dst;
int32_t scale;
static const wtype INIT = 0;

#if MEGDNN_CC_CUDA
__host__ __device__ void write(uint32_t idx, wtype val) {
dst[idx] = val * scale;
}

__host__ __device__ static wtype apply(wtype a, wtype b) { return a + b; }

__device__ wtype read(uint32_t idx) {
constexpr uint32_t subbytes_per_pixel = 8;
const uint32_t* sptr =
(const uint32_t*)(src + subbytes_per_pixel * idx / 2);
uint32_t val = *sptr;
int32_t ret = 0;
#pragma unroll
for (int j = 0; j < 8; j++) {
uint8_t cur = (val & 0xF);
ret += cur;
val = (val >> 4);
}
return ret;
}
#endif
};

} // namespace

void megdnn::cuda::_do_dispatch_reduce_with_scale_filter_u4(
const uint8_t* src, int32_t scale, uint32_t rows, uint32_t cols,
int32_t* dst, cudaStream_t stream) {
// rows = OC
// cols is measured in pixels, i.e. IC * FH * FW / 8, a pixel consists of 8
// subbyte data,
ReduceWithScaleUInt4Op op;
op.src = src;
op.scale = scale;
op.dst = dst;
static_cast<void>(op);
static_cast<void>(stream);
static_cast<void>(rows);
static_cast<void>(cols);
run_reduce<ReduceWithScaleUInt4Op, false>(dst + rows, rows, cols, 1, stream,
op);
}

size_t megdnn::cuda::_do_dispatch_reduce_workspace_in_bytes(size_t A, size_t B,
size_t C) {
return get_reduce_workspace_in_bytes<ReduceWithScaleUInt4Op>(A, B, C);
}

// vim: ft=cpp syntax=cuda.doxygen

+ 0
- 48
dnn/src/cuda/conv_bias/quint4x4x32_wmma/reduce_with_scale_filter.cuh View File

@@ -1,48 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
* \file dnn/src/cuda/conv_bias/quint4x4x32_wmma/reduce_with_scale_filter.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/cuda/utils.cuh"

namespace megdnn {
namespace cuda {
void _do_dispatch_reduce_with_scale_filter_u4(const uint8_t* src, int32_t scale,
uint32_t rows, uint32_t cols,
int32_t* dst,
cudaStream_t stream);
size_t _do_dispatch_reduce_workspace_in_bytes(size_t A, size_t B, size_t C);
} // namespace cuda
} // namespace megdnn

// vim: ft=cpp syntax=cuda.doxygen

+ 114
- 0
dnn/src/cuda/conv_bias/reduce_with_scale_filter.cu View File

@@ -0,0 +1,114 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
*modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
*notice, this list of conditions and the following disclaimer in the
*documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its
*contributors may be used to endorse or promote products derived from this
*software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
* \file dnn/src/cuda/conv_bias/reduce_with_scale_filter.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "./reduce_with_scale_filter.cuh"
#include "src/cuda/reduce_helper.cuh"
#include "src/cuda/integer_subbyte_utils.cuh"

using namespace megdnn;
using namespace cuda;

namespace {

template <bool signedness>
struct ReduceWithScaleInt4Op {
typedef int32_t wtype;
const uint8_t* src;
int32_t* dst;
int32_t scale;
static const wtype INIT = 0;

#if MEGDNN_CC_CUDA
__host__ __device__ void write(uint32_t idx, wtype val) {
dst[idx] = val * scale;
}

__host__ __device__ static wtype apply(wtype a, wtype b) { return a + b; }

__device__ wtype read(uint32_t idx) {
constexpr uint32_t subbytes_per_pixel = 8;
const uint32_t* sptr =
(const uint32_t*)(src + subbytes_per_pixel * idx / 2);
uint32_t val = *sptr;
int32_t ret = 0;
#pragma unroll
for (int j = 0; j < 8; j++) {
ret += integer_subbyte::unpack_integer_4bits<signedness>(val,
(j << 2));
}
return ret;
}
#endif
};

} // namespace

template <bool signedness>
void megdnn::cuda::do_dispatch_reduce_with_scale_filter_4bit(
const uint8_t* src, int32_t scale, uint32_t rows, uint32_t cols,
int32_t* dst, cudaStream_t stream) {
// rows = OC
// cols is measured in pixels, i.e. IC * FH * FW / 8, a pixel consists of 8
// subbyte data,
ReduceWithScaleInt4Op<signedness> op;
op.src = src;
op.scale = scale;
op.dst = dst;
static_cast<void>(op);
static_cast<void>(stream);
static_cast<void>(rows);
static_cast<void>(cols);
run_reduce<ReduceWithScaleInt4Op<signedness>, false>(dst + rows, rows, cols,
1, stream, op);
}

#define INST(signedness) \
template void \
megdnn::cuda::do_dispatch_reduce_with_scale_filter_4bit<signedness>( \
const uint8_t* src, int32_t scale, uint32_t rows, uint32_t cols, \
int32_t* dst, cudaStream_t stream)
INST(false);
INST(true);
#undef INST

size_t megdnn::cuda::do_dispatch_reduce_workspace_in_bytes(size_t A, size_t B,
size_t C) {
return get_reduce_workspace_in_bytes<ReduceWithScaleInt4Op<false>>(A, B, C);
}

// vim: ft=cpp syntax=cuda.doxygen

+ 52
- 0
dnn/src/cuda/conv_bias/reduce_with_scale_filter.cuh View File

@@ -0,0 +1,52 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
*modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
*notice, this list of conditions and the following disclaimer in the
*documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its
*contributors may be used to endorse or promote products derived from this
*software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
* \file dnn/src/cuda/conv_bias/reduce_with_scale_filter.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/cuda/utils.cuh"

namespace megdnn {
namespace cuda {
template <bool signedness>
void do_dispatch_reduce_with_scale_filter_4bit(const uint8_t* src,
int32_t scale, uint32_t rows,
uint32_t cols, int32_t* dst,
cudaStream_t stream);
size_t do_dispatch_reduce_workspace_in_bytes(size_t A, size_t B, size_t C);
} // namespace cuda
} // namespace megdnn

// vim: ft=cpp syntax=cuda.doxygen

+ 0
- 312
dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp View File

@@ -1,312 +0,0 @@
/**
* \file dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "./algo.h"
#include "src/cuda/conv_bias/sass_helper.cuh"
#include "src/cuda/sass_loader.h"
#include "src/cuda/utils.h"
#include "src/common/conv_bias.h"

using namespace megdnn;
using namespace cuda;
using namespace sass;

namespace {
#if !MEGDNN_TEGRA_X1
// all stride are in bytes
void compute_conv2d_offset(size_t fh, size_t fw, size_t ics, size_t ihs,
Conv2dConstantOffset& constant_offset) {
constexpr int interleaved = 64;
constexpr int size_bits = 4;
constexpr int threablock_k = 128;
constexpr int inc_step = threablock_k / interleaved;
size_t i = 0;
int* s32 = reinterpret_cast<int*>(&(constant_offset.c_offset[0]));
for (; i < inc_step; i++) {
int c = i / (fh * fw);
int khkw = i % (fh * fw);
int kh = khkw / fw;
int kw = khkw % fw;
s32[2 * i] = c * ics + kh * ihs + kw * interleaved * size_bits / 8;
int8_t* s8 = reinterpret_cast<int8_t*>(&(s32[2 * i + 1]));
s8[0] = kh;
s8[1] = kw;
s8[2] = -kh;
s8[3] = -kw;
}
for (; i < (inc_step + fh * fw * inc_step); i++) {
int c = i / (fh * fw);
int khkw = i % (fh * fw);
int kh = khkw / fw;
int kw = khkw % fw;
s32[2 * i] = c * ics + kh * ihs + kw * interleaved * size_bits / 8;
int8_t* s8 = reinterpret_cast<int8_t*>(&(s32[2 * i + 1]));
s8[0] = kh;
s8[1] = kw;
s8[2] = -kh;
s8[3] = -kw;
int i_ = i - inc_step;
c = i_ / (fh * fw);
khkw = i_ % (fh * fw);
kh = khkw / fw;
kw = khkw % fw;
s32[2 * i] -= c * ics + kh * ihs + kw * interleaved * size_bits / 8;
}
}
#endif
}; // namespace

std::string ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::kernel_key(
const SizeArgs& args) const {
std::string kernel_key;
using NonlineMode = Param::NonlineMode;
auto&& param = args.opr->param();
if (args.z_layout->ndim > 0) {
kernel_key =
ssprintf("%s_conv_bias_int4_fuse_z_imma8832_ldg16_%ux%u",
current_device_arch_name(), m_tile_nhw, m_tile_oc);
} else {
kernel_key =
ssprintf("%s_conv_bias_int4_imma8832_ldg16_%ux%u",
current_device_arch_name(), m_tile_nhw, m_tile_oc);
}
if (param.nonlineMode == NonlineMode::H_SWISH) {
kernel_key += "_hswish";
} else {
megdnn_assert(param.nonlineMode == NonlineMode::RELU ||
param.nonlineMode == NonlineMode::IDENTITY);
kernel_key += "_relu";
}
return kernel_key;
}

bool ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::is_available(
const SizeArgs& args) const {
if (args.bias_layout->ndim <= 0)
return false;
using Param = param::ConvBias;
using Format = Param::Format;
using Sparse = Param::Sparse;
using Mode = Param::Mode;
bool available = true;
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
if (!check_bias_share_in_channel(*(args.bias_layout), param.format))
return false;
if (param.format != Format::NCHW64)
return false;
UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout),
param);
// TODO support group conv
available &= param.sparse == Sparse::DENSE;
// mode must be cross correlation
available &= param.mode == Mode::CROSS_CORRELATION;
// check data type
auto src_dtype = args.src_layout->dtype,
filter_dtype = args.filter_layout->dtype,
bias_dtype = args.bias_layout->dtype,
dst_dtype = args.dst_layout->dtype;
available &= (src_dtype.enumv() == DTypeEnum::QuantizedS4 &&
filter_dtype.enumv() == DTypeEnum::QuantizedS4 &&
bias_dtype.enumv() == DTypeEnum::QuantizedS32 &&
dst_dtype.enumv() == DTypeEnum::QuantizedS4);
// TODO: support dialtion
available &= dh == 1 && dw == 1;
// ensure precomputed offsets are positive integers
available &= hi >= fh && wi >= fw;
// only support sm_75 or later, platform should have tensorcore int8
// support
available &= is_compute_capability_required(7, 5);
// param buffer size is 4K, use 3K to store precomputed offset, fh * fw <=
// (3*1024/4/2/2) - 1
available &= fh * fw <= 191;
return available;
}

size_t
ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::get_workspace_in_bytes(
const SizeArgs& args) const {
if (args.preprocessed_filter == nullptr) {
return args.filter_layout->span().dist_byte() +
args.bias_layout->span().dist_byte();
}
return 0_z;
}

void ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::exec(
const ExecArgs& args) const {
#if MEGDNN_TEGRA_X1
megdnn_throw("sass kernel is disabled at compile time for TX1");
#else
using Format = Param::Format;
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout),
param);
auto&& stream = cuda_stream(args.opr->handle());
constexpr int interleaved = 64;

void* bias_ptr = nullptr;
void* filter_ptr = nullptr;
if (args.preprocessed_filter) {
megdnn_assert(args.preprocessed_filter->tensors.size() == 2);
filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr;
bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr;
} else {
// reorder filter and bias
filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr);
bias_ptr =
reinterpret_cast<void*>(args.workspace.raw_ptr +
args.filter_layout->span().dist_byte());
if (args.z_layout->ndim > 0) {
reorder_imma_filter_bias<4, 64>(
reinterpret_cast<int8_t*>(filter_ptr),
reinterpret_cast<int32_t*>(bias_ptr),
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr),
args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw,
stream);
} else {
reorder_imma_filter_bias<4, 64, true>(
reinterpret_cast<int8_t*>(filter_ptr),
reinterpret_cast<int32_t*>(bias_ptr),
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr),
args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw,
stream);
}
}

uint32_t u32_n = n, u32_ci = ci, u32_hi = hi, u32_wi = wi, u32_fh = fh,
u32_fw = fw, u32_sh = sh, u32_sw = sw, u32_ph = ph, u32_pw = pw,
u32_co = co, u32_ho = ho, u32_wo = wo;
Conv2dInt4Param kern_param(u32_n, u32_ci, u32_hi, u32_wi, u32_fh, u32_fw,
u32_sh, u32_sw, u32_ph, u32_pw, u32_co, u32_ho,
u32_wo, interleaved);

Conv2dConstantOffset kern_coffset;
compute_conv2d_offset(fh, fw, kern_param.ics, kern_param.ihs, kern_coffset);
// The starting address of Turing param buffer is c[0x0][0x160]
kern_coffset.c_offset_param.begin = param_buffer_start_address();
kern_coffset.c_offset_param.size = 16 * (1 + fh * fw);
kern_coffset.c_offset_param.max = 16 * fh * fw;
kern_coffset.c_offset_param.rewind = 16 * (1 - fh * fw);

auto kern_key = kernel_key(args);
float src_scale = args.src_layout->dtype.param<dtype::QuantizedS4>().scale,
filter_scale =
args.filter_layout->dtype.param<dtype::QuantizedS4>().scale,
bias_scale =
args.bias_layout->dtype.param<dtype::QuantizedS32>().scale,
dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS4>().scale;
float alpha = src_scale * filter_scale / dst_scale,
beta = bias_scale / dst_scale;
float inv_dst_scale = 1.f / dst_scale;

unsigned int tx = m_threads, ty = 1;
unsigned int gridx = div_ceil<unsigned int>(
static_cast<unsigned int>(n * ho * wo), m_tile_nhw);
unsigned int gridy =
div_ceil<unsigned int>(static_cast<unsigned int>(co), m_tile_oc);
void* src_ptr = const_cast<void*>(args.src_tensor->raw_ptr);
void* dst_ptr = const_cast<void*>(args.dst_tensor->raw_ptr);

using NonlineMode = Param::NonlineMode;
auto&& kernel = SASSKernelLoader::instance().get_kernel(kern_key, kern_key);
if (args.z_layout->ndim > 0) {
void* z_ptr = const_cast<void*>(args.z_tensor->raw_ptr);
float z_scale = args.z_layout->dtype.param<dtype::QuantizedS4>().scale;
float gamma = z_scale / dst_scale;
std::vector<void*> params = {&src_ptr, &filter_ptr, &bias_ptr, &z_ptr,
&dst_ptr, &alpha, &beta, &gamma};
kern_coffset.c_offset_param.begin +=
sizeof(src_ptr) + sizeof(filter_ptr) + sizeof(bias_ptr) +
sizeof(z_ptr) + sizeof(dst_ptr) + sizeof(alpha) + sizeof(beta) +
sizeof(gamma);

uint32_t relu = param.nonlineMode == NonlineMode::RELU ? 1 : 0;
if (param.nonlineMode == NonlineMode::H_SWISH) {
params.push_back(&dst_scale);
params.push_back(&inv_dst_scale);
kern_coffset.c_offset_param.begin +=
sizeof(dst_scale) + sizeof(inv_dst_scale);
} else {
params.push_back(&relu);
kern_coffset.c_offset_param.begin += sizeof(relu);
}
params.push_back(&kern_param);
kern_coffset.c_offset_param.begin += sizeof(kern_param);
kern_coffset.c_offset_param.begin +=
sizeof(kern_coffset.c_offset_param);
kern_coffset.c_offset_param.max += kern_coffset.c_offset_param.begin;
params.push_back(&kern_coffset);
cucheck(cuLaunchKernel(kernel, gridx, gridy, 1, tx, ty, 1, 0, stream,
params.data(), 0));
} else {
std::vector<void*> params = {&src_ptr, &filter_ptr, &bias_ptr,
&dst_ptr, &alpha, &beta};

kern_coffset.c_offset_param.begin +=
sizeof(src_ptr) + sizeof(filter_ptr) + sizeof(bias_ptr) +
sizeof(dst_ptr) + sizeof(alpha) + sizeof(beta);

uint32_t relu = param.nonlineMode == NonlineMode::RELU ? 1 : 0;
if (param.nonlineMode == NonlineMode::H_SWISH) {
params.push_back(&dst_scale);
params.push_back(&inv_dst_scale);
kern_coffset.c_offset_param.begin +=
sizeof(dst_scale) + sizeof(inv_dst_scale);
} else {
params.push_back(&relu);
kern_coffset.c_offset_param.begin += sizeof(relu);
}
params.push_back(&kern_param);
kern_coffset.c_offset_param.begin += sizeof(kern_param);
kern_coffset.c_offset_param.begin +=
sizeof(kern_coffset.c_offset_param);
kern_coffset.c_offset_param.max += kern_coffset.c_offset_param.begin;
params.push_back(&kern_coffset);
cucheck(cuLaunchKernel(kernel, gridx, gridy, 1, tx, ty, 1, 0, stream,
params.data(), 0));
}
after_kernel_launch();
#endif
}

size_t ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::
get_preprocess_workspace_in_bytes(const SizeArgs& args) const {
return 0_z;
}

SmallVector<TensorLayout> ConvBiasForwardImpl::
AlgoSASSInt4NCHW64IMMAImplicitGemm::deduce_preprocessed_filter_layout(
const SizeArgs& args) const {
return {args.filter_layout->collapse_contiguous(),
args.bias_layout->collapse_contiguous()};
}

void ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::exec_preprocess(
const ExecArgs& args) const {
using Format = Param::Format;
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout),
param);
auto&& stream = cuda_stream(args.opr->handle());
reorder_imma_filter_bias<4, 64>(
reinterpret_cast<int8_t*>(
args.preprocessed_filter->tensors[0].raw_ptr),
args.preprocessed_filter->tensors[1].compatible_ptr<int32_t>(),
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr),
args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw,
stream);
}

// vim: syntax=cpp.doxygen

+ 44
- 6
dnn/src/naive/conv_bias/opr_impl.cpp View File

@@ -161,6 +161,38 @@ void forward_bias<dt_qint4, dt_qint4, dt_qint32, dt_qint32>(
forward_bias<dt_qint8, dt_qint8, dt_qint32, dt_qint32>( forward_bias<dt_qint8, dt_qint8, dt_qint32, dt_qint32>(
new_src, new_flt, bias, dst, nullptr, new_filter_meta); new_src, new_flt, bias, dst, nullptr, new_filter_meta);
} }

template <>
void forward_bias<dt_quint4, dt_qint4, dt_qint32, dt_qint32>(
_megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
_megdnn_tensor_out dst, dt_byte* workspace_ptr,
const ConvBiasForward::CanonizedFilterMeta& filter_meta) {
auto convert_layout_src = [](const TensorLayout& layout) {
auto ret = layout;
auto param = layout.dtype.param<dtype::Quantized4Asymm>();
ret.dtype = dtype::QuantizedS8(param.scale);
ret.format = TensorFormat(ret.dtype);
ret.init_contiguous_stride();
return ret;
};
auto convert_layout_flt = [](const TensorLayout& layout) {
auto ret = layout;
auto param = layout.dtype.param<dtype::QuantizedS4>();
ret.dtype = dtype::QuantizedS8(param.scale);
ret.format = TensorFormat(ret.dtype);
ret.init_contiguous_stride();
return ret;
};
TensorND new_src = {workspace_ptr, convert_layout_src(src.layout)};
TensorND new_flt = {workspace_ptr + new_src.layout.span().dist_byte(),
convert_layout_flt(filter.layout)};
uint4_to_int8(src, new_src);
int4_to_int8(filter, new_flt);
auto new_filter_meta = filter_meta;
new_filter_meta.dtype = new_flt.layout.dtype;
forward_bias<dt_qint8, dt_qint8, dt_qint32, dt_qint32>(
new_src, new_flt, bias, dst, nullptr, new_filter_meta);
}
} // namespace convolution } // namespace convolution


size_t ConvBiasForwardImpl::get_workspace_in_bytes(const TensorLayout& src, size_t ConvBiasForwardImpl::get_workspace_in_bytes(const TensorLayout& src,
@@ -211,9 +243,10 @@ void ConvBiasForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
TensorLayout{dst.layout, bias.layout.dtype}}; TensorLayout{dst.layout, bias.layout.dtype}};
workspace_ptr += sfb.layout.span().dist_byte(); workspace_ptr += sfb.layout.span().dist_byte();
} }
#define DISPATCH_RAW(in_dt, bias_dt, out_dt, cmode, func) \
#define DISPATCH_RAW(in_dt, flt_dt, bias_dt, out_dt, cmode, func) \
else if (src.layout.dtype.enumv() == DTypeTrait<dtype::in_dt>::enumv && \ else if (src.layout.dtype.enumv() == DTypeTrait<dtype::in_dt>::enumv && \
filter.layout.dtype.enumv() == DTypeTrait<dtype::in_dt>::enumv && \
filter.layout.dtype.enumv() == \
DTypeTrait<dtype::flt_dt>::enumv && \
bias.layout.dtype.enumv() == DTypeTrait<dtype::bias_dt>::enumv && \ bias.layout.dtype.enumv() == DTypeTrait<dtype::bias_dt>::enumv && \
sfb.layout.dtype.enumv() == DTypeTrait<dtype::out_dt>::enumv && \ sfb.layout.dtype.enumv() == DTypeTrait<dtype::out_dt>::enumv && \
param().compute_mode == Param::ComputeMode::cmode) { \ param().compute_mode == Param::ComputeMode::cmode) { \
@@ -222,7 +255,7 @@ void ConvBiasForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
} }
#define DISPATCH(in_dt, out_dt) \ #define DISPATCH(in_dt, out_dt) \
DISPATCH_RAW( \ DISPATCH_RAW( \
in_dt, out_dt, out_dt, DEFAULT, \
in_dt, in_dt, out_dt, out_dt, DEFAULT, \
(convolution::forward_bias<DTypeTrait<dtype::in_dt>::ctype, \ (convolution::forward_bias<DTypeTrait<dtype::in_dt>::ctype, \
DTypeTrait<dtype::in_dt>::ctype, \ DTypeTrait<dtype::in_dt>::ctype, \
DTypeTrait<dtype::out_dt>::ctype, \ DTypeTrait<dtype::out_dt>::ctype, \
@@ -236,16 +269,21 @@ void ConvBiasForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
DISPATCH(QuantizedS8, Float32) DISPATCH(QuantizedS8, Float32)
DISPATCH(Quantized8Asymm, QuantizedS32) DISPATCH(Quantized8Asymm, QuantizedS32)
DISPATCH(Quantized4Asymm, QuantizedS32) DISPATCH(Quantized4Asymm, QuantizedS32)
DISPATCH_RAW(QuantizedS8, QuantizedS32, QuantizedS32, FLOAT32,
DISPATCH_RAW(QuantizedS8, QuantizedS8, QuantizedS32, QuantizedS32,
FLOAT32,
(convolution::forward_bias<dt_int8, dt_int8, dt_int32, (convolution::forward_bias<dt_int8, dt_int8, dt_int32,
dt_int32>)) dt_int32>))
DISPATCH(QuantizedS4, QuantizedS32) DISPATCH(QuantizedS4, QuantizedS32)
DISPATCH_RAW(Quantized4Asymm, QuantizedS4, QuantizedS32, QuantizedS32,
DEFAULT,
(convolution::forward_bias<dt_quint4, dt_qint4, dt_qint32,
dt_qint32>))
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
DISPATCH(Float16, Float16) DISPATCH(Float16, Float16)
DISPATCH_RAW(Float16, Float16, Float16, FLOAT32,
DISPATCH_RAW(Float16, Float16, Float16, Float16, FLOAT32,
(convolution::forward_bias<dt_float16, dt_float16, (convolution::forward_bias<dt_float16, dt_float16,
dt_float16, dt_float32>)) dt_float16, dt_float32>))
DISPATCH_RAW(BFloat16, BFloat16, BFloat16, FLOAT32,
DISPATCH_RAW(BFloat16, BFloat16, BFloat16, BFloat16, FLOAT32,
(convolution::forward_bias<dt_bfloat16, dt_bfloat16, (convolution::forward_bias<dt_bfloat16, dt_bfloat16,
dt_bfloat16, dt_float32>)) dt_bfloat16, dt_float32>))
#endif #endif


+ 48
- 0
dnn/src/naive/lowbit_utils.cpp View File

@@ -57,6 +57,54 @@ void megdnn::naive::uint8_to_uint4(const TensorND& in, const TensorND& out) {
} }
} }


void megdnn::naive::uint4_to_int8(const TensorND& in, const TensorND& out) {
auto in_ptr = static_cast<uint8_t*>(in.raw_ptr) + in.layout.span().low_byte;
auto out_ptr = out.compatible_ptr<int8_t>() + out.layout.span().low_byte;
const auto& ly = in.layout;
int8_t zero_point =
(int8_t)ly.dtype.param<dtype::Quantized4Asymm>().zero_point;
auto dim_in = ly.shape[ly.ndim - 1];
auto elems = ly.total_nr_elems();
auto dim_out = elems / dim_in;
auto stride_out = div_ceil(dim_in, 2_z);
for (size_t i = 0; i < dim_out; ++i) {
for (size_t j = 0; j < dim_in; j += 2) {
uint8_t val = in_ptr[j / 2];
out_ptr[j] = (int8_t)(val & 0xF) - zero_point;
if (j + 1 < dim_in)
out_ptr[j + 1] = (int8_t)((val >> 4) & 0xF) - zero_point;
}
in_ptr += stride_out;
out_ptr += dim_in;
}
}

void megdnn::naive::int8_to_uint4(const TensorND& in, const TensorND& out) {
auto in_ptr = static_cast<int8_t*>(in.raw_ptr) + in.layout.span().low_byte;
auto out_ptr =
static_cast<uint8_t*>(out.raw_ptr) + out.layout.span().low_byte;
auto zero_point =
out.layout.dtype.param<dtype::Quantized4Asymm>().zero_point;
const auto& ly = in.layout;
auto dim_in = ly.shape[ly.ndim - 1];
auto elems = ly.total_nr_elems();
auto dim_out = elems / dim_in;
auto stride_out = div_ceil(dim_in, 2_z);
for (size_t i = 0; i < dim_out; ++i) {
for (size_t j = 0; j < dim_in; j += 2) {
uint8_t a = (uint8_t)std::max((int32_t)in_ptr[j] + zero_point, 0);
uint8_t b = 0;
if (j + 1 < dim_in)
b = (uint8_t)std::max((int32_t)in_ptr[j + 1] + zero_point, 0);
a = std::min(a, DTypeTrait<dtype::Quantized4Asymm>::max());
b = std::min(b, DTypeTrait<dtype::Quantized4Asymm>::max());
out_ptr[j / 2] = a + (b << 4);
}
in_ptr += dim_in;
out_ptr += stride_out;
}
}

// ==================================qint4====================================== // ==================================qint4======================================
void megdnn::naive::int4_to_int8(const TensorND& in, const TensorND& out) { void megdnn::naive::int4_to_int8(const TensorND& in, const TensorND& out) {
auto in_ptr = static_cast<int8_t*>(in.raw_ptr) + in.layout.span().low_byte; auto in_ptr = static_cast<int8_t*>(in.raw_ptr) + in.layout.span().low_byte;


+ 4
- 0
dnn/src/naive/lowbit_utils.h View File

@@ -20,6 +20,10 @@ void uint4_to_uint8(const TensorND& in, const TensorND& out);


void uint8_to_uint4(const TensorND& in, const TensorND& out); void uint8_to_uint4(const TensorND& in, const TensorND& out);


void uint4_to_int8(const TensorND& in, const TensorND& out);

void int8_to_uint4(const TensorND& in, const TensorND& out);

void int4_to_int8(const TensorND& in, const TensorND& out); void int4_to_int8(const TensorND& in, const TensorND& out);


void int8_to_int4(const TensorND& in , const TensorND& out); void int8_to_int4(const TensorND& in , const TensorND& out);


+ 20
- 3
dnn/test/common/conv_bias.cpp View File

@@ -733,19 +733,33 @@ void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype,
param::ConvBias::Format format, param::ConvBias::Format format,
const std::vector<TestArg>& args, bool fuse_z, const std::vector<TestArg>& args, bool fuse_z,
bool stable_test) { bool stable_test) {
megdnn_assert(src_dtype.enumv() == filter_dtype.enumv());
megdnn_assert((src_dtype.enumv() == filter_dtype.enumv()) ||
(src_dtype.enumv() == DTypeEnum::Quantized4Asymm &&
filter_dtype.enumv() == DTypeEnum::QuantizedS4));
Checker<ConvBiasForward> checker(handle, !stable_test); Checker<ConvBiasForward> checker(handle, !stable_test);
if (algo) { if (algo) {
checker.set_before_exec_callback( checker.set_before_exec_callback(
ConvBiasAlgoChecker<ConvBiasForward>(algo)); ConvBiasAlgoChecker<ConvBiasForward>(algo));
} }
std::unique_ptr<RNG> rng; std::unique_ptr<RNG> rng;
std::unique_ptr<RNG> flt_rng;
std::unique_ptr<RNG> bias_rng; std::unique_ptr<RNG> bias_rng;
std::unique_ptr<RNG> const_rng; std::unique_ptr<RNG> const_rng;
std::unique_ptr<RNG> zero_rng; std::unique_ptr<RNG> zero_rng;
// TODO: check range of rng // TODO: check range of rng
if (src_dtype.enumv() == DTypeEnum::QuantizedS8) { if (src_dtype.enumv() == DTypeEnum::QuantizedS8) {
rng = std::make_unique<UniformIntRNG>(-3, 3); rng = std::make_unique<UniformIntRNG>(-3, 3);
flt_rng = std::make_unique<UniformIntRNG>(-3, 3);
const_rng = std::make_unique<UniformIntRNG>(1, 1);
zero_rng = std::make_unique<UniformIntRNG>(0, 0);
megdnn_assert(bias_dtype.enumv() == DTypeEnum::QuantizedS32);
bias_rng = std::make_unique<UniformIntRNG>(-50, 50);
checker.set_epsilon(1 + 1e-3)
.set_max_avg_error(1e-1)
.set_max_avg_biased_error(1e-3);
} else if (src_dtype.enumv() == DTypeEnum::Quantized4Asymm) {
rng = std::make_unique<UniformIntRNG>(0, 6);
flt_rng = std::make_unique<UniformIntRNG>(-3, 3);
const_rng = std::make_unique<UniformIntRNG>(1, 1); const_rng = std::make_unique<UniformIntRNG>(1, 1);
zero_rng = std::make_unique<UniformIntRNG>(0, 0); zero_rng = std::make_unique<UniformIntRNG>(0, 0);
megdnn_assert(bias_dtype.enumv() == DTypeEnum::QuantizedS32); megdnn_assert(bias_dtype.enumv() == DTypeEnum::QuantizedS32);
@@ -755,6 +769,7 @@ void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype,
.set_max_avg_biased_error(1e-3); .set_max_avg_biased_error(1e-3);
} else if (src_dtype.enumv() == DTypeEnum::QuantizedS4) { } else if (src_dtype.enumv() == DTypeEnum::QuantizedS4) {
rng = std::make_unique<UniformIntRNG>(-3, 3); rng = std::make_unique<UniformIntRNG>(-3, 3);
flt_rng = std::make_unique<UniformIntRNG>(-3, 3);
const_rng = std::make_unique<UniformIntRNG>(1, 1); const_rng = std::make_unique<UniformIntRNG>(1, 1);
zero_rng = std::make_unique<UniformIntRNG>(0, 0); zero_rng = std::make_unique<UniformIntRNG>(0, 0);
megdnn_assert(bias_dtype.enumv() == DTypeEnum::QuantizedS32); megdnn_assert(bias_dtype.enumv() == DTypeEnum::QuantizedS32);
@@ -764,11 +779,13 @@ void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype,
.set_max_avg_biased_error(1e-3); .set_max_avg_biased_error(1e-3);
} else if (src_dtype.enumv() == DTypeEnum::Float16) { } else if (src_dtype.enumv() == DTypeEnum::Float16) {
rng = std::make_unique<NormalRNG>(2.f); rng = std::make_unique<NormalRNG>(2.f);
flt_rng = std::make_unique<NormalRNG>(2.f);
megdnn_assert(bias_dtype.enumv() == DTypeEnum::Float16); megdnn_assert(bias_dtype.enumv() == DTypeEnum::Float16);
bias_rng = std::make_unique<NormalRNG>(2.f); bias_rng = std::make_unique<NormalRNG>(2.f);
checker.set_epsilon(1e-2); checker.set_epsilon(1e-2);
} else if (src_dtype.enumv() == DTypeEnum::Float32) { } else if (src_dtype.enumv() == DTypeEnum::Float32) {
rng = std::make_unique<NormalRNG>(2.f); rng = std::make_unique<NormalRNG>(2.f);
flt_rng = std::make_unique<NormalRNG>(2.f);
megdnn_assert(bias_dtype.enumv() == DTypeEnum::Float32); megdnn_assert(bias_dtype.enumv() == DTypeEnum::Float32);
bias_rng = std::make_unique<NormalRNG>(2.f); bias_rng = std::make_unique<NormalRNG>(2.f);
} }
@@ -819,9 +836,9 @@ void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype,
} }
return z; return z;
}; };
megdnn_assert(rng != nullptr && bias_rng != nullptr);
megdnn_assert(rng != nullptr && flt_rng != nullptr && bias_rng != nullptr);
checker.set_rng(0, rng.get()) checker.set_rng(0, rng.get())
.set_rng(1, rng.get())
.set_rng(1, flt_rng.get())
.set_rng(2, bias_rng.get()) .set_rng(2, bias_rng.get())
.set_rng(3, rng.get()); .set_rng(3, rng.get());
if (stable_test) { if (stable_test) {


+ 62
- 46
dnn/test/cuda/conv_test_utils.cpp View File

@@ -257,7 +257,9 @@ void benchmark_target_algo_with_cudnn_tsc(
param::ConvBias::Format change_cudnn_format, param::ConvBias::Format change_cudnn_format,
DType change_cudnn_src_dtype, DType change_cudnn_filter_dtype, DType change_cudnn_src_dtype, DType change_cudnn_filter_dtype,
DType change_cudnn_bias_dtype, DType change_cudnn_dst_dtype) { DType change_cudnn_bias_dtype, DType change_cudnn_dst_dtype) {
megdnn_assert(src_dtype.enumv() == filter_dtype.enumv());
megdnn_assert((src_dtype.enumv() == filter_dtype.enumv()) ||
(src_dtype.enumv() == DTypeEnum::Quantized4Asymm &&
filter_dtype.enumv() == DTypeEnum::QuantizedS4));
CUBenchmarker<ConvBiasForward> benchmarker(handle); CUBenchmarker<ConvBiasForward> benchmarker(handle);
CUBenchmarker<ConvBiasForward> benchmarker_cudnn(handle); CUBenchmarker<ConvBiasForward> benchmarker_cudnn(handle);
size_t RUNS = 200; size_t RUNS = 200;
@@ -299,30 +301,30 @@ void benchmark_target_algo_with_cudnn_tsc(
using Param = ConvBias::Param; using Param = ConvBias::Param;
using Format = Param::Format; using Format = Param::Format;
// helper function to change format // helper function to change format
auto get_tensor_shape = [](TensorShape shape,
auto get_tensor_shape = [](TensorShape shape, DType dtype,
Format format) -> TensorShape { Format format) -> TensorShape {
TensorShape ret; TensorShape ret;
if (format == Format::NCHW4) { if (format == Format::NCHW4) {
ret = static_cast<TensorShape>( ret = static_cast<TensorShape>(
TensorLayout{shape, dtype::Int8()}
TensorLayout{shape, dtype}
.reshape({shape[0], shape[1] / 4, 4, shape[2], .reshape({shape[0], shape[1] / 4, 4, shape[2],
shape[3]}) shape[3]})
.dimshuffle({0, 1, 3, 4, 2})); .dimshuffle({0, 1, 3, 4, 2}));
} else if (format == Format::NCHW32) { } else if (format == Format::NCHW32) {
ret = static_cast<TensorShape>( ret = static_cast<TensorShape>(
TensorLayout{shape, dtype::Int8()}
TensorLayout{shape, dtype}
.reshape({shape[0], shape[1] / 32, 32, shape[2], .reshape({shape[0], shape[1] / 32, 32, shape[2],
shape[3]}) shape[3]})
.dimshuffle({0, 1, 3, 4, 2})); .dimshuffle({0, 1, 3, 4, 2}));
} else if (format == Format::NCHW64) { } else if (format == Format::NCHW64) {
ret = static_cast<TensorShape>( ret = static_cast<TensorShape>(
TensorLayout{shape, dtype::QuantizedS4(1.f)}
TensorLayout{shape, dtype}
.reshape({shape[0], shape[1] / 64, 64, shape[2], .reshape({shape[0], shape[1] / 64, 64, shape[2],
shape[3]}) shape[3]})
.dimshuffle({0, 1, 3, 4, 2})); .dimshuffle({0, 1, 3, 4, 2}));
} else if (format == Format::CHWN4) { } else if (format == Format::CHWN4) {
ret = static_cast<TensorShape>( ret = static_cast<TensorShape>(
TensorLayout{shape, dtype::Int8()}
TensorLayout{shape, dtype}
.reshape({shape[0], shape[1] / 4, 4, shape[2], .reshape({shape[0], shape[1] / 4, 4, shape[2],
shape[3]}) shape[3]})
.dimshuffle({1, 3, 4, 0, 2})); .dimshuffle({1, 3, 4, 0, 2}));
@@ -370,21 +372,24 @@ void benchmark_target_algo_with_cudnn_tsc(
if (algo) { if (algo) {
time_in_ms = time_in_ms =
algo_benchmark<ConvBiasForward, OprProxy<ConvBiasForward>, algo_benchmark<ConvBiasForward, OprProxy<ConvBiasForward>,
CUTimer>(benchmarker,
{get_tensor_shape(src, format),
get_tensor_shape(filter, format),
get_tensor_shape(bias, format),
{},
{}},
algo) /
CUTimer>(
benchmarker,
{get_tensor_shape(src, src_dtype, format),
get_tensor_shape(filter, filter_dtype, format),
get_tensor_shape(bias, bias_dtype, format),
{},
{}},
algo) /
RUNS; RUNS;
} else { } else {
time_in_ms = benchmarker.execs({get_tensor_shape(src, format),
get_tensor_shape(filter, format),
get_tensor_shape(bias, format),
{},
{}}) /
RUNS;
time_in_ms =
benchmarker.execs(
{get_tensor_shape(src, src_dtype, format),
get_tensor_shape(filter, filter_dtype, format),
get_tensor_shape(bias, bias_dtype, format),
{},
{}}) /
RUNS;
} }
float time_in_ms_cudnn = 0; float time_in_ms_cudnn = 0;
if (with_cudnn) { if (with_cudnn) {
@@ -393,9 +398,11 @@ void benchmark_target_algo_with_cudnn_tsc(
algo_benchmark<ConvBiasForward, algo_benchmark<ConvBiasForward,
OprProxy<ConvBiasForward>, CUTimer>( OprProxy<ConvBiasForward>, CUTimer>(
benchmarker_cudnn, benchmarker_cudnn,
{get_tensor_shape(src, format_cudnn),
get_tensor_shape(filter, format_cudnn),
get_tensor_shape(bias, format_cudnn),
{get_tensor_shape(src, src_dtype, format_cudnn),
get_tensor_shape(filter, filter_dtype,
format_cudnn),
get_tensor_shape(bias, bias_dtype,
format_cudnn),
{}, {},
{}}, {}},
change_cudnn_algo) / change_cudnn_algo) /
@@ -403,9 +410,11 @@ void benchmark_target_algo_with_cudnn_tsc(
} else { } else {
time_in_ms_cudnn = time_in_ms_cudnn =
benchmarker_cudnn.execs( benchmarker_cudnn.execs(
{get_tensor_shape(src, format_cudnn),
get_tensor_shape(filter, format_cudnn),
get_tensor_shape(bias, format_cudnn),
{get_tensor_shape(src, src_dtype, format_cudnn),
get_tensor_shape(filter, filter_dtype,
format_cudnn),
get_tensor_shape(bias, bias_dtype,
format_cudnn),
{}, {},
{}}) / {}}) /
RUNS; RUNS;
@@ -426,21 +435,24 @@ void benchmark_target_algo_with_cudnn_tsc(
if (algo) { if (algo) {
time_in_ms = time_in_ms =
algo_benchmark<ConvBiasForward, OprProxy<ConvBiasForward>, algo_benchmark<ConvBiasForward, OprProxy<ConvBiasForward>,
CUTimer>(benchmarker,
{get_tensor_shape(src, format),
get_tensor_shape(filter, format),
get_tensor_shape(bias, format),
get_tensor_shape(z, format),
{}},
algo) /
CUTimer>(
benchmarker,
{get_tensor_shape(src, src_dtype, format),
get_tensor_shape(filter, filter_dtype, format),
get_tensor_shape(bias, bias_dtype, format),
get_tensor_shape(z, src_dtype, format),
{}},
algo) /
RUNS; RUNS;
} else { } else {
time_in_ms = benchmarker.execs({get_tensor_shape(src, format),
get_tensor_shape(filter, format),
get_tensor_shape(bias, format),
get_tensor_shape(z, format),
{}}) /
RUNS;
time_in_ms =
benchmarker.execs(
{get_tensor_shape(src, src_dtype, format),
get_tensor_shape(filter, filter_dtype, format),
get_tensor_shape(bias, bias_dtype, format),
get_tensor_shape(z, src_dtype, format),
{}}) /
RUNS;
} }
time_in_ms_cudnn = 0; time_in_ms_cudnn = 0;
if (with_cudnn) { if (with_cudnn) {
@@ -449,20 +461,24 @@ void benchmark_target_algo_with_cudnn_tsc(
algo_benchmark<ConvBiasForward, algo_benchmark<ConvBiasForward,
OprProxy<ConvBiasForward>, CUTimer>( OprProxy<ConvBiasForward>, CUTimer>(
benchmarker_cudnn, benchmarker_cudnn,
{get_tensor_shape(src, format_cudnn),
get_tensor_shape(filter, format_cudnn),
get_tensor_shape(bias, format_cudnn),
get_tensor_shape(z, format_cudnn),
{get_tensor_shape(src, src_dtype, format_cudnn),
get_tensor_shape(filter, filter_dtype,
format_cudnn),
get_tensor_shape(bias, bias_dtype,
format_cudnn),
get_tensor_shape(z, src_dtype, format_cudnn),
{}}, {}},
change_cudnn_algo) / change_cudnn_algo) /
RUNS; RUNS;
} else { } else {
time_in_ms_cudnn = time_in_ms_cudnn =
benchmarker_cudnn.execs( benchmarker_cudnn.execs(
{get_tensor_shape(src, format_cudnn),
get_tensor_shape(filter, format_cudnn),
get_tensor_shape(bias, format_cudnn),
get_tensor_shape(z, format_cudnn),
{get_tensor_shape(src, src_dtype, format_cudnn),
get_tensor_shape(filter, filter_dtype,
format_cudnn),
get_tensor_shape(bias, bias_dtype,
format_cudnn),
get_tensor_shape(z, src_dtype, format_cudnn),
{}}) / {}}) /
RUNS; RUNS;
} }


+ 65
- 2
dnn/test/naive/conv_bias.cpp View File

@@ -746,6 +746,45 @@ TEST_F(NAIVE, CONV_BIAS_QUANTIZED4) {


checker.set_param(param).exect(Testcase{input, filter, bias, z, {}}, checker.set_param(param).exect(Testcase{input, filter, bias, z, {}},
Testcase{{}, {}, {}, {}, output}); Testcase{{}, {}, {}, {}, output});

// test qu4 x q4

for (size_t i = 0; i < input_values.size(); i++) {
input_values[i] = input_values[i] + 8;
}

for (size_t i = 0; i < z_values.size(); i++) {
z_values[i] = z_values[i] + 8;
}

std::vector<int> output_uint4;
auto dtype_qu4 = dtype::Quantized4Asymm(0.01, 8);
for (size_t i = 0; i < output_values.size(); i++) {
int result =
static_cast<int>(dtype_qu4.param()
.quantize(output_values.at(i) * 0.01)
.as_uint8());
output_uint4.push_back(result);
}

auto input_qu4 = TensorValueLowbit4(
{1, 1, 4, 4}, dtype::Quantized4Asymm(0.1, 8), input_values);

auto filter_q4 = TensorValueLowbit4({3, 1, 3, 3}, dtype::QuantizedS4(0.1),
filter_values);

auto bias_s32 = GenTensorValue({1, 3, 1, 1}, dtype::QuantizedS32(0.01),
bias_values);

auto z_qu4 = TensorValueLowbit4({1, 3, 2, 2},
dtype::Quantized4Asymm(0.01, 8), z_values);

auto output_qu4 = TensorValueLowbit4(
{1, 3, 2, 2}, dtype::Quantized4Asymm(0.01, 8), output_uint4);

checker.set_param(param).exect(
Testcase{input_qu4, filter_q4, bias_s32, z_qu4, {}},
Testcase{{}, {}, {}, {}, output_qu4});
} }


TEST_F(NAIVE, CONV_BIAS_NCHW64_Q4) { TEST_F(NAIVE, CONV_BIAS_NCHW64_Q4) {
@@ -3329,7 +3368,7 @@ TEST_F(NAIVE, CONV_BIAS_NCHW64_Q4) {


auto input_64 = TensorValueLowbit4({1, 1, 4, 4, 64}, auto input_64 = TensorValueLowbit4({1, 1, 4, 4, 64},
dtype::QuantizedS4(0.1), input_values); dtype::QuantizedS4(0.1), input_values);
auto fliter_64 = TensorValueLowbit4({64, 1, 3, 3, 64},
auto filter_64 = TensorValueLowbit4({64, 1, 3, 3, 64},
dtype::QuantizedS4(0.1), filter_values); dtype::QuantizedS4(0.1), filter_values);
auto bias1_64 = auto bias1_64 =
GenTensorValue({1, 1, 1, 1, 64}, dtype::QuantizedS32(0.01), bias_1); GenTensorValue({1, 1, 1, 1, 64}, dtype::QuantizedS32(0.01), bias_1);
@@ -3338,7 +3377,31 @@ TEST_F(NAIVE, CONV_BIAS_NCHW64_Q4) {
{1, 1, 2, 2, 64}, dtype::QuantizedS4(1), output_values); {1, 1, 2, 2, 64}, dtype::QuantizedS4(1), output_values);


checker.set_param(param).exect( checker.set_param(param).exect(
Testcase{input_64, fliter_64, bias1_64, {}, {}},
Testcase{input_64, filter_64, bias1_64, {}, {}},
Testcase{{}, {}, {}, {}, output_64}); Testcase{{}, {}, {}, {}, output_64});

// test qu4 x q4

for (size_t i = 0; i < input_values.size(); i++) {
input_values[i] = input_values[i] + 8;
}

for (size_t i = 0; i < output_values.size(); i++) {
output_values[i] = output_values[i] + 8;
}

auto input_qu4_64 = TensorValueLowbit4(
{1, 1, 4, 4, 64}, dtype::Quantized4Asymm(0.1, 8), input_values);
auto filter_q4_64 = TensorValueLowbit4(
{64, 1, 3, 3, 64}, dtype::QuantizedS4(0.1), filter_values);
auto bias_64 =
GenTensorValue({1, 1, 1, 1, 64}, dtype::QuantizedS32(0.01), bias_1);

auto output_q4_64 = TensorValueLowbit4(
{1, 1, 2, 2, 64}, dtype::Quantized4Asymm(1, 8), output_values);

checker.set_param(param).exect(
Testcase{input_qu4_64, filter_q4_64, bias_64, {}, {}},
Testcase{{}, {}, {}, {}, output_q4_64});
} }
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

Loading…
Cancel
Save