GitOrigin-RevId: 581e31fc20
release-1.1
@@ -182,6 +182,48 @@ class WarpPerspectiveBackwardMat: public WarpPerspectiveBase { | |||
size_t workspace_in_bytes); | |||
}; | |||
class DctChannelSelectForward : public OperatorBase { | |||
DEF_OPR_PARAM(DctChannelSelect); | |||
DEF_OPR_IMPL(DctChannelSelectForward, OperatorBase, 3, 1); | |||
public: | |||
/** | |||
* \param[in] DctChannelSelectForward input, must be uint8 nchw tensor | |||
* \param[in] mask_offset input, must be int32 nchw tensor | |||
* \param[in] mask_val input, must be int32 nchw tensor | |||
* \param[dst] DctChannelSelectForward output, default fp32 nchw tensor | |||
* \param[out] workspace temporary workspace to perform forward | |||
*/ | |||
virtual void exec(_megdnn_tensor_in src, | |||
_megdnn_tensor_in mask_offset, | |||
_megdnn_tensor_in mask_val, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& src, | |||
const TensorLayout& mask_offset, | |||
const TensorLayout& mask_val, | |||
TensorLayout& dst); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& mask_offset, | |||
const TensorLayout& mask_val, | |||
const TensorLayout& dst) = 0; | |||
protected: | |||
void check_layout_fwd(const TensorLayout& src, | |||
const TensorLayout& mask_offset, | |||
const TensorLayout& mask_val, | |||
const TensorLayout& dst); | |||
void deduce_layout_fwd(const TensorLayout& src, | |||
const TensorLayout& mask_offset, | |||
const TensorLayout& mask_val, | |||
TensorLayout& dst); | |||
std::string param_msg() const; | |||
}; | |||
} // namespace megdnn | |||
#include "megdnn/internal/opr_header_epilogue.h" | |||
@@ -411,6 +411,9 @@ pdef('ElemwiseMultiType').add_enum( | |||
pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) | |||
(pdef('DctChannelSelect', '2d discrete cosine transform').add_enum_alias('Format', 'ConvolutionV0'). | |||
add_enum('FastImpl', 'NONE', 'FIX_32_MASK').add_fields('int32', 'dct_block_size', 8)) | |||
(pdef('MatrixMul', version=0, is_legacy=True). | |||
add_fields('bool', 'transposeA', 'false', 'transposeB', 'false'). | |||
add_enum('DataType', | |||
@@ -0,0 +1,82 @@ | |||
/** | |||
* \file dnn/src/common/dct.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
void DctChannelSelectForward::deduce_layout_fwd(const TensorLayout& src, | |||
const TensorLayout& mask_offset, | |||
const TensorLayout& mask_val, | |||
TensorLayout& dst) { | |||
const size_t dct_block = param().dct_block_size; | |||
const size_t in = src.shape[0]; | |||
const size_t ic = src.shape[1]; | |||
const size_t ih = src.shape[2]; | |||
const size_t iw = src.shape[3]; | |||
check_layout_fwd(src, mask_offset, mask_val, dst); | |||
const size_t oh = ih / dct_block; | |||
const size_t ow = iw / dct_block; | |||
//! mask will be empty or (ic + 1) elements | |||
size_t oc = mask_offset.ndim > 0 && mask_offset[0] >= 2 | |||
? mask_val.shape[0] | |||
: ic * dct_block * dct_block; | |||
if (param().fastImpl == Param::FastImpl::FIX_32_MASK) { | |||
megdnn_assert(oc == 32, | |||
"Param::FastImpl::FIX_32_MASK oc must be 32, but %zu", | |||
oc); | |||
} | |||
if (param().format == Param::Format::NCHW) { | |||
dst = TensorLayout(TensorShape({in, oc, oh, ow}), dst.dtype); | |||
} else { | |||
megdnn_assert(param().format == Param::Format::NCHW4, | |||
"dct format must be nchw or nchw4"); | |||
megdnn_assert(oc % 4 == 0, "oc mod 4 == 0 in nchw4"); | |||
dst = TensorLayout(TensorShape({in, oc / 4, oh, ow, 4}), dst.dtype); | |||
} | |||
} | |||
void DctChannelSelectForward::deduce_layout(const TensorLayout& src, | |||
const TensorLayout& mask_offset, | |||
const TensorLayout& mask_val, | |||
TensorLayout& dst) { | |||
deduce_layout_fwd(src, mask_offset, mask_val, dst); | |||
} | |||
void DctChannelSelectForward::check_layout_fwd(const TensorLayout& src, | |||
const TensorLayout& mask_offset, | |||
const TensorLayout& mask_val, | |||
const TensorLayout& dst) { | |||
const size_t dct_block = param().dct_block_size; | |||
const size_t ih = src.shape[2]; | |||
const size_t iw = src.shape[3]; | |||
megdnn_assert(mask_offset.ndim == 0 || (mask_offset.ndim == 1 && | |||
(mask_offset.shape[0] == 0 || | |||
mask_offset.shape[0] >= 2) && | |||
mask_val.ndim == 1), | |||
"mask only support one valid dim"); | |||
megdnn_assert(mask_val.ndim <= 1, "only support one dim"); | |||
megdnn_assert(src.dtype.enumv() == DTypeEnum::Uint8, | |||
"src.dtype == dtype::Uint8"); | |||
megdnn_assert(dst.dtype.enumv() == DTypeEnum::Float32 || | |||
dst.dtype.enumv() == DTypeEnum::QuantizedS8, | |||
"dst.dtype == dtype::Float32 || dst.dtype.enumv() == " | |||
"DTypeEnum::QuantizedS8"); | |||
megdnn_assert(ih % dct_block == 0, "ih mod dctblock == 0"); | |||
megdnn_assert(iw % dct_block == 0, "iw mod dctblock == 0"); | |||
} | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -201,6 +201,7 @@ private: | |||
cb(RemapBackwardMat) \ | |||
cb(AdaptivePoolingForward) \ | |||
cb(AdaptivePoolingBackward) \ | |||
cb(DctChannelSelectForward) | |||
/*! | |||
* \brief specialize HandleImpl::create_operator for a single opr type; | |||
@@ -0,0 +1,429 @@ | |||
/** | |||
* \file dnn/src/cuda/dct/dct_channel_select.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megcore_cdefs.h" | |||
#include "src/cuda/dct/dct_channel_select.cuh" | |||
#include "src/cuda/error_info.cuh" | |||
namespace megdnn { | |||
namespace cuda { | |||
template <typename T> | |||
struct CudaPostProcess; | |||
template <> | |||
struct CudaPostProcess<float> { | |||
CudaPostProcess(float){}; | |||
static inline __device__ float func(float val) { return val; } | |||
}; | |||
template <> | |||
struct CudaPostProcess<int8_t> { | |||
CudaDTypeParamImpl<dt_qint8> m_type_cvt; | |||
CudaPostProcess(float scale) { m_type_cvt.inv_scale = 1.f / scale; }; | |||
inline __device__ int8_t func(float val) { | |||
return m_type_cvt.quantize(val).as_int8(); | |||
} | |||
}; | |||
template <uint32_t format> | |||
struct ChannelBlockHelper; | |||
template <> | |||
struct ChannelBlockHelper<dct::DctLayoutFormat::NCHW4> { | |||
static constexpr int channel_block = 4; | |||
}; | |||
template <> | |||
struct ChannelBlockHelper<dct::DctLayoutFormat::NCHW> { | |||
static constexpr int channel_block = 1; | |||
}; | |||
namespace dct { | |||
namespace { | |||
inline __device__ void load_row(float (&row_cache)[8], const uint8_t* src) { | |||
int2 row = *((int2*)src); | |||
row_cache[0] = (float)(((uchar4*)&(row.x))->x); | |||
row_cache[1] = (float)(((uchar4*)&(row.x))->y); | |||
row_cache[2] = (float)(((uchar4*)&(row.x))->z); | |||
row_cache[3] = (float)(((uchar4*)&(row.x))->w); | |||
row_cache[4] = (float)(((uchar4*)&(row.y))->x); | |||
row_cache[5] = (float)(((uchar4*)&(row.y))->y); | |||
row_cache[6] = (float)(((uchar4*)&(row.y))->z); | |||
row_cache[7] = (float)(((uchar4*)&(row.y))->w); | |||
} | |||
inline __device__ void fast_dct_1d_internel(float& src0, float& src1, | |||
float& src2, float& src3, | |||
float& src4, float& src5, | |||
float& src6, float& src7) { | |||
constexpr float rsqrt_8 = 0.3535533905932737f; //!< rsqrt_8 = sqrt(1 / 8) | |||
constexpr float a = 1.387039845322148f; //!< a = sqrt2 * cos(pi * 1 / 16) | |||
constexpr float b = 1.306562964876377f; //!< b = sqrt2 * cos(pi * 2 / 16) | |||
constexpr float c = 1.175875602419359f; //!< c = sqrt2 * cos(pi * 3 / 16) | |||
constexpr float d = 0.785694958387102f; //!< d = sqrt2 * cos(pi * 5 / 16) | |||
constexpr float e = 0.541196100146197f; //!< e = sqrt2 * cos(pi * 6 / 16) | |||
constexpr float f = 0.275899379282943f; //!< f = sqrt2 * cos(pi * 7 / 16) | |||
const float add_0_7 = src0 + src7; | |||
const float add_1_6 = src1 + src6; | |||
const float add_2_5 = src2 + src5; | |||
const float add_3_4 = src3 + src4; | |||
const float sub_0_7 = src0 - src7; | |||
const float sub_6_1 = src6 - src1; | |||
const float sub_2_5 = src2 - src5; | |||
const float sub_4_3 = src4 - src3; | |||
const float add_0_7_3_4 = add_0_7 + add_3_4; | |||
const float add_1_6_2_5 = add_1_6 + add_2_5; | |||
const float add_0_7_sub_3_4 = add_0_7 - add_3_4; | |||
const float add_1_6_sub_2_5 = add_1_6 - add_2_5; | |||
src0 = rsqrt_8 * (add_0_7_3_4 + add_1_6_2_5); | |||
src2 = rsqrt_8 * (b * add_0_7_sub_3_4 + e * add_1_6_sub_2_5); | |||
src4 = rsqrt_8 * (add_0_7_3_4 - add_1_6_2_5); | |||
src6 = rsqrt_8 * (e * add_0_7_sub_3_4 - b * add_1_6_sub_2_5); | |||
src1 = rsqrt_8 * (a * sub_0_7 - c * sub_6_1 + d * sub_2_5 - f * sub_4_3); | |||
src3 = rsqrt_8 * (c * sub_0_7 + f * sub_6_1 - a * sub_2_5 + d * sub_4_3); | |||
src5 = rsqrt_8 * (d * sub_0_7 + a * sub_6_1 + f * sub_2_5 - c * sub_4_3); | |||
src7 = rsqrt_8 * (f * sub_0_7 + d * sub_6_1 + c * sub_2_5 + a * sub_4_3); | |||
} | |||
inline __device__ void fast_dct_1d(float (&src)[8]) { | |||
fast_dct_1d_internel(src[0], src[1], src[2], src[3], src[4], src[5], src[6], | |||
src[7]); | |||
} | |||
inline __device__ void fast_dct_1d_col(float (&src)[8][8], const int col) { | |||
fast_dct_1d_internel(src[0][col], src[1][col], src[2][col], src[3][col], | |||
src[4][col], src[5][col], src[6][col], src[7][col]); | |||
} | |||
enum class MaskType { | |||
NO_MASK = 0, | |||
USER_DEFINE_MASK = 1, | |||
FIX_32_MASK = 2, | |||
MASK_END | |||
}; | |||
template <const int dct_block, const int block_oh, const int block_ow, | |||
uint32_t format, MaskType mask_type, typename DstDtype, typename T2> | |||
struct StoreMask; | |||
template <const int dct_block, const int block_oh, const int block_ow, | |||
typename T2> | |||
struct StoreMask<dct_block, block_oh, block_ow, DctLayoutFormat::NCHW, | |||
MaskType::USER_DEFINE_MASK, float, T2> { | |||
static inline __device__ void func( | |||
const float (&thread_cache)[dct_block][dct_block], float* dst_tid, | |||
const int oc_stride, int channel_idx, const int* mask_offset, | |||
const int* mask_val, CudaPostProcess<T2>& quant_param, | |||
megcore::AsyncErrorInfo* error_info, void* error_tracker) { | |||
__shared__ float shared[dct_block][dct_block][block_oh][block_ow]; | |||
#pragma unroll | |||
for (int i = 0; i < dct_block; ++i) | |||
#pragma unroll | |||
for (int j = 0; j < dct_block; ++j) { | |||
shared[i][j][threadIdx.y][threadIdx.x] = thread_cache[i][j]; | |||
} | |||
const int store_channel_offset = mask_offset[channel_idx]; | |||
const int nr_store_channel = | |||
mask_offset[channel_idx + 1] - store_channel_offset; | |||
if (nr_store_channel < 0) { | |||
set_async_error_info(error_info, error_tracker, | |||
"nchw sub mask len must > 0"); | |||
} | |||
for (int store_channel_idx = 0; store_channel_idx < nr_store_channel; | |||
++store_channel_idx) { | |||
const int index = | |||
mask_val[store_channel_offset + store_channel_idx]; | |||
dst_tid[store_channel_idx * oc_stride] = | |||
shared[index / dct_block][index % dct_block][threadIdx.y] | |||
[threadIdx.x]; | |||
} | |||
} | |||
}; | |||
template <const int dct_block, const int block_oh, const int block_ow, | |||
typename T2> | |||
struct StoreMask<dct_block, block_oh, block_ow, DctLayoutFormat::NCHW4, | |||
MaskType::USER_DEFINE_MASK, int8_t, T2> { | |||
static inline __device__ void func( | |||
const float (&thread_cache)[dct_block][dct_block], int8_t* dst_tid, | |||
const int oc_stride, int channel_idx, const int* mask_offset, | |||
const int* mask_val, CudaPostProcess<T2>& quant_param, | |||
megcore::AsyncErrorInfo* error_info, void* error_tracker) { | |||
//! nchw4 channel_block is 4 | |||
constexpr int channel_block = | |||
ChannelBlockHelper<DctLayoutFormat::NCHW4>::channel_block; | |||
__shared__ float shared[dct_block][dct_block][block_oh][block_ow]; | |||
#pragma unroll | |||
for (int i = 0; i < dct_block; ++i) | |||
#pragma unroll | |||
for (int j = 0; j < dct_block; ++j) { | |||
shared[i][j][threadIdx.y][threadIdx.x] = thread_cache[i][j]; | |||
} | |||
const int store_channel_offset = mask_offset[channel_idx]; | |||
const int nr_store_channel = | |||
mask_offset[channel_idx + 1] - store_channel_offset; | |||
if (nr_store_channel % 4 != 0 || nr_store_channel < 0) { | |||
set_async_error_info(error_info, error_tracker, | |||
"nchw4 sub_mask_len mod 4 should be 0 and " | |||
"sub_mask_len must > 0"); | |||
} | |||
for (int store_channel_idx = 0; store_channel_idx < nr_store_channel; | |||
store_channel_idx += channel_block) { | |||
const int index0 = | |||
mask_val[store_channel_offset + store_channel_idx]; | |||
const int index1 = | |||
mask_val[store_channel_offset + store_channel_idx + 1]; | |||
const int index2 = | |||
mask_val[store_channel_offset + store_channel_idx + 2]; | |||
const int index3 = | |||
mask_val[store_channel_offset + store_channel_idx + 3]; | |||
const int store_c4_idx = store_channel_idx / channel_block; | |||
*(char4*)(&dst_tid[store_c4_idx * channel_block * oc_stride]) = { | |||
quant_param.func( | |||
shared[index0 / dct_block][index0 % dct_block] | |||
[threadIdx.y][threadIdx.x]), | |||
quant_param.func( | |||
shared[index1 / dct_block][index1 % dct_block] | |||
[threadIdx.y][threadIdx.x]), | |||
quant_param.func( | |||
shared[index2 / dct_block][index2 % dct_block] | |||
[threadIdx.y][threadIdx.x]), | |||
quant_param.func( | |||
shared[index3 / dct_block][index3 % dct_block] | |||
[threadIdx.y][threadIdx.x])}; | |||
} | |||
} | |||
}; | |||
template <const int dct_block, const int block_oh, const int block_ow, | |||
uint32_t format, typename DstDtype, typename T2> | |||
struct StoreMask<dct_block, block_oh, block_ow, format, MaskType::NO_MASK, | |||
DstDtype, T2> { | |||
static inline __device__ void func( | |||
const float (&thread_cache)[dct_block][dct_block], | |||
DstDtype* dst_tid, const int oc_stride, int channel_idx, | |||
const int* mask_offset, const int* mask_val, | |||
CudaPostProcess<T2>& quant_param, | |||
megcore::AsyncErrorInfo* error_info, void* error_tracker) { | |||
constexpr int channel_block = ChannelBlockHelper<format>::channel_block; | |||
#pragma unroll | |||
for (int i = 0; i < dct_block; i++) { | |||
#pragma unroll | |||
for (int j = 0; j < dct_block; j++) { | |||
dst_tid[(i * dct_block + j) / channel_block * channel_block * | |||
oc_stride + | |||
(i * dct_block + j) % channel_block] = | |||
quant_param.func(thread_cache[i][j]); | |||
} | |||
} | |||
} | |||
}; | |||
template <const int dct_block, const int block_oh, const int block_ow, | |||
typename T2> | |||
struct StoreMask<dct_block, block_oh, block_ow, DctLayoutFormat::NCHW, | |||
MaskType::FIX_32_MASK, float, T2> { | |||
static inline __device__ void func( | |||
const float (&thread_cache)[dct_block][dct_block], float* dst_tid, | |||
const int oc_stride, int channel_idx, const int* mask_offset, | |||
const int* mask_val, CudaPostProcess<T2>& quant_param, | |||
megcore::AsyncErrorInfo* error_info, void* error_tracker) { | |||
#define STORE(store_index, index) \ | |||
dst_tid[store_index * oc_stride] = \ | |||
thread_cache[index / dct_block][index % dct_block] | |||
STORE(0, 0); | |||
STORE(1, 1); | |||
STORE(2, 8); | |||
STORE(3, 16); | |||
STORE(4, 9); | |||
STORE(5, 2); | |||
STORE(6, 3); | |||
STORE(7, 10); | |||
if (channel_idx == 0) { | |||
STORE(8, 17); | |||
STORE(9, 24); | |||
STORE(10, 32); | |||
STORE(11, 25); | |||
STORE(12, 18); | |||
STORE(13, 11); | |||
STORE(14, 4); | |||
STORE(15, 5); | |||
} | |||
#undef STORE | |||
} | |||
}; | |||
template <const int dct_block, const int block_oh, const int block_ow, | |||
typename T2> | |||
struct StoreMask<dct_block, block_oh, block_ow, DctLayoutFormat::NCHW4, | |||
MaskType::FIX_32_MASK, int8_t, T2> { | |||
static inline __device__ void func( | |||
const float (&thread_cache)[dct_block][dct_block], int8_t* dst_tid, | |||
const int oc_stride, int channel_idx, const int* mask_offset, | |||
const int* mask_val, CudaPostProcess<T2>& quant_param, | |||
megcore::AsyncErrorInfo* error_info, void* error_tracker) { | |||
#define STORE(store_index, index0, index1, index2, index3) \ | |||
*(char4*)(&dst_tid[store_index * oc_stride]) = { \ | |||
quant_param.func( \ | |||
thread_cache[index0 / dct_block][index0 % dct_block]), \ | |||
quant_param.func( \ | |||
thread_cache[index1 / dct_block][index1 % dct_block]), \ | |||
quant_param.func( \ | |||
thread_cache[index2 / dct_block][index2 % dct_block]), \ | |||
quant_param.func( \ | |||
thread_cache[index3 / dct_block][index3 % dct_block])} | |||
STORE(0, 0, 1, 8, 16); | |||
STORE(4, 9, 2, 3, 10); | |||
if (channel_idx == 0) { | |||
STORE(8, 17, 24, 32, 25); | |||
STORE(12, 18, 11, 4, 5); | |||
} | |||
#undef STORE | |||
} | |||
}; | |||
template <const int dct_block, MaskType mask_type, const int ker_block_h, | |||
const int ker_block_w, uint32_t format, typename DstDtype, | |||
typename T2> | |||
__global__ void kern_dct(const uint8_t* src, DstDtype* dst, const int n, | |||
const int c, const int h, const int w, const int oh, | |||
const int ow, const int oc_stride, const int oc, | |||
const int* mask_offset, const int* mask_val, | |||
CudaPostProcess<T2> quant_param, | |||
megcore::AsyncErrorInfo* error_info, | |||
void* error_tracker) { | |||
constexpr int block_oh = ker_block_h / dct_block; | |||
constexpr int block_ow = ker_block_w / dct_block; | |||
const int channel_stride = h * w; | |||
const int oc_idx = blockIdx.z % c; | |||
const int oh_idx = blockIdx.y * block_oh + threadIdx.y; | |||
const int ow_idx = blockIdx.x * block_ow + threadIdx.x; | |||
float thread_cache[dct_block][dct_block]; | |||
const uint8_t* src_tid = | |||
src + blockIdx.z * channel_stride + | |||
(blockIdx.y * ker_block_h + threadIdx.y * dct_block) * w + | |||
(blockIdx.x * ker_block_w + threadIdx.x * dct_block); | |||
const int inner_channel_offset = | |||
(oh_idx * ow + ow_idx) * ChannelBlockHelper<format>::channel_block; | |||
DstDtype* dst_tid = | |||
dst + blockIdx.z * channel_stride + inner_channel_offset; | |||
if (mask_type != MaskType::NO_MASK) { | |||
const int batch_idx = blockIdx.z / c; | |||
const int batch_stride = oc_stride * oc; | |||
int out_channel_offset = 0; | |||
if (mask_type == MaskType::FIX_32_MASK) { | |||
//! trick out_channel_offset = {0, 16, 24}[oc_idx]; oc_idx = 0, 1, 2 | |||
out_channel_offset = 16 * oc_idx - 8 * (oc_idx >> 1); | |||
} else { | |||
out_channel_offset = mask_offset[oc_idx]; | |||
} | |||
dst_tid = dst + batch_idx * batch_stride + | |||
out_channel_offset * oc_stride + inner_channel_offset; | |||
} | |||
if (oh_idx < oh && ow_idx < ow) { | |||
load_row(thread_cache[0], src_tid + 0 * w); | |||
load_row(thread_cache[1], src_tid + 1 * w); | |||
load_row(thread_cache[2], src_tid + 2 * w); | |||
load_row(thread_cache[3], src_tid + 3 * w); | |||
load_row(thread_cache[4], src_tid + 4 * w); | |||
load_row(thread_cache[5], src_tid + 5 * w); | |||
load_row(thread_cache[6], src_tid + 6 * w); | |||
load_row(thread_cache[7], src_tid + 7 * w); | |||
//! TMP = A @ C.T | |||
fast_dct_1d(thread_cache[0]); | |||
fast_dct_1d(thread_cache[1]); | |||
fast_dct_1d(thread_cache[2]); | |||
fast_dct_1d(thread_cache[3]); | |||
fast_dct_1d(thread_cache[4]); | |||
fast_dct_1d(thread_cache[5]); | |||
fast_dct_1d(thread_cache[6]); | |||
fast_dct_1d(thread_cache[7]); | |||
//! TMP = C @ TMP | |||
fast_dct_1d_col(thread_cache, 0); | |||
fast_dct_1d_col(thread_cache, 1); | |||
fast_dct_1d_col(thread_cache, 2); | |||
fast_dct_1d_col(thread_cache, 3); | |||
fast_dct_1d_col(thread_cache, 4); | |||
fast_dct_1d_col(thread_cache, 5); | |||
fast_dct_1d_col(thread_cache, 6); | |||
fast_dct_1d_col(thread_cache, 7); | |||
StoreMask<dct_block, block_oh, block_ow, format, mask_type, DstDtype, | |||
T2>::func(thread_cache, dst_tid, oc_stride, oc_idx, | |||
mask_offset, mask_val, quant_param, error_info, | |||
error_tracker); | |||
} | |||
} | |||
} // namespace | |||
template <int dct_block, uint32_t format, typename DstDtype> | |||
void call_kern_dct(const uint8_t* d_src, DstDtype* d_dst, const int n, | |||
const int c, const int h, const int w, const int oc, | |||
bool fix_32_mask, const int* mask_offset, | |||
const int* mask_val, cudaStream_t stream, | |||
megcore::AsyncErrorInfo* error_info, void* error_tracker, | |||
float scale) { | |||
constexpr int ker_block_h = 32; | |||
constexpr int ker_block_w = 256; | |||
const int oh = h / dct_block; | |||
const int ow = w / dct_block; | |||
const int oc_stride = oh * ow; | |||
const dim3 block_dim(DIVUP(w, ker_block_w), DIVUP(h, ker_block_h), n * c); | |||
const dim3 thread_dim(DIVUP(ker_block_w, dct_block), | |||
DIVUP(ker_block_h, dct_block)); | |||
auto cuda_dtype_param = CudaPostProcess<DstDtype>(scale); | |||
if (fix_32_mask) { | |||
kern_dct<dct_block, MaskType::FIX_32_MASK, ker_block_h, ker_block_w, | |||
format><<<block_dim, thread_dim, 0, stream>>>( | |||
d_src, d_dst, n, c, h, w, oh, ow, oc_stride, oc, mask_offset, | |||
mask_val, cuda_dtype_param, error_info, error_tracker); | |||
} else if (mask_offset && mask_val) { | |||
kern_dct<dct_block, MaskType::USER_DEFINE_MASK, ker_block_h, | |||
ker_block_w, format><<<block_dim, thread_dim, 0, stream>>>( | |||
d_src, d_dst, n, c, h, w, oh, ow, oc_stride, oc, mask_offset, | |||
mask_val, cuda_dtype_param, error_info, error_tracker); | |||
} else { | |||
kern_dct<dct_block, MaskType::NO_MASK, ker_block_h, ker_block_w, format> | |||
<<<block_dim, thread_dim, 0, stream>>>( | |||
d_src, d_dst, n, c, h, w, oh, ow, oc_stride, oc, | |||
mask_offset, mask_val, cuda_dtype_param, error_info, | |||
error_tracker); | |||
} | |||
} | |||
template void call_kern_dct<8, DctLayoutFormat::NCHW, float>( | |||
const uint8_t* d_src, float* d_dst, const int n, const int c, | |||
const int h, const int w, const int oc, bool fix_32_mask, | |||
const int* mask_offset, const int* mask_val, cudaStream_t stream, | |||
megcore::AsyncErrorInfo* error_info, void* error_tracker, float scale); | |||
template void call_kern_dct<8, DctLayoutFormat::NCHW4, int8_t>( | |||
const uint8_t* d_src, int8_t* d_dst, const int n, const int c, | |||
const int h, const int w, const int oc, bool fix_32_mask, | |||
const int* mask_offset, const int* mask_val, cudaStream_t stream, | |||
megcore::AsyncErrorInfo* error_info, void* error_tracker, float scale); | |||
} // namespace dct | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,38 @@ | |||
/** | |||
* \file dnn/src/cuda/dct/dct_channel_select.cuh | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the | |||
"License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express | |||
or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include <stdint.h> | |||
#include <cstdio> | |||
#include "src/common/opr_param_defs_enumv.cuh" | |||
#include "src/cuda/utils.cuh" | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace dct { | |||
using DctLayoutFormat = megdnn::param_enumv::DctChannelSelect::Format; | |||
template <int dct_block, uint32_t format, typename DstDtype> | |||
void call_kern_dct(const uint8_t* d_src, DstDtype* d_dst, const int n, | |||
const int c, const int h, const int w, const int oc, | |||
bool fix_32_mask, const int* mask_offset, | |||
const int* mask_val, cudaStream_t stream, | |||
megcore::AsyncErrorInfo* error_info, void* error_tracker, | |||
float scale = 1.f); | |||
} // namespace dct | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,73 @@ | |||
/** | |||
* \file dnn/src/naive/dct/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/common/utils.h" | |||
#include "src/cuda/dct/dct_channel_select.cuh" | |||
#include "src/cuda/dct/opr_impl.h" | |||
#include "src/cuda/handle.h" | |||
#include "src/cuda/utils.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
void DctChannelSelectForwardImpl::exec(_megdnn_tensor_in src, | |||
_megdnn_tensor_in mask_offset, | |||
_megdnn_tensor_in mask_val, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace /*workspace*/) { | |||
auto stream = cuda_stream(this->handle()); | |||
const int in = src.layout.shape[0]; | |||
const int ic = src.layout.shape[1]; | |||
const int ih = src.layout.shape[2]; | |||
const int iw = src.layout.shape[3]; | |||
int oc = dst.layout.shape[1]; | |||
const bool with_fix_32_mask = | |||
param().fastImpl == Param::FastImpl::FIX_32_MASK; | |||
if (param().format == Param::Format::NCHW4) { | |||
megdnn_assert(dst.layout.ndim == 5 && dst.layout.shape[4] == 4, | |||
"dst must be nchw4"); | |||
oc = oc * 4; | |||
} | |||
megdnn_assert(!with_fix_32_mask || (with_fix_32_mask && oc == 32), | |||
"only support specify mask"); | |||
megdnn_assert(param().dct_block_size == 8, "only support dct block = 8"); | |||
auto error_info = | |||
concrete_handle(this->handle())->megcore_context().error_info; | |||
constexpr int dct_block = 8; | |||
const int* mask_offset_ptr = nullptr; | |||
const int* mask_val_ptr = nullptr; | |||
if (mask_offset.layout.ndim == 1 && mask_offset.layout.shape[0] >= 2) { | |||
mask_offset_ptr = mask_offset.ptr<int32_t>(); | |||
mask_val_ptr = mask_val.ptr<int32_t>(); | |||
} | |||
if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { | |||
megdnn_assert(param().format == Param::Format::NCHW, | |||
"fp32 only support nchw"); | |||
dct::call_kern_dct<dct_block, dct::DctLayoutFormat::NCHW>( | |||
src.ptr<uint8_t>(), dst.ptr<float>(), in, ic, ih, iw, oc, | |||
with_fix_32_mask, mask_offset_ptr, mask_val_ptr, stream, | |||
error_info, m_error_tracker); | |||
} else { | |||
megdnn_assert(dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8, | |||
"only support fp32 and qs8"); | |||
megdnn_assert(param().format == Param::Format::NCHW4, | |||
"qint8 only support nchw4"); | |||
dct::call_kern_dct<dct_block, dct::DctLayoutFormat::NCHW4>( | |||
src.ptr<uint8_t>(), (int8_t*)dst.raw_ptr, in, ic, ih, iw, oc, | |||
with_fix_32_mask, mask_offset_ptr, mask_val_ptr, stream, | |||
error_info, m_error_tracker, | |||
dst.layout.dtype.param<::megdnn::dtype::QuantizedS8>().scale); | |||
} | |||
} | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,40 @@ | |||
/** | |||
* \file dnn/src/cuda/dct/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
class DctChannelSelectForwardImpl : public DctChannelSelectForward { | |||
public: | |||
using DctChannelSelectForward::DctChannelSelectForward; | |||
void* m_error_tracker = nullptr; | |||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in mask_offset, | |||
_megdnn_tensor_in mask_val, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout& /*src*/, | |||
const TensorLayout& /*mask_offset*/, | |||
const TensorLayout& /*mask_val*/, | |||
const TensorLayout& /*dst*/) { | |||
return 0; | |||
}; | |||
void set_error_tracker(void* tracker) override { | |||
m_error_tracker = tracker; | |||
} | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -26,6 +26,7 @@ | |||
#include "src/cuda/convpooling/opr_impl.h" | |||
#include "src/cuda/cumsum/opr_impl.h" | |||
#include "src/cuda/cvt_color/opr_impl.h" | |||
#include "src/cuda/dct/opr_impl.h" | |||
#include "src/cuda/deformable_conv/opr_impl.h" | |||
#include "src/cuda/deformable_ps_roi_pooling/opr_impl.h" | |||
#include "src/cuda/dot/opr_impl.h" | |||
@@ -0,0 +1,242 @@ | |||
/** | |||
* \file dnn/src/naive/dct/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include <cmath> | |||
#include "megdnn/basic_types.h" | |||
#include "megdnn/dtype.h" | |||
#include "midout.h" | |||
#include "src/naive/dct/opr_impl.h" | |||
#include "src/naive/handle.h" | |||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||
MIDOUT_DECL(megdnn_naive_dct_fwd) | |||
namespace megdnn { | |||
namespace naive { | |||
namespace { | |||
static inline void generate_c_matrix(float* result, int block) { | |||
constexpr float pi = M_PI; | |||
for (int i = 0; i < block; ++i) { | |||
for (int j = 0; j < block; ++j) { | |||
float alpha = i == 0 ? sqrt(1.f / static_cast<float>(block)) | |||
: sqrt(2.f / static_cast<float>(block)); | |||
result[i * block + j] = alpha * cos((2.f * j + 1.f) * i * pi / | |||
static_cast<float>(2 * block)); | |||
} | |||
} | |||
} | |||
template <typename T> | |||
void matmul(int m, int n, int k, int lda, int ldb, int ldc, const float* a, | |||
const T* b, float* c, bool trans_a, bool trans_b) { | |||
for (int m_idx = 0; m_idx < m; ++m_idx) { | |||
for (int n_idx = 0; n_idx < n; ++n_idx) { | |||
float res = 0.f; | |||
for (int k_idx = 0; k_idx < k; ++k_idx) { | |||
float av = trans_a ? a[k_idx * lda + m_idx] | |||
: a[m_idx * lda + k_idx]; | |||
float bv = trans_b ? b[n_idx * ldb + k_idx] | |||
: b[k_idx * ldb + n_idx]; | |||
res += av * bv; | |||
} | |||
c[m_idx * ldc + n_idx] = res; | |||
} | |||
} | |||
} | |||
std::vector<std::vector<int>> mask_offset_to_2dmask( | |||
_megdnn_tensor_in mask_offset, _megdnn_tensor_in mask_val) { | |||
std::vector<std::vector<int>> mask; | |||
if (mask_offset.layout.ndim > 0 && mask_offset.layout[0] >= 2) { | |||
const int offset_len = mask_offset.layout.shape[0]; | |||
const int32_t* mask_offset_ptr = mask_offset.ptr<int32_t>(); | |||
const int32_t* mask_val_ptr = mask_val.ptr<int32_t>(); | |||
megdnn_assert( | |||
mask_val.layout.shape[0] == | |||
static_cast<size_t>(mask_offset_ptr[offset_len - 1]), | |||
"check mask offset %zu != %zu", mask_val.layout.shape[0], | |||
static_cast<size_t>(mask_offset_ptr[offset_len - 1])); | |||
for (int offset_idx = 1; offset_idx < offset_len; ++offset_idx) { | |||
mask.push_back({}); | |||
const int mask_len = mask_offset_ptr[offset_idx] - | |||
mask_offset_ptr[offset_idx - 1]; | |||
const int32_t* mask_ptr = | |||
&mask_val_ptr[mask_offset_ptr[offset_idx - 1]]; | |||
for (int val_idx = 0; val_idx < mask_len; ++val_idx) { | |||
mask[offset_idx - 1].push_back(mask_ptr[val_idx]); | |||
} | |||
} | |||
} | |||
return mask; | |||
} | |||
inline bool is_layout_nchw4(const TensorLayout& layout) { | |||
if (layout.ndim == 5 && layout[4] == 4) { | |||
return true; | |||
} else { | |||
return false; | |||
} | |||
} | |||
template <typename T> | |||
using QuantizedCType = | |||
std::enable_if_t<DTypeTrait<T>::category == DTypeCategory::QUANTIZED, | |||
typename DTypeTrait<T>::ctype>; | |||
inline int8_t quant_float_2_int8(float val, DType dtype) { | |||
return dtype.param<::megdnn::dtype::QuantizedS8>().quantize(val).as_int8(); | |||
} | |||
template <param::DctChannelSelect::Format format, typename Dtype> | |||
inline void dct_output(Dtype* dst_ptr, const int oc_idx, const int img_size, | |||
float val, DType) { | |||
dst_ptr[oc_idx * img_size] = val; | |||
} | |||
template <> | |||
inline void dct_output<param::DctChannelSelect::Format::NCHW4>( | |||
int8_t* dst_ptr, const int oc_idx, const int img_size, float val, | |||
DType dtype) { | |||
dst_ptr[oc_idx / 4 * 4 * img_size + oc_idx % 4] = | |||
quant_float_2_int8(val, dtype); | |||
} | |||
template <param::DctChannelSelect::Format format> | |||
struct ChannleBlock { | |||
static constexpr int block = 1; | |||
}; | |||
template <> | |||
struct ChannleBlock<param::DctChannelSelect::Format::NCHW4> { | |||
static constexpr int block = 4; | |||
}; | |||
template <param::DctChannelSelect::Format format, typename Dtype> | |||
void naive_dct(const uint8_t* src, Dtype* dst, int n, int c, int h, int w, | |||
int block, const std::vector<std::vector<int>>& mask, | |||
DType dtype) { | |||
constexpr int block_channel = ChannleBlock<format>::block; | |||
const int block_h = block; | |||
const int block_w = block; | |||
std::vector<float> c_matrix(block * block); | |||
std::vector<float> tmp(block * block); | |||
std::vector<float> tmp_result(block * block); | |||
generate_c_matrix(&c_matrix[0], block); | |||
megdnn_assert(h % block_h == 0, "h mod block_h == 0"); | |||
megdnn_assert(w % block_w == 0, "w mod block_w == 0"); | |||
const int oh = h / block_h; | |||
const int ow = w / block_w; | |||
const int o_img_size = oh * ow; | |||
std::vector<int> mask_offset; | |||
int mask_len_sum = 0; | |||
if (mask.size() > 0) { | |||
for (auto& sub_mask : mask) { | |||
mask_offset.push_back(mask_len_sum); | |||
mask_len_sum += sub_mask.size(); | |||
} | |||
} else { | |||
for (int c_idx = 0; c_idx < c; ++c_idx) { | |||
mask_offset.push_back(mask_len_sum); | |||
mask_len_sum += block_h * block_w; | |||
} | |||
} | |||
const size_t o_batch_stride = mask_len_sum * oh * ow; | |||
for (int n_idx = 0; n_idx < n; ++n_idx) { | |||
for (int c_idx = 0; c_idx < c; ++c_idx) { | |||
megdnn_assert(mask_offset[c_idx] % block_channel == 0, | |||
"%d mod %d == 0", mask_offset[c_idx], block_channel); | |||
const size_t src_offset = n_idx * c * h * w + c_idx * h * w; | |||
const uint8_t* src_channel = src + src_offset; | |||
const size_t dst_offset = n_idx * o_batch_stride + | |||
mask_offset[c_idx] / block_channel * oh * | |||
ow * block_channel; | |||
Dtype* dst_channel = dst + dst_offset; | |||
for (int oh_idx = 0; oh_idx < oh; ++oh_idx) { | |||
for (int ow_idx = 0; ow_idx < ow; ++ow_idx) { | |||
matmul(block, block, block, block, w, block, &c_matrix[0], | |||
&src_channel[oh_idx * block_h * w + | |||
ow_idx * block_w], | |||
&tmp[0], false, false); | |||
matmul(block, block, block, block, block, block, &tmp[0], | |||
&c_matrix[0], &tmp_result[0], false, true); | |||
Dtype* dst_start = dst_channel + | |||
(oh_idx * ow + ow_idx) * block_channel; | |||
if (mask.size() == 0) { | |||
for (int inner_h_idx = 0; inner_h_idx < block_h; | |||
++inner_h_idx) { | |||
for (int inner_w_idx = 0; inner_w_idx < block_w; | |||
++inner_w_idx) { | |||
const int oc_idx = | |||
inner_h_idx * block_w + inner_w_idx; | |||
dct_output<format>( | |||
dst_start, oc_idx, o_img_size, | |||
tmp_result[inner_h_idx * block + | |||
inner_w_idx], | |||
dtype); | |||
} | |||
} | |||
} else { | |||
//! with mask | |||
auto& sub_mask = mask[c_idx]; | |||
int dst_offset = 0; | |||
for (auto mask_idx : sub_mask) { | |||
dct_output<format>(dst_start, dst_offset, | |||
o_img_size, tmp_result[mask_idx], | |||
dtype); | |||
++dst_offset; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} // namespace | |||
void DctChannelSelectForwardImpl::exec(_megdnn_tensor_in src, | |||
_megdnn_tensor_in mask_offset, | |||
_megdnn_tensor_in mask_val, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace /*workspace*/) { | |||
MIDOUT_BEGIN(megdnn_naive_dct_fwd) { | |||
int in = src.layout.shape[0]; | |||
int ic = src.layout.shape[1]; | |||
int ih = src.layout.shape[2]; | |||
int iw = src.layout.shape[3]; | |||
megdnn_assert(dst.raw_ptr, "dst can not be nullptr"); | |||
const int block = param().dct_block_size; | |||
auto mask = mask_offset_to_2dmask(mask_offset, mask_val); | |||
if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { | |||
megdnn_assert(!is_layout_nchw4(dst.layout) && | |||
param().format == Param::Format::NCHW, | |||
"dst must be nchw"); | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(naive_dct<Param::Format::NCHW>( | |||
src.ptr<uint8_t>(), dst.ptr<float>(), in, ic, ih, iw, block, | |||
mask, dst.layout.dtype)); | |||
} else { | |||
megdnn_assert(dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8, | |||
"dst must be q8"); | |||
megdnn_assert(is_layout_nchw4(dst.layout) && | |||
param().format == Param::Format::NCHW4, | |||
"dst must be nchw4"); | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(naive_dct<Param::Format::NCHW4>( | |||
src.ptr<uint8_t>(), static_cast<int8_t*>(dst.raw_ptr), in, | |||
ic, ih, iw, block, mask, dst.layout.dtype)); | |||
} | |||
} | |||
MIDOUT_END(); | |||
} | |||
} // namespace naive | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,34 @@ | |||
/** | |||
* \file dnn/src/naive/dct/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
namespace megdnn { | |||
namespace naive { | |||
class DctChannelSelectForwardImpl : public DctChannelSelectForward { | |||
public: | |||
using DctChannelSelectForward::DctChannelSelectForward; | |||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in mask_offset, | |||
_megdnn_tensor_in mask_val, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout& /*src*/, | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&) override { | |||
return 0; | |||
}; | |||
}; | |||
} // namespace naive | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -6,7 +6,8 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/naive/handle.h" | |||
@@ -29,6 +30,7 @@ | |||
#include "src/naive/convpooling/opr_impl.h" | |||
#include "src/naive/cumsum/opr_impl.h" | |||
#include "src/naive/cvt_color/opr_impl.h" | |||
#include "src/naive/dct/opr_impl.h" | |||
#include "src/naive/deformable_conv/opr_impl.h" | |||
#include "src/naive/deformable_ps_roi_pooling/opr_impl.h" | |||
#include "src/naive/dot/opr_impl.h" | |||
@@ -56,6 +58,7 @@ | |||
#include "src/naive/reduce/opr_impl.h" | |||
#include "src/naive/relayout/opr_impl.h" | |||
#include "src/naive/relayout_format/opr_impl.h" | |||
#include "src/naive/remap/opr_impl.h" | |||
#include "src/naive/repeat/opr_impl.h" | |||
#include "src/naive/resize/opr_impl.h" | |||
#include "src/naive/rng/opr_impl.h" | |||
@@ -76,7 +79,6 @@ | |||
#include "src/naive/warp_affine/opr_impl.h" | |||
#include "src/naive/warp_perspective/opr_impl.h" | |||
#include "src/naive/winograd_filter_preprocess/opr_impl.h" | |||
#include "src/naive/remap/opr_impl.h" | |||
static size_t g_image2d_pitch_alignment = 1; | |||
@@ -6,20 +6,21 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include <map> | |||
#include <memory> | |||
#include <vector> | |||
#include <regex> | |||
#include <vector> | |||
#include "megdnn/basic_types.h" | |||
#include "megdnn/tensor_format.h" | |||
#include "test/common/opr_algo_proxy.h" | |||
#include "test/common/opr_proxy.h" | |||
#include "test/common/rng.h" | |||
#include "test/common/timer.h" | |||
#include "test/common/opr_algo_proxy.h" | |||
namespace megdnn { | |||
namespace test { | |||
@@ -31,6 +32,7 @@ public: | |||
using TensorValueArray = TensorNDArray; | |||
using BeforeExecCallback = | |||
std::function<void(Opr*, const TensorValueArray&)>; | |||
using TensorsConstriant = std::function<void(TensorValueArray& tensors)>; | |||
BenchmarkerBase(Handle* handle, T timer) | |||
: m_timer(timer), | |||
@@ -54,6 +56,8 @@ public: | |||
} | |||
float exec(TensorLayoutArray layouts); | |||
float exect(const TensorValueArray& testcase_in); | |||
//! disabiguate overloaded exec | |||
float execs(const TensorShapeArray& shapes) { return exec(shapes); } | |||
float execl(const TensorLayoutArray& layouts) { return exec(layouts); } | |||
@@ -73,6 +77,11 @@ public: | |||
m_fmt[idx] = fmt; | |||
return *this; | |||
} | |||
BenchmarkerBase& set_tensors_constraint( | |||
const TensorsConstriant& tensor_constraint) { | |||
m_tensor_constraint = tensor_constraint; | |||
return *this; | |||
} | |||
TensorLayoutArray make_layouts(const TensorShapeArray& shapes) { | |||
TensorLayoutArray layouts(shapes.size()); | |||
for (size_t i = 0; i < shapes.size(); ++i) { | |||
@@ -142,6 +151,7 @@ private: | |||
std::unique_ptr<OprProxy<Opr>> m_proxy; | |||
BeforeExecCallback m_before_exec_callback; | |||
std::unique_ptr<Opr> m_opr; | |||
TensorsConstriant m_tensor_constraint; | |||
}; | |||
template <typename Opr, typename T> | |||
@@ -184,10 +194,16 @@ float BenchmarkerBase<Opr, T>::exec(TensorLayoutArray layouts) { | |||
auto rng = m_rng[i]; | |||
if (!rng) | |||
rng = m_default_rng.get(); | |||
auto size = tensor.layout.span().high_byte; | |||
rng->gen(tensor); | |||
} | |||
if (m_tensor_constraint) { | |||
m_tensor_constraint(tensors_cur_host); | |||
} | |||
for (size_t i = 0; i < tensors_cur_host.size(); ++i) { | |||
TensorND& tensor = tensors_cur_host[i]; | |||
if (tensor.layout.ndim == 0) | |||
continue; | |||
auto size = tensor.layout.span().high_byte; | |||
megdnn_memcpy_H2D(m_handle, tensors_cur[i].raw_ptr, tensor.raw_ptr, | |||
size); | |||
} | |||
@@ -243,6 +259,105 @@ float BenchmarkerBase<Opr, T>::exec(TensorLayoutArray layouts) { | |||
return time_in_ms; | |||
} | |||
template <typename Opr, typename T> | |||
float BenchmarkerBase<Opr, T>::exect(const TensorValueArray& testcase_in) { | |||
auto opr = this->opr(); | |||
opr->param() = m_param; | |||
TensorLayoutArray layouts; | |||
TensorNDArray tensors_cur_host; | |||
for (auto& inp : testcase_in) { | |||
layouts.push_back(inp.layout); | |||
tensors_cur_host.emplace_back(inp); | |||
} | |||
auto user_layouts = layouts; | |||
m_proxy->deduce_layout(opr, layouts); | |||
for (size_t i = 0; i < layouts.size(); ++i) | |||
if (user_layouts[i].ndim > 0) { | |||
auto run = [&]() { | |||
ASSERT_TRUE(layouts[i].eq_shape(user_layouts[i])) | |||
<< "User provided shape is " | |||
<< user_layouts[i].TensorShape::to_string() | |||
<< "\nExpected shape is " | |||
<< layouts[i].TensorShape::to_string(); | |||
}; | |||
run(); | |||
} | |||
auto allocate = [&layouts](Handle* handle) { | |||
TensorNDArray tensors(layouts.size()); | |||
auto trans_func = [handle](const TensorLayout& layout) { | |||
auto span = layout.span(); | |||
TensorND res; | |||
res.raw_ptr = static_cast<uint8_t*>( | |||
megdnn_malloc(handle, span.dist_byte())) + | |||
span.low_byte; | |||
res.layout = layout; | |||
return res; | |||
}; | |||
std::transform(layouts.begin(), layouts.end(), tensors.begin(), | |||
trans_func); | |||
return tensors; | |||
}; | |||
auto tensors_cur = allocate(m_handle); | |||
//! init | |||
for (size_t i = 0; i < tensors_cur_host.size(); ++i) { | |||
TensorND& tensor = tensors_cur_host[i]; | |||
auto size = tensor.layout.span().high_byte; | |||
if (tensor.layout.ndim == 0) | |||
continue; | |||
megdnn_memcpy_H2D(m_handle, tensors_cur[i].raw_ptr, tensor.raw_ptr, | |||
size); | |||
} | |||
if (m_before_exec_callback) { | |||
m_before_exec_callback(opr, tensors_cur); | |||
} | |||
//! run | |||
//! warm up | |||
m_proxy->exec(opr, tensors_cur); | |||
megcoreSynchronize(m_handle->megcore_computing_handle()); | |||
if (m_adaptive_secs) { | |||
//! find m_times for adaptive benchmarking | |||
m_times = 0; | |||
int cur_times = 1; | |||
auto remain_time = m_adaptive_secs * 1e6; | |||
while (remain_time > 0) { | |||
m_timer.reset(); | |||
m_timer.start(); | |||
for (int i = 0; i < cur_times; ++i) | |||
m_proxy->exec(opr, tensors_cur); | |||
megcoreSynchronize(m_handle->megcore_computing_handle()); | |||
m_timer.stop(); | |||
m_times += cur_times; | |||
auto this_run_time = m_timer.get_time_in_us(); | |||
remain_time -= this_run_time; | |||
cur_times = std::min( | |||
cur_times * 2, | |||
std::max<int>(1, remain_time / this_run_time * cur_times)); | |||
} | |||
} | |||
m_timer.reset(); | |||
m_timer.start(); | |||
for (size_t t = 0; t < m_times; ++t) | |||
m_proxy->exec(opr, tensors_cur); | |||
megcoreSynchronize(m_handle->megcore_computing_handle()); | |||
m_timer.stop(); | |||
auto time_in_ms = m_timer.get_time_in_us() / 1e3; | |||
if (m_display) { | |||
std::cout << "Total time is " << time_in_ms << "ms " | |||
<< "for " << m_times << " run(s)." << std::endl; | |||
} | |||
auto free = [](Handle* handle, TensorNDArray& tensors) { | |||
std::for_each(tensors.begin(), tensors.end(), | |||
[handle](const TensorND& tensor) { | |||
megdnn_free(handle, tensor.raw_ptr); | |||
}); | |||
}; | |||
free(m_handle, tensors_cur); | |||
if (m_adaptive_secs) | |||
time_in_ms /= m_times; | |||
return time_in_ms; | |||
} | |||
template <typename Opr, typename T = Timer> | |||
class Benchmarker; | |||
@@ -294,8 +294,6 @@ void CheckerHelper::do_exec_with_testcases(const TensorValueArray& testcase_in, | |||
ASSERT_TRUE(testcase_in[i].layout.ndim == 0 || | |||
testcase_out[i].layout.ndim == 0 || | |||
testcase_in[i].layout.eq_layout(testcase_out[i].layout)); | |||
ASSERT_TRUE(testcase_in[i].layout.ndim != 0 || | |||
testcase_out[i].layout.ndim != 0); | |||
layouts.emplace_back(testcase_in[i].layout.ndim > 0 | |||
? testcase_in[i].layout | |||
: testcase_out[i].layout); | |||
@@ -392,7 +392,8 @@ TensorND TensorValue(const TensorShape& shape, T dtype, | |||
tensor.layout = {shape, dtype}; | |||
tensor.raw_ptr = | |||
static_cast<dt_byte*>(malloc(tensor.layout.span().dist_byte())); | |||
megdnn_assert(values.size() == tensor.layout.total_nr_elems()); | |||
megdnn_assert(values.size() == tensor.layout.total_nr_elems(), "%zu == %zu", values.size(), | |||
tensor.layout.total_nr_elems()); | |||
auto ptr = tensor.ptr<typename DTypeTrait<T>::ctype>(); | |||
for (const auto& v : values) { | |||
*ptr++ = typename DTypeTrait<T>::ctype(v); | |||
@@ -0,0 +1,198 @@ | |||
/** | |||
* \file | |||
* dnn/test/common/dct_ref.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "test/common/dct_ref.h" | |||
namespace megdnn { | |||
namespace test { | |||
struct FixCase { | |||
std::vector<int> mask_offset; | |||
std::vector<int> mask_val; | |||
}; | |||
using Param = DctChannelSelectForward::Param; | |||
static inline FixCase get_fix_mask(Param::FastImpl impl) { | |||
std::vector<int> fix_32_mask_offset{0, 16, 24, 32}; | |||
std::vector<int> fix_32_mask_val{0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, | |||
25, 18, 11, 4, 5, 0, 1, 8, 16, 9, 2, | |||
3, 10, 0, 1, 8, 16, 9, 2, 3, 10}; | |||
megdnn_assert(impl == Param::FastImpl::FIX_32_MASK, | |||
"only support gen FIX_32_MASK"); | |||
return {fix_32_mask_offset, fix_32_mask_val}; | |||
} | |||
CheckerHelper::TensorsConstriant gen_dct_constriant( | |||
const size_t /* n */, const size_t ic, const size_t ih, const size_t iw, | |||
const size_t oc, Param param) { | |||
auto constraint = [=](CheckerHelper::TensorValueArray& tensors_orig) { | |||
const size_t block = param.dct_block_size; | |||
const int block_c = param.format == Param::Format::NCHW4 ? 4 : 1; | |||
megdnn_assert(oc % block_c == 0, "oc mod block_c must == 0"); | |||
std::shared_ptr<DctTestcase> test_case_ptr = DctTestcase::make(); | |||
DctTestcase& test_case = *test_case_ptr.get(); | |||
UniformIntRNG rng(0, 255); | |||
UniformIntRNG mask_rng(0, 64 / block_c - 1); | |||
const size_t no_mask_oc = ic * block * block; | |||
megdnn_assert(ih % block == 0, "%zu mod %zu == 0", ih, block); | |||
megdnn_assert(iw % block == 0, "%zu mod %zu == 0", iw, block); | |||
TensorND mask_offset; | |||
TensorND mask_val; | |||
std::vector<int>& mask_offset_vec = test_case.mask_offset_vec; | |||
std::vector<int>& mask_val_vec = test_case.mask_val_vec; | |||
UniformIntRNG rng_oc(0, oc); | |||
if (param.fastImpl == Param::FastImpl::FIX_32_MASK) { | |||
auto fix_32_mask = get_fix_mask(Param::FastImpl::FIX_32_MASK); | |||
mask_offset_vec = fix_32_mask.mask_offset; | |||
mask_val_vec = fix_32_mask.mask_val; | |||
megdnn_assert(oc == 32, "oc must eq 32"); | |||
} else if (no_mask_oc > oc) { | |||
size_t remain_oc = oc; | |||
mask_offset_vec.resize(ic + 1); | |||
mask_val_vec.resize(oc); | |||
mask_offset_vec[0] = 0; | |||
for (size_t ic_idx = 0; ic_idx < ic; ++ic_idx) { | |||
size_t random_len = (int)rng_oc.gen_single_val() * block_c; | |||
size_t mask_len = (ic_idx == ic - 1) || (remain_oc == 0) | |||
? remain_oc | |||
: random_len % remain_oc; | |||
megdnn_assert(mask_len % block_c == 0, | |||
"mask_len mod block_c == 0, but %zu mod %d ", | |||
mask_len, block_c); | |||
const size_t oc_idx = mask_offset_vec[ic_idx]; | |||
remain_oc -= mask_len; | |||
mask_offset_vec[ic_idx + 1] = oc_idx + mask_len; | |||
for (size_t mask_idx = 0; mask_idx < mask_len; ++mask_idx) { | |||
mask_val_vec[oc_idx + mask_idx] = | |||
(int)mask_rng.gen_single_val(); | |||
} | |||
} | |||
} | |||
mask_offset = TensorND(mask_offset_vec.data(), | |||
{{mask_offset_vec.size()}, dtype::Int32()}); | |||
mask_val = TensorND(mask_val_vec.data(), | |||
{{mask_val_vec.size()}, dtype::Int32()}); | |||
if (tensors_orig.size() > 1) { | |||
megdnn_assert(tensors_orig.size() == 4, "tensors_orig.size() == 4"); | |||
megdnn_assert(mask_offset_vec.size() >= 2, | |||
"mask_offset_vec.size() >= 2"); | |||
megdnn_assert(tensors_orig[1].layout == mask_offset.layout, | |||
"tensors_orig[1].layout == mask_offset.layout"); | |||
megdnn_assert(tensors_orig[2].layout == mask_val.layout, | |||
"tensors_orig[2].layout == mask_val.layout"); | |||
auto naive_handle = create_cpu_handle(2, false); | |||
megdnn_memcpy_D2D(naive_handle.get(), tensors_orig[1].raw_ptr, | |||
mask_offset.raw_ptr, | |||
mask_offset.layout.span().dist_byte()); | |||
megdnn_memcpy_D2D(naive_handle.get(), tensors_orig[2].raw_ptr, | |||
mask_val.raw_ptr, | |||
mask_val.layout.span().dist_byte()); | |||
} | |||
}; | |||
return constraint; | |||
} | |||
std::shared_ptr<DctTestcase> gen_dct_case(const size_t n, const size_t ic, | |||
const size_t ih, const size_t iw, | |||
const size_t oc, Param param, | |||
DType dst_dtype, | |||
bool correct_result) { | |||
const size_t block = param.dct_block_size; | |||
const int block_c = param.format == Param::Format::NCHW4 ? 4 : 1; | |||
megdnn_assert(oc % block_c == 0, "oc mod block_c must == 0"); | |||
std::shared_ptr<DctTestcase> test_case_ptr = DctTestcase::make(); | |||
DctTestcase& test_case = *test_case_ptr.get(); | |||
UniformIntRNG rng(0, 255); | |||
UniformIntRNG mask_rng(0, 64 / block_c - 1); | |||
const size_t input_elements = n * ic * ih * iw; | |||
const size_t no_mask_oc = ic * block * block; | |||
megdnn_assert(ih % block == 0, "%zu mod %zu == 0", ih, block); | |||
megdnn_assert(iw % block == 0, "%zu mod %zu == 0", iw, block); | |||
std::vector<uint8_t>& inp_vec = test_case.inp_vec; | |||
inp_vec.resize(input_elements); | |||
TensorShape input_shape{n, ic, ih, iw}; | |||
for (auto& elm : inp_vec) { | |||
elm = (uint8_t)rng.gen_single_val(); | |||
} | |||
auto src = TensorND(inp_vec.data(), {input_shape, dtype::Uint8()}); | |||
TensorND mask_offset; | |||
TensorND mask_val; | |||
std::vector<int>& mask_offset_vec = test_case.mask_offset_vec; | |||
std::vector<int>& mask_val_vec = test_case.mask_val_vec; | |||
UniformIntRNG rng_oc(0, oc); | |||
if (param.fastImpl == Param::FastImpl::FIX_32_MASK) { | |||
auto fix_32_mask = get_fix_mask(Param::FastImpl::FIX_32_MASK); | |||
mask_offset_vec = fix_32_mask.mask_offset; | |||
mask_val_vec = fix_32_mask.mask_val; | |||
megdnn_assert(oc == 32, "oc must eq 32"); | |||
} else if (no_mask_oc > oc) { | |||
size_t remain_oc = oc; | |||
mask_offset_vec.resize(ic + 1); | |||
mask_val_vec.resize(oc); | |||
mask_offset_vec[0] = 0; | |||
for (size_t ic_idx = 0; ic_idx < ic; ++ic_idx) { | |||
size_t random_len = (int)rng_oc.gen_single_val() * block_c; | |||
size_t mask_len = (ic_idx == ic - 1) || (remain_oc == 0) | |||
? remain_oc | |||
: random_len % remain_oc; | |||
megdnn_assert(mask_len % block_c == 0, | |||
"mask_len mod block_c == 0, but %zu mod %d ", | |||
mask_len, block_c); | |||
const size_t oc_idx = mask_offset_vec[ic_idx]; | |||
remain_oc -= mask_len; | |||
mask_offset_vec[ic_idx + 1] = oc_idx + mask_len; | |||
for (size_t mask_idx = 0; mask_idx < mask_len; ++mask_idx) { | |||
mask_val_vec[oc_idx + mask_idx] = | |||
(int)mask_rng.gen_single_val(); | |||
} | |||
} | |||
} | |||
mask_offset = TensorND(mask_offset_vec.data(), | |||
{{mask_offset_vec.size()}, dtype::Int32()}); | |||
mask_val = TensorND(mask_val_vec.data(), | |||
{{mask_val_vec.size()}, dtype::Int32()}); | |||
if (mask_offset_vec.size() >= 2) { | |||
test_case.testcase_in = { | |||
src, mask_offset, mask_val, {nullptr, {{}, dst_dtype}}}; | |||
} else { | |||
test_case.testcase_in = {src, {}, {}, {nullptr, {{}, dst_dtype}}}; | |||
} | |||
auto naive_handle = create_cpu_handle(2, false); | |||
auto opr_naive = naive_handle->create_operator<DctChannelSelectForward>(); | |||
opr_naive->param() = param; | |||
using Proxy = OprProxy<DctChannelSelectForward>; | |||
Proxy naive_proxy; | |||
TensorLayout temp_dst_layout; | |||
temp_dst_layout.dtype = dst_dtype; | |||
TensorLayoutArray layouts{src.layout, mask_offset.layout, mask_val.layout, | |||
temp_dst_layout}; | |||
naive_proxy.deduce_layout(opr_naive.get(), layouts); | |||
const size_t output_elements = layouts[3].total_nr_elems(); | |||
std::vector<float>& output_vec = test_case.output_vec; | |||
output_vec.resize(output_elements); | |||
auto dst = TensorND(output_vec.data(), layouts[3]); | |||
DctTestcase::TensorValueArray testcase_naive; | |||
testcase_naive.emplace_back(test_case.testcase_in[0]); | |||
testcase_naive.emplace_back(test_case.testcase_in[1]); | |||
testcase_naive.emplace_back(test_case.testcase_in[2]); | |||
testcase_naive.emplace_back(dst); | |||
if (correct_result) { | |||
naive_proxy.exec(opr_naive.get(), testcase_naive); | |||
} | |||
test_case.testcase_out = {{}, {}, {}, dst}; | |||
return test_case_ptr; | |||
} | |||
} // namespace test | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,52 @@ | |||
/** | |||
* \file | |||
* dnn/test/common/dct_ref.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include <math.h> | |||
#include <vector> | |||
#include "megdnn/dtype.h" | |||
#include "megdnn/oprs/nn.h" | |||
#include "test/common/checker.h" | |||
#include "test/common/opr_proxy.h" | |||
#include "test/common/rng.h" | |||
namespace megdnn { | |||
namespace test { | |||
using Param = DctChannelSelectForward::Param; | |||
struct DctTestcase { | |||
using TensorValueArray = TensorNDArray; | |||
TensorValueArray testcase_in; | |||
TensorValueArray testcase_out; | |||
std::vector<uint8_t> inp_vec; | |||
std::vector<int> mask_offset_vec; | |||
std::vector<int> mask_val_vec; | |||
std::vector<float> output_vec; | |||
static std::shared_ptr<DctTestcase> make() { | |||
return std::make_shared<DctTestcase>(); | |||
} | |||
}; | |||
CheckerHelper::TensorsConstriant gen_dct_constriant( | |||
const size_t n, const size_t ic, const size_t ih, const size_t iw, | |||
const size_t oc, Param param); | |||
std::shared_ptr<DctTestcase> gen_dct_case(const size_t n, const size_t ic, | |||
const size_t ih, const size_t iw, | |||
const size_t oc, Param param, | |||
DType dst_dtype = dtype::Float32(), | |||
bool correct_result = true); | |||
} // namespace test | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -110,6 +110,7 @@ DEF(BatchConvBiasForward, 5, true, true); | |||
DEF(Remap, 3, true, true); | |||
DEF(RemapBackwardData, 3, true, false); | |||
DEF(RemapBackwardMat, 4, true, false); | |||
DEF(DctChannelSelectForward, 4, true, true); | |||
} // namespace test | |||
} // namespace megdnn | |||
@@ -0,0 +1,360 @@ | |||
/** | |||
* \file dnn/test/cuda/dct.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megdnn/oprs/nn.h" | |||
#include "test/common/benchmarker.h" | |||
#include "test/common/checker.h" | |||
#include "test/common/dct_ref.h" | |||
#include "test/common/rng.h" | |||
#include "test/cuda/fixture.h" | |||
namespace megdnn { | |||
namespace test { | |||
TEST_F(CUDA, DCT) { | |||
DctChannelSelectForward::Param param; | |||
Checker<DctChannelSelectForward> checker(handle_cuda()); | |||
for (size_t n : {1, 3}) { | |||
for (size_t ic : {1, 3}) { | |||
for (size_t ih : {8, 16, 32, 512, 1024}) { | |||
for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) { | |||
checker.set_param(param) | |||
.set_dtype(0, dtype::Uint8()) | |||
.set_dtype(1, dtype::Int32()) | |||
.set_dtype(2, dtype::Int32()) | |||
.execs({TensorShape{n, ic, ih, iw}, {}, {}, {}}); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
TEST_F(CUDA, DCT_QINT8) { | |||
DctChannelSelectForward::Param param; | |||
Checker<DctChannelSelectForward> checker(handle_cuda()); | |||
param.format = Param::Format::NCHW4; | |||
for (size_t n : {1, 3}) { | |||
for (size_t ic : {1, 3}) { | |||
for (size_t ih : {8, 16, 32, 512, 1024}) { | |||
for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) { | |||
checker.set_param(param) | |||
.set_dtype(0, dtype::Uint8()) | |||
.set_dtype(1, dtype::Int32()) | |||
.set_dtype(2, dtype::Int32()) | |||
.set_dtype(3, dtype::QuantizedS8(10.f)) | |||
.set_epsilon(1) | |||
.execs({TensorShape{n, ic, ih, iw}, {}, {}, {}}); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
TEST_F(CUDA, DCT_WITH_FIX_32_MASK) { | |||
using Param = DctChannelSelectForward::Param; | |||
Param param; | |||
Checker<DctChannelSelectForward> checker(handle_cuda(), false); | |||
param.fastImpl = Param::FastImpl::FIX_32_MASK; | |||
auto test_case = gen_dct_case(3, 3, 1024, 768, 32, param); | |||
checker.set_param(param).exect(test_case->testcase_in, | |||
test_case->testcase_out); | |||
} | |||
TEST_F(CUDA, DCT_WITH_FIX_32_MASK_QINT8) { | |||
using Param = DctChannelSelectForward::Param; | |||
Param param; | |||
Checker<DctChannelSelectForward> checker(handle_cuda(), false); | |||
param.fastImpl = Param::FastImpl::FIX_32_MASK; | |||
param.format = Param::Format::NCHW4; | |||
auto test_case = | |||
gen_dct_case(3, 3, 1024, 768, 32, param, dtype::QuantizedS8(10.f)); | |||
checker.set_param(param).set_epsilon(1).exect(test_case->testcase_in, | |||
test_case->testcase_out); | |||
} | |||
TEST_F(CUDA, DCT_WITH_MASK) { | |||
Checker<DctChannelSelectForward> checker(handle_cuda(), false); | |||
DctChannelSelectForward::Param param; | |||
checker.set_param(param).exect( | |||
Testcase{TensorValue( | |||
{1, 3, 8, 16}, dtype::Uint8(), | |||
{109, 39, 30, 115, 71, 15, 206, 139, 221, 5, | |||
18, 16, 93, 185, 99, 102, 205, 172, 191, 29, | |||
185, 6, 47, 84, 0, 47, 105, 203, 251, 73, | |||
196, 83, 3, 211, 32, 181, 49, 111, 114, 83, | |||
148, 232, 77, 17, 35, 2, 154, 100, 41, 135, | |||
141, 206, 56, 91, 137, 199, 104, 192, 75, 122, | |||
78, 65, 184, 69, 91, 82, 2, 172, 194, 240, | |||
49, 145, 87, 210, 97, 190, 179, 93, 125, 105, | |||
181, 207, 148, 178, 133, 53, 25, 198, 238, 151, | |||
14, 120, 213, 195, 145, 20, 122, 107, 217, 185, | |||
65, 5, 115, 110, 82, 206, 163, 86, 2, 2, | |||
44, 125, 50, 38, 41, 106, 30, 5, 151, 243, | |||
238, 181, 232, 191, 161, 57, 23, 204, | |||
109, 39, 30, 115, 71, 15, 206, 139, 221, 5, | |||
18, 16, 93, 185, 99, 102, 205, 172, 191, 29, | |||
185, 6, 47, 84, 0, 47, 105, 203, 251, 73, | |||
196, 83, 3, 211, 32, 181, 49, 111, 114, 83, | |||
148, 232, 77, 17, 35, 2, 154, 100, 41, 135, | |||
141, 206, 56, 91, 137, 199, 104, 192, 75, 122, | |||
78, 65, 184, 69, 91, 82, 2, 172, 194, 240, | |||
49, 145, 87, 210, 97, 190, 179, 93, 125, 105, | |||
181, 207, 148, 178, 133, 53, 25, 198, 238, 151, | |||
14, 120, 213, 195, 145, 20, 122, 107, 217, 185, | |||
65, 5, 115, 110, 82, 206, 163, 86, 2, 2, | |||
44, 125, 50, 38, 41, 106, 30, 5, 151, 243, | |||
238, 181, 232, 191, 161, 57, 23, 204, | |||
109, 39, 30, 115, 71, 15, 206, 139, 221, 5, | |||
18, 16, 93, 185, 99, 102, 205, 172, 191, 29, | |||
185, 6, 47, 84, 0, 47, 105, 203, 251, 73, | |||
196, 83, 3, 211, 32, 181, 49, 111, 114, 83, | |||
148, 232, 77, 17, 35, 2, 154, 100, 41, 135, | |||
141, 206, 56, 91, 137, 199, 104, 192, 75, 122, | |||
78, 65, 184, 69, 91, 82, 2, 172, 194, 240, | |||
49, 145, 87, 210, 97, 190, 179, 93, 125, 105, | |||
181, 207, 148, 178, 133, 53, 25, 198, 238, 151, | |||
14, 120, 213, 195, 145, 20, 122, 107, 217, 185, | |||
65, 5, 115, 110, 82, 206, 163, 86, 2, 2, | |||
44, 125, 50, 38, 41, 106, 30, 5, 151, 243, | |||
238, 181, 232, 191, 161, 57, 23, 204}), | |||
TensorValue({4}, dtype::Int32(), {0, 14, 22, 30}), | |||
TensorValue({30}, dtype::Int32(), | |||
{8, 16, 9, 2, 3, 10, 17, 24, 32, 25, | |||
18, 11, 4, 5, 0, 1, 8, 16, 9, 2, | |||
3, 10, 0, 1, 8, 16, 9, 2, 3, 10}), | |||
{}}, | |||
Testcase{{}, | |||
{}, | |||
{}, | |||
TensorValue({1, 30, 1, 2}, dtype::Float32(), | |||
{-22.850792, -97.862236, -101.043236, | |||
-4.727012, 28.275675, -157.96654, | |||
42.1377, 45.06531, -149.77373, | |||
24.487143, -8.054966, -13.990831, | |||
-6.9395194, -3.9211385, 64.79172, | |||
-12.363858, -47.875, 59., | |||
56.271786, -62.725567, 120.522675, | |||
16.559765, 85.74334, 112.904495, | |||
99.375, 29.499973, 2.0220923, | |||
-19.681704, 890.12494, 941.25, | |||
-7.0498576, 99.47632, -22.850792, | |||
-97.862236, -101.043236, -4.727012, | |||
28.275675, -157.96654, 42.1377, | |||
45.06531, -149.77373, 24.487143, | |||
-8.054966, -13.990831, 890.12494, | |||
941.25, -7.0498576, 99.47632, | |||
-22.850792, -97.862236, -101.043236, | |||
-4.727012, 28.275675, -157.96654, | |||
42.1377, 45.06531, -149.77373, | |||
24.487143, -8.054966, -13.990831})}); | |||
} | |||
TEST_F(CUDA, DCT_WITH_MASK2) { | |||
Checker<DctChannelSelectForward> checker(handle_cuda(), false); | |||
DctChannelSelectForward::Param param; | |||
UniformIntRNG rng_oc(0, 3 * 64); | |||
for (size_t n : {1, 3}) { | |||
for (size_t ic : {1, 3}) { | |||
for (size_t ih : {8, 16, 32, 512, 1024}) { | |||
for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) { | |||
int random_oc = static_cast<int>(rng_oc.gen_single_val()); | |||
int max_oc = ic * 64; | |||
int mask_oc = (random_oc % max_oc) + 1; | |||
auto test_case = | |||
gen_dct_case(n, ic, ih, iw, mask_oc, param); | |||
checker.set_param(param).exect(test_case->testcase_in, | |||
test_case->testcase_out); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
TEST_F(CUDA, DCT_WITH_MASK2_QINT8) { | |||
Checker<DctChannelSelectForward> checker(handle_cuda(), false); | |||
DctChannelSelectForward::Param param; | |||
param.format = DctChannelSelectForward::Param::Format::NCHW4; | |||
UniformIntRNG rng_oc(0, 3 * 64); | |||
for (size_t n : {1, 3}) { | |||
for (size_t ic : {1, 3}) { | |||
for (size_t ih : {8, 16, 32, 512, 1024}) { | |||
for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) { | |||
int random_oc = static_cast<int>(rng_oc.gen_single_val()); | |||
int max_oc = ic * 64; | |||
int mask_oc = (random_oc % max_oc) + 1; | |||
mask_oc = (mask_oc + 3) / 4 * 4; | |||
auto test_case = gen_dct_case(n, ic, ih, iw, mask_oc, param, | |||
dtype::QuantizedS8(10.f)); | |||
checker.set_param(param).set_epsilon(1).exect( | |||
test_case->testcase_in, test_case->testcase_out); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
TEST_F(CUDA, DCT_WITH_MASK2_QINT8_CONSTRAINT) { | |||
DctChannelSelectForward::Param param; | |||
param.format = DctChannelSelectForward::Param::Format::NCHW4; | |||
Checker<DctChannelSelectForward> checker(handle_cuda(), false); | |||
checker.set_param(param) | |||
.set_dtype(0, dtype::Uint8()) | |||
.set_dtype(1, dtype::Int32()) | |||
.set_dtype(2, dtype::Int32()) | |||
.set_dtype(3, dtype::QuantizedS8(10.f)) | |||
.set_epsilon(1); | |||
UniformIntRNG rng_oc(0, 3 * 64); | |||
for (size_t n : {1, 3}) { | |||
for (size_t ic : {1, 3}) { | |||
for (size_t ih : {8, 16, 32, 512, 1024}) { | |||
for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) { | |||
int random_oc = static_cast<int>(rng_oc.gen_single_val()); | |||
int max_oc = ic * 64; | |||
int mask_oc = (random_oc % max_oc) + 1; | |||
mask_oc = (mask_oc + 3) / 4 * 4; | |||
if (mask_oc < max_oc) { | |||
checker | |||
.set_tensors_constraint(gen_dct_constriant( | |||
n, ic, ih, iw, mask_oc, param)) | |||
.exec({TensorShape{n, ic, ih, iw}, | |||
TensorShape{ic + 1}, | |||
TensorShape{(size_t)mask_oc}, | |||
{}}); | |||
} else { | |||
checker.set_tensors_constraint({}).exec( | |||
{TensorShape{n, ic, ih, iw}, {}, {}, {}}); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
#if MEGDNN_WITH_BENCHMARK | |||
TEST_F(CUDA, BENCHMARK_DCT) { | |||
using Param = DctChannelSelectForward::Param; | |||
auto run = [&](const TensorShapeArray& shapes, Param param) { | |||
Benchmarker<DctChannelSelectForward> benchmarker(handle_cuda()); | |||
benchmarker.set_param(param); | |||
benchmarker.set_dtype(0, dtype::Uint8()) | |||
.set_dtype(1, dtype::Int32()) | |||
.set_dtype(2, dtype::Int32()); | |||
for (auto&& shape : shapes) { | |||
double computation = double(shape[0]) * shape[1] * shape[2] * | |||
shape[3] * 32.0 * 1e-6; | |||
auto time_ms = benchmarker.execs({shape, {}, {}, {}}); | |||
printf("execute %s, %.4f Gops\n", shape.to_string().c_str(), | |||
computation / time_ms); | |||
} | |||
}; | |||
auto run_case = [&](const DctTestcase& testcase, Param param, | |||
std::string comment = "") { | |||
Benchmarker<DctChannelSelectForward> benchmarker(handle_cuda()); | |||
benchmarker.set_param(param); | |||
benchmarker.set_dtype(0, dtype::Uint8()) | |||
.set_dtype(1, dtype::Int32()) | |||
.set_dtype(2, dtype::Int32()) | |||
.set_dtype(3, testcase.testcase_out[3].layout.dtype); | |||
auto src_shape = testcase.testcase_in[0].layout; | |||
double computation = double(src_shape[0]) * src_shape[1] * | |||
src_shape[2] * src_shape[3] * 32.0 * 1e-6; | |||
auto time_ms = benchmarker.exect(testcase.testcase_in); | |||
printf("[%s] execute %s, %.4f Gops\n", comment.c_str(), | |||
src_shape.to_string().c_str(), computation / time_ms); | |||
}; | |||
auto run_case_constraint = | |||
[&](const Benchmarker<DctChannelSelectForward>::TensorsConstriant& | |||
constraint, | |||
Param param, const TensorShapeArray& shapes, | |||
std::string comment = "", DType output_dtype) { | |||
Benchmarker<DctChannelSelectForward> benchmarker(handle_cuda()); | |||
benchmarker.set_param(param) | |||
.set_dtype(0, dtype::Uint8()) | |||
.set_dtype(1, dtype::Int32()) | |||
.set_dtype(2, dtype::Int32()) | |||
.set_dtype(3, output_dtype) | |||
.set_tensors_constraint(constraint); | |||
auto src_shape = shapes[0]; | |||
double computation = double(src_shape[0]) * src_shape[1] * | |||
src_shape[2] * src_shape[3] * 32.0 * 1e-6; | |||
auto time_ms = benchmarker.exec(shapes); | |||
printf("[%s] execute %s, %.4f Gops\n", comment.c_str(), | |||
src_shape.to_string().c_str(), computation / time_ms); | |||
}; | |||
TensorShapeArray shapes = { | |||
{1, 3, 512, 512}, | |||
{8, 3, 2176, 3840}, | |||
}; | |||
{ | |||
Param param; | |||
run(shapes, param); | |||
} | |||
Param fix_32_param; | |||
fix_32_param.fastImpl = Param::FastImpl::FIX_32_MASK; | |||
{ | |||
auto test_case = gen_dct_case(8, 3, 2176, 3840, 32, fix_32_param); | |||
run_case(*test_case, fix_32_param, "FIX_32_MASK"); | |||
} | |||
{ | |||
Param param; | |||
auto test_case = gen_dct_case(8, 3, 2176, 3840, 32, fix_32_param); | |||
run_case(*test_case, param, "MASK 32"); | |||
} | |||
{ | |||
Param fix_32_nchw4_param; | |||
fix_32_nchw4_param.fastImpl = Param::FastImpl::FIX_32_MASK; | |||
fix_32_nchw4_param.format = Param::Format::NCHW4; | |||
auto test_case = gen_dct_case(8, 3, 2176, 3840, 32, fix_32_nchw4_param, | |||
dtype::QuantizedS8(10.f)); | |||
run_case(*test_case, fix_32_nchw4_param, "FIX_32_MASK QINT8"); | |||
} | |||
{ | |||
Param fix_32_nchw4_param; | |||
fix_32_nchw4_param.fastImpl = Param::FastImpl::FIX_32_MASK; | |||
fix_32_nchw4_param.format = Param::Format::NCHW4; | |||
auto test_case = gen_dct_case(8, 3, 2176, 3840, 32, fix_32_nchw4_param, | |||
dtype::QuantizedS8(10.f)); | |||
fix_32_nchw4_param.fastImpl = Param::FastImpl::NONE; | |||
run_case(*test_case, fix_32_nchw4_param, "MASK 32 QINT8"); | |||
} | |||
{ | |||
Param fix_32_nchw4_param; | |||
fix_32_nchw4_param.fastImpl = Param::FastImpl::FIX_32_MASK; | |||
fix_32_nchw4_param.format = Param::Format::NCHW4; | |||
TensorShapeArray shapes = {{8, 3, 2176, 3840}, {4}, {32}, {}}; | |||
auto constraint = | |||
gen_dct_constriant(8, 3, 2176, 3840, 32, fix_32_nchw4_param); | |||
run_case_constraint(constraint, fix_32_nchw4_param, shapes, | |||
"FIX_32_MASK QINT8 Constraint", | |||
dtype::QuantizedS8(10.f)); | |||
} | |||
} | |||
#endif | |||
} // namespace test | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,679 @@ | |||
/** | |||
* \file dnn/test/naive/dct.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megdnn/oprs/nn.h" | |||
#include "test/common/checker.h" | |||
#include "test/common/dct_ref.h" | |||
#include "test/common/rng.h" | |||
#include "test/common/tensor.h" | |||
#include "test/naive/fixture.h" | |||
namespace megdnn { | |||
namespace test { | |||
TEST_F(NAIVE, DCT) { | |||
Checker<DctChannelSelectForward> checker(handle(), | |||
/* check_dispatch */ false); | |||
DctChannelSelectForward::Param param; | |||
checker.set_param(param).exect( | |||
Testcase{TensorValue( | |||
{1, 1, 16, 16}, dtype::Uint8(), | |||
{87, 155, 59, 161, 24, 200, 58, 3, 40, 43, | |||
156, 7, 176, 232, 226, 78, 73, 236, 185, 109, | |||
196, 169, 62, 32, 167, 180, 96, 157, 101, 53, | |||
150, 47, 26, 238, 218, 210, 204, 236, 249, 111, | |||
16, 35, 169, 204, 117, 16, 3, 147, 12, 233, | |||
135, 162, 58, 118, 184, 237, 90, 105, 156, 195, | |||
196, 104, 138, 19, 82, 62, 126, 140, 220, 171, | |||
206, 232, 105, 123, 2, 135, 137, 41, 26, 219, | |||
167, 245, 104, 103, 24, 144, 141, 210, 208, 114, | |||
169, 170, 22, 11, 69, 106, 236, 150, 57, 184, | |||
75, 241, 28, 175, 178, 186, 190, 124, 187, 116, | |||
112, 162, 214, 154, 207, 31, 43, 40, 15, 188, | |||
81, 197, 20, 199, 246, 132, 159, 111, 79, 95, | |||
148, 184, 171, 173, 203, 146, 150, 33, 178, 9, | |||
141, 49, 237, 222, 72, 5, 23, 38, 248, 82, | |||
93, 229, 70, 180, 149, 232, 245, 72, 196, 138, | |||
4, 31, 160, 30, 8, 109, 153, 252, 204, 126, | |||
15, 182, 145, 130, 179, 234, 21, 240, 144, 105, | |||
77, 116, 155, 232, 168, 99, 159, 92, 251, 223, | |||
119, 173, 166, 39, 228, 91, 34, 5, 62, 172, | |||
131, 164, 143, 10, 161, 165, 221, 214, 178, 110, | |||
185, 254, 152, 149, 46, 144, 173, 237, 76, 210, | |||
221, 45, 200, 113, 58, 20, 47, 135, 228, 80, | |||
91, 51, 238, 194, 222, 231, 174, 244, 139, 96, | |||
71, 25, 25, 62, 172, 181, 71, 27, 86, 0, | |||
121, 38, 199, 236, 93, 158}), | |||
{}, | |||
{}, | |||
{}}, | |||
Testcase{{}, | |||
{}, | |||
{}, | |||
TensorValue( | |||
{1, 64, 2, 2}, dtype::Float32(), | |||
{1.10687500e+03, 9.59500000e+02, 8.98125000e+02, | |||
1.21912500e+03, 1.38846378e+01, 3.91629181e+01, | |||
-1.50343018e+02, -1.02085358e+02, 2.34341068e+01, | |||
-8.40960388e+01, -4.23510742e+01, 1.72630596e+01, | |||
-4.66624413e+01, -4.87857285e+01, -7.06332016e+01, | |||
6.31493912e+01, -9.96249924e+01, 7.72499924e+01, | |||
7.46250153e+01, 5.81250114e+01, -9.07061768e+01, | |||
-7.68266630e+00, -3.15778809e+01, -3.35406876e+01, | |||
8.55864143e+00, -7.36760712e+01, 6.20557327e+01, | |||
-2.92043419e+01, -1.39985870e+02, 2.56675129e+01, | |||
5.21866226e+01, 1.07624054e+02, -6.16851950e+00, | |||
-8.56008530e+01, 7.35654449e+01, -2.56767311e+01, | |||
-2.09981880e+01, -6.22950821e+01, -1.31617493e+02, | |||
-6.30962448e+01, -2.21552780e+02, -4.79528542e+01, | |||
1.04179153e+02, 7.45253448e+01, 3.19730816e+01, | |||
1.24306192e+01, -9.93905945e+01, -8.95680237e+01, | |||
-1.44870041e+02, -9.44738235e+01, -4.09417763e+01, | |||
4.50356903e+01, -3.65339231e+00, 5.79474449e+01, | |||
-2.46253452e+01, 3.29394951e+01, -1.09065903e+02, | |||
5.23808861e+01, -1.00386992e+01, -7.92311325e+01, | |||
-1.44292374e+01, 5.74285736e+01, 2.28798485e+01, | |||
6.84826508e+01, -1.49241837e+02, 9.35751495e+01, | |||
-4.02763329e+01, -6.63586197e+01, 2.15622040e+02, | |||
-7.83887939e+01, -8.06824951e+01, -2.51097183e+01, | |||
1.58941059e+01, -5.66967869e+00, -1.53566467e+02, | |||
-4.33494377e+01, 8.12108078e+01, 1.21169144e+02, | |||
2.14673615e+02, -3.72018318e+01, 2.45811577e+01, | |||
-1.27189613e+02, 4.98553581e+01, -5.83694696e+00, | |||
-4.80477619e+00, -2.24601650e+01, -5.02191353e+00, | |||
5.16259460e+01, 1.07266571e+02, -3.41748886e+01, | |||
-5.44621315e+01, 6.25573196e+01, -4.24649086e+01, | |||
4.42625465e+01, 2.71147366e+01, 4.83264275e+01, | |||
-6.99711227e+01, -1.00299120e+01, 1.33173111e+02, | |||
2.48003254e+01, -1.74687519e+01, 9.44530487e-01, | |||
1.35930038e+02, 6.72219162e+01, 4.53297043e+01, | |||
1.37072708e+02, -7.73253784e+01, 6.12967606e+01, | |||
9.78184891e+01, 3.63894577e+01, -1.64039135e+01, | |||
-6.67858887e+01, 5.27859840e+01, -4.99117432e+01, | |||
8.77927475e+01, -5.86666260e+01, 3.86430244e+01, | |||
2.17759323e+01, 8.34562683e+01, 3.06256886e+01, | |||
1.61030369e+01, 8.11268158e+01, 1.36932516e+01, | |||
-1.06112595e+02, -9.31621475e+01, 3.13674717e+01, | |||
-4.90609503e+00, 7.96453857e+01, -1.02625000e+02, | |||
1.40000076e+01, 3.18749981e+01, -1.08375000e+02, | |||
-5.44420319e+01, -1.50944397e+02, 5.29974670e+01, | |||
-1.44041641e+02, 4.86086197e+01, -7.13610382e+01, | |||
3.06417294e+01, 7.20477829e+01, -6.95384140e+01, | |||
1.25441925e+02, -1.54897385e+01, 3.78566666e+01, | |||
4.23749886e+01, -3.37500000e+01, -9.96250000e+01, | |||
-6.73750076e+01, 3.34241295e+01, -6.24825974e+01, | |||
1.76387348e+01, -6.45708389e+01, 1.70728874e+01, | |||
-5.73032570e+01, -1.71570969e+01, 1.84064590e+02, | |||
4.17566071e+01, 7.08248520e+00, -2.59306641e+01, | |||
1.37766739e+02, -2.16669798e+00, 6.03565750e+01, | |||
6.84421844e+01, 6.19825096e+01, -1.44220114e+01, | |||
-3.12404213e+01, -2.50061111e+01, 6.73021851e+01, | |||
2.52050266e+01, -8.35850677e+01, -4.70746574e+01, | |||
1.73889160e+01, 1.18955564e+01, 6.16792488e+00, | |||
-3.29667168e+01, 4.55779572e+01, -4.17868996e+00, | |||
-9.40233841e+01, -9.77727051e+01, 1.74934635e+01, | |||
5.25992851e+01, 1.23662634e+01, 5.26129305e-01, | |||
4.69518929e+01, -1.52657738e+01, 9.96897888e+01, | |||
-9.51726151e+01, 9.99432602e+01, -1.75949844e+02, | |||
1.00472336e+02, -5.89417953e+01, -1.72231483e+01, | |||
1.89282093e+01, -8.17851868e+01, 7.22908936e+01, | |||
-9.06294174e+01, 2.46093607e+00, -4.03946457e+01, | |||
2.17710762e+01, -5.62999649e+01, 4.77665749e+01, | |||
-4.04248848e+01, 4.78787374e+00, 1.05557320e+02, | |||
-4.60584450e+01, -7.33774490e+01, -4.25107193e+01, | |||
1.71907139e+01, -8.01314316e+01, 1.69647141e+01, | |||
-8.24824219e+01, 8.29206543e+01, 3.72900200e+01, | |||
3.77470016e+01, 6.70151443e+01, 1.79784470e+01, | |||
-4.01441078e+01, 6.29196739e+01, 7.60664597e+01, | |||
-5.59005699e+01, 8.81600475e+00, -6.89491081e+00, | |||
-8.03825378e+01, -5.33856511e-01, 7.26196136e+01, | |||
-3.76809120e+01, -1.08401566e+02, 6.35455990e+00, | |||
-8.66767120e+01, -1.02679443e+02, -9.54313660e+00, | |||
-3.55650787e+01, -1.21355652e+02, 2.32628040e+01, | |||
3.94072838e+01, 1.24754738e+02, 9.51344986e+01, | |||
-5.84752541e+01, -4.65028038e+01, 6.00556993e+00, | |||
4.94889374e+01, 7.64868622e+01, -1.49546280e+01, | |||
-3.70648766e+01, 5.55572205e+01, -1.17196434e+02, | |||
9.20216217e+01, 3.29843826e+01, 3.25113411e+01, | |||
5.62059135e+01, 6.30202141e+01, 4.99030991e+01, | |||
2.85804024e+01, -1.44606361e+01, 7.64952774e+01, | |||
-2.95697536e+01})}); | |||
} | |||
TEST_F(NAIVE, DCT_INT8) { | |||
Checker<DctChannelSelectForward> checker(handle(), | |||
/* check_dispatch */ false); | |||
DctChannelSelectForward::Param param; | |||
param.format = DctChannelSelectForward::Param::Format::NCHW4; | |||
checker.set_param(param).exect( | |||
Testcase{TensorValue( | |||
{1, 1, 16, 16}, dtype::Uint8(), | |||
{113, 223, 229, 159, 249, 252, 89, 84, 45, 16, | |||
41, 72, 184, 236, 70, 184, 86, 172, 218, 211, | |||
47, 177, 18, 85, 174, 226, 37, 109, 38, 135, | |||
228, 195, 133, 238, 47, 246, 244, 118, 175, 143, | |||
34, 10, 28, 4, 82, 103, 89, 55, 235, 78, | |||
151, 178, 249, 62, 183, 84, 105, 0, 121, 98, | |||
249, 90, 161, 114, 121, 241, 21, 199, 196, 119, | |||
231, 209, 250, 180, 192, 213, 116, 105, 114, 169, | |||
1, 142, 3, 30, 140, 245, 201, 109, 19, 26, | |||
224, 68, 123, 228, 64, 150, 184, 212, 136, 172, | |||
241, 152, 222, 233, 15, 72, 130, 144, 107, 130, | |||
242, 79, 195, 46, 226, 57, 183, 36, 88, 161, | |||
121, 170, 2, 215, 109, 212, 35, 18, 76, 197, | |||
117, 81, 208, 8, 237, 75, 15, 20, 16, 192, | |||
61, 113, 96, 126, 211, 57, 49, 62, 185, 211, | |||
155, 87, 233, 163, 164, 84, 61, 28, 1, 11, | |||
190, 253, 145, 30, 38, 98, 153, 56, 231, 152, | |||
12, 204, 96, 8, 47, 87, 25, 237, 21, 150, | |||
173, 19, 41, 175, 164, 231, 39, 145, 39, 187, | |||
210, 123, 165, 98, 87, 242, 38, 136, 182, 145, | |||
41, 47, 147, 171, 172, 35, 170, 148, 26, 89, | |||
107, 151, 130, 232, 65, 217, 27, 206, 68, 219, | |||
60, 106, 3, 209, 175, 189, 191, 32, 119, 141, | |||
56, 48, 105, 58, 94, 163, 185, 60, 83, 249, | |||
112, 245, 137, 60, 178, 51, 177, 106, 199, 209, | |||
4, 247, 3, 127, 88, 46}), | |||
{}, | |||
{}, | |||
{}}, | |||
Testcase{{}, | |||
{}, | |||
{}, | |||
TensorValue( | |||
{1, 16, 2, 2, 4}, dtype::QuantizedS8(10.f), | |||
{122, -1, -8, 4, 92, -13, -5, 7, 99, 4, | |||
5, 3, 89, 7, 2, -6, 3, -8, -10, 2, | |||
-1, 0, 4, -3, -5, -8, -11, 1, 14, 4, | |||
-10, -18, 3, 12, -14, -2, -4, -9, 12, 4, | |||
-2, -2, 2, 6, -9, 6, 1, 5, -5, -1, | |||
2, -12, 4, -5, -0, 4, 1, 5, -8, 5, | |||
-3, 4, 2, 6, -0, 9, -4, -7, -4, -5, | |||
-2, 8, 2, 4, 0, 7, -8, 4, -2, 3, | |||
-6, -5, 19, 5, -4, -4, -5, -16, -8, -3, | |||
-5, 19, 4, 3, 4, -6, 1, -12, -1, 7, | |||
11, -5, -1, -8, 2, -12, -9, -2, -4, -20, | |||
-11, -15, -15, -9, -2, -9, -2, -3, 13, 2, | |||
5, 6, 7, -4, 1, -7, 6, 4, 2, 6, | |||
0, -0, 8, 8, -6, 5, 1, -2, -2, -12, | |||
2, -12, -2, 6, 7, 3, 4, 14, 14, -3, | |||
1, -3, 6, 0, -20, 2, -10, 10, -5, -5, | |||
13, 0, -3, 7, -12, -17, -13, 1, -6, 10, | |||
-1, -9, 4, -16, 3, 2, 5, 1, -4, 9, | |||
-0, 1, 3, 15, -4, -13, -6, 4, 3, -2, | |||
-1, -4, -7, -7, -2, 8, -16, -4, -10, 5, | |||
1, -3, 2, -9, -4, 1, -1, -1, -4, -6, | |||
-4, 1, 0, -9, 15, -1, -7, -3, -5, -0, | |||
3, -0, -6, -17, 16, -3, 3, -2, -3, 5, | |||
3, -2, 3, 13, 8, 1, -3, -8, -7, -4, | |||
6, -6, -15, -7, 0, 4, -3, -3, -10, 14, | |||
1, 3, 14, 4, -1, 14})}); | |||
} | |||
TEST_F(NAIVE, DCT_INT8_MASK) { | |||
Checker<DctChannelSelectForward> checker(handle(), | |||
/* check_dispatch */ false); | |||
DctChannelSelectForward::Param param; | |||
param.format = DctChannelSelectForward::Param::Format::NCHW4; | |||
auto src_tensor = TensorValue( | |||
{1, 3, 8, 16}, dtype::Uint8(), | |||
{195, 165, 82, 30, 154, 60, 175, 195, 179, 165, 132, 37, 250, | |||
107, 36, 80, 5, 54, 247, 218, 191, 211, 239, 76, 140, 33, | |||
253, 85, 132, 101, 105, 177, 46, 183, 102, 99, 19, 175, 108, | |||
252, 42, 238, 48, 251, 108, 90, 176, 2, 35, 46, 161, 252, | |||
38, 225, 195, 174, 58, 165, 198, 249, 162, 118, 198, 41, 154, | |||
10, 87, 24, 201, 12, 188, 1, 93, 179, 246, 134, 18, 178, | |||
173, 36, 122, 89, 115, 46, 43, 205, 232, 55, 149, 30, 206, | |||
97, 186, 125, 35, 209, 51, 48, 222, 222, 130, 173, 63, 0, | |||
223, 19, 5, 162, 154, 143, 134, 63, 123, 102, 102, 212, 145, | |||
80, 87, 212, 42, 26, 219, 225, 120, 94, 213, 238, | |||
25, 172, 141, 45, 182, 203, 50, 94, 44, 88, 74, 76, 151, | |||
105, 138, 87, 125, 55, 60, 211, 15, 158, 198, 37, 54, 203, | |||
239, 79, 56, 6, 53, 201, 97, 233, 178, 74, 193, 46, 249, | |||
65, 5, 208, 130, 67, 191, 168, 152, 129, 253, 195, 231, 3, | |||
109, 229, 254, 193, 229, 202, 108, 22, 89, 251, 13, 53, 47, | |||
192, 12, 81, 19, 53, 93, 104, 41, 217, 215, 184, 136, 249, | |||
14, 244, 4, 220, 33, 53, 142, 219, 43, 28, 68, 198, 202, | |||
88, 235, 7, 233, 47, 84, 127, 28, 17, 189, 135, 183, 192, | |||
239, 116, 31, 118, 186, 49, 251, 233, 220, 27, 97, 30, 43, | |||
193, 217, 48, 24, 225, 15, 3, 26, 71, 82, 104, | |||
175, 125, 79, 195, 50, 236, 114, 179, 180, 177, 230, 173, 43, | |||
195, 123, 111, 106, 5, 91, 254, 34, 76, 52, 82, 193, 179, | |||
185, 71, 57, 215, 18, 5, 151, 13, 59, 206, 154, 95, 149, | |||
40, 229, 16, 116, 144, 249, 67, 97, 223, 208, 144, 92, 174, | |||
246, 77, 196, 211, 20, 123, 239, 250, 235, 65, 184, 54, 239, | |||
168, 135, 17, 79, 117, 171, 173, 109, 39, 57, 13, 129, 79, | |||
236, 117, 134, 123, 149, 113, 198, 160, 249, 242, 220, 226, 44, | |||
113, 164, 217, 46, 249, 182, 22, 98, 228, 49, 78, 101, 236, | |||
181, 5, 245, 72, 62, 182, 151, 210, 254, 190, 35, 73, 190, | |||
247, 50, 81, 49, 217, 86, 229, 139, 203, 57, 194}); | |||
checker.set_param(param).exect( | |||
Testcase{src_tensor, | |||
TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}), | |||
TensorValue({32}, dtype::Int32(), | |||
{0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, | |||
25, 18, 11, 4, 5, 0, 1, 8, 16, 9, 2, | |||
3, 10, 0, 1, 8, 16, 9, 2, 3, 10}), | |||
{}}, | |||
Testcase{{}, | |||
{}, | |||
{}, | |||
TensorValue( | |||
{1, 8, 1, 2, 4}, dtype::QuantizedS8(10.f), | |||
{100, -12, 7, 7, 104, 2, -2, -2, -7, -7, -3, | |||
8, 12, -12, -5, -1, 5, -7, -1, 7, -7, -3, | |||
6, 7, -0, -2, -7, 11, 6, 3, -1, 7, 94, | |||
-5, 6, -5, 98, 0, -3, -16, 5, 7, 13, -8, | |||
1, 5, -5, -8, 108, -3, -8, -7, 110, 1, -2, | |||
5, -0, 7, 8, -9, 14, -0, 1, -4})}); | |||
checker.set_param(param).exect( | |||
Testcase{TensorValue( | |||
{1, 3, 8, 16}, dtype::Uint8(), | |||
{195, 165, 82, 30, 154, 60, 175, 195, 179, 165, | |||
132, 37, 250, 107, 36, 80, 5, 54, 247, 218, | |||
191, 211, 239, 76, 140, 33, 253, 85, 132, 101, | |||
105, 177, 46, 183, 102, 99, 19, 175, 108, 252, | |||
42, 238, 48, 251, 108, 90, 176, 2, 35, 46, | |||
161, 252, 38, 225, 195, 174, 58, 165, 198, 249, | |||
162, 118, 198, 41, 154, 10, 87, 24, 201, 12, | |||
188, 1, 93, 179, 246, 134, 18, 178, 173, 36, | |||
122, 89, 115, 46, 43, 205, 232, 55, 149, 30, | |||
206, 97, 186, 125, 35, 209, 51, 48, 222, 222, | |||
130, 173, 63, 0, 223, 19, 5, 162, 154, 143, | |||
134, 63, 123, 102, 102, 212, 145, 80, 87, 212, | |||
42, 26, 219, 225, 120, 94, 213, 238, | |||
25, 172, 141, 45, 182, 203, 50, 94, 44, 88, | |||
74, 76, 151, 105, 138, 87, 125, 55, 60, 211, | |||
15, 158, 198, 37, 54, 203, 239, 79, 56, 6, | |||
53, 201, 97, 233, 178, 74, 193, 46, 249, 65, | |||
5, 208, 130, 67, 191, 168, 152, 129, 253, 195, | |||
231, 3, 109, 229, 254, 193, 229, 202, 108, 22, | |||
89, 251, 13, 53, 47, 192, 12, 81, 19, 53, | |||
93, 104, 41, 217, 215, 184, 136, 249, 14, 244, | |||
4, 220, 33, 53, 142, 219, 43, 28, 68, 198, | |||
202, 88, 235, 7, 233, 47, 84, 127, 28, 17, | |||
189, 135, 183, 192, 239, 116, 31, 118, 186, 49, | |||
251, 233, 220, 27, 97, 30, 43, 193, 217, 48, | |||
24, 225, 15, 3, 26, 71, 82, 104, | |||
175, 125, 79, 195, 50, 236, 114, 179, 180, 177, | |||
230, 173, 43, 195, 123, 111, 106, 5, 91, 254, | |||
34, 76, 52, 82, 193, 179, 185, 71, 57, 215, | |||
18, 5, 151, 13, 59, 206, 154, 95, 149, 40, | |||
229, 16, 116, 144, 249, 67, 97, 223, 208, 144, | |||
92, 174, 246, 77, 196, 211, 20, 123, 239, 250, | |||
235, 65, 184, 54, 239, 168, 135, 17, 79, 117, | |||
171, 173, 109, 39, 57, 13, 129, 79, 236, 117, | |||
134, 123, 149, 113, 198, 160, 249, 242, 220, 226, | |||
44, 113, 164, 217, 46, 249, 182, 22, 98, 228, | |||
49, 78, 101, 236, 181, 5, 245, 72, 62, 182, | |||
151, 210, 254, 190, 35, 73, 190, 247, 50, 81, | |||
49, 217, 86, 229, 139, 203, 57, 194}), | |||
TensorValue({4}, dtype::Int32(), {0, 12, 20, 28}), | |||
TensorValue({28}, dtype::Int32(), | |||
{0, 1, 8, 16, 9, 2, 3, 10, 17, 24, | |||
32, 25, 0, 1, 8, 16, 9, 2, 3, 10, | |||
0, 1, 8, 16, 9, 2, 3, 10}), | |||
{}}, | |||
Testcase{{}, | |||
{}, | |||
{}, | |||
TensorValue( | |||
{1, 7, 1, 2, 4}, dtype::QuantizedS8(10.f), | |||
{100, -12, 7, 7, 104, 2, -2, -2, -7, -7, -3, | |||
8, 12, -12, -5, -1, 5, -7, -1, 7, -7, -3, | |||
6, 7, | |||
94, -5, 6, -5, 98, 0, -3, -16, 5, 7, 13, | |||
-8, 1, 5, -5, -8, 108, -3, -8, -7, 110, 1, | |||
-2, 5, -0, 7, 8, -9, 14, -0, 1, -4})}); | |||
} | |||
TEST_F(NAIVE, DCT_4x4) { | |||
Checker<DctChannelSelectForward> checker(handle(), | |||
/* check_dispatch */ false); | |||
DctChannelSelectForward::Param param; | |||
param.dct_block_size = 4; | |||
checker.set_param(param).exect( | |||
Testcase{TensorValue( | |||
{1, 1, 8, 8}, dtype::Uint8(), | |||
{186, 120, 112, 220, 69, 80, 201, 127, 246, 254, | |||
175, 50, 240, 251, 76, 37, 34, 166, 250, 195, | |||
231, 139, 128, 233, 75, 80, 3, 2, 19, 140, | |||
193, 203, 115, 107, 250, 209, 14, 243, 199, 60, | |||
234, 107, 174, 156, 81, 87, 13, 116, 96, 140, | |||
197, 253, 113, 223, 229, 159, 249, 252, 89, 84, | |||
45, 16, 41, 72}), | |||
{}, | |||
{}, | |||
{}}, | |||
Testcase{{}, | |||
{}, | |||
{}, | |||
TensorValue( | |||
{1, 16, 2, 2}, dtype::Float32(), | |||
{5.42000000e+02, 5.91750000e+02, 6.78000000e+02, | |||
4.27750000e+02, 3.49953423e+01, -1.17686939e+01, | |||
-1.66842098e+01, -3.85316620e+01, -3.80000000e+01, | |||
-1.22500000e+01, 2.00000000e+01, -9.77500000e+01, | |||
-1.61191311e+01, -9.46695328e+00, 3.28882408e+01, | |||
-4.92537880e+01, 1.66958221e+02, -4.26609573e+01, | |||
2.56999969e-01, 5.39384537e+01, 1.71819706e+01, | |||
9.00009003e+01, -1.23818558e+02, 1.18912420e+01, | |||
6.61014938e+01, -2.49261990e+01, 4.95798302e+00, | |||
-1.02324417e+02, 7.85859919e+00, 3.73140755e+01, | |||
1.03783745e+02, -4.61430321e+01, -1.43000000e+02, | |||
-7.57500000e+01, -5.00000000e-01, -8.27500000e+01, | |||
1.34834738e+01, -1.93409515e+02, 6.84791718e+01, | |||
-4.01652241e+00, 1.22000000e+02, -8.57500000e+01, | |||
-4.05000000e+01, -5.62500000e+01, -2.88564739e+01, | |||
5.76532059e+01, -2.67414131e+01, 1.70877876e+01, | |||
3.85416756e+01, 3.09300461e+01, 5.84670639e+00, | |||
1.85747864e+02, -2.05141403e+02, -9.91859360e+01, | |||
-1.66716263e+02, -1.71430378e+01, 6.71520996e+00, | |||
8.41980438e+01, -3.50666313e+01, -1.48387482e+02, | |||
1.08180256e+01, 5.49991112e+01, -1.06814528e+01, | |||
1.86087704e+01})}); | |||
checker.set_param(param).exect( | |||
Testcase{TensorValue( | |||
{1, 1, 8, 8}, dtype::Uint8(), | |||
{186, 120, 112, 220, 69, 80, 201, 127, 246, 254, | |||
175, 50, 240, 251, 76, 37, 34, 166, 250, 195, | |||
231, 139, 128, 233, 75, 80, 3, 2, 19, 140, | |||
193, 203, 115, 107, 250, 209, 14, 243, 199, 60, | |||
234, 107, 174, 156, 81, 87, 13, 116, 96, 140, | |||
197, 253, 113, 223, 229, 159, 249, 252, 89, 84, | |||
45, 16, 41, 72}), | |||
TensorValue({2}, dtype::Int32(), {0, 6}), | |||
TensorValue({6}, dtype::Int32(), {0, 1, 8, 4, 2, 3}), | |||
{}}, | |||
Testcase{ | |||
{}, | |||
{}, | |||
{}, | |||
TensorValue( | |||
{1, 6, 2, 2}, dtype::Float32(), | |||
{5.4200000e+02, 5.9175000e+02, 6.7800000e+02, | |||
4.2775000e+02, 3.4995342e+01, -1.1768694e+01, | |||
-1.6684210e+01, -3.8531662e+01, -1.4300000e+02, | |||
-7.5750000e+01, -5.0000000e-01, -8.2750000e+01, | |||
1.6695822e+02, -4.2660957e+01, 2.5699997e-01, | |||
5.3938454e+01, -3.8000000e+01, -1.2250000e+01, | |||
2.0000000e+01, -9.7750000e+01, -1.6119131e+01, | |||
-9.4669533e+00, 3.2888241e+01, -4.9253788e+01})}); | |||
} | |||
TEST_F(NAIVE, DCT_WITH_MASK) { | |||
Checker<DctChannelSelectForward> checker(handle(), | |||
/* check_dispatch */ false); | |||
DctChannelSelectForward::Param param; | |||
checker.set_param(param).exect( | |||
Testcase{TensorValue( | |||
{1, 3, 8, 16}, dtype::Uint8(), | |||
{109, 39, 30, 115, 71, 15, 206, 139, 221, 5, | |||
18, 16, 93, 185, 99, 102, 205, 172, 191, 29, | |||
185, 6, 47, 84, 0, 47, 105, 203, 251, 73, | |||
196, 83, 3, 211, 32, 181, 49, 111, 114, 83, | |||
148, 232, 77, 17, 35, 2, 154, 100, 41, 135, | |||
141, 206, 56, 91, 137, 199, 104, 192, 75, 122, | |||
78, 65, 184, 69, 91, 82, 2, 172, 194, 240, | |||
49, 145, 87, 210, 97, 190, 179, 93, 125, 105, | |||
181, 207, 148, 178, 133, 53, 25, 198, 238, 151, | |||
14, 120, 213, 195, 145, 20, 122, 107, 217, 185, | |||
65, 5, 115, 110, 82, 206, 163, 86, 2, 2, | |||
44, 125, 50, 38, 41, 106, 30, 5, 151, 243, | |||
238, 181, 232, 191, 161, 57, 23, 204, | |||
109, 39, 30, 115, 71, 15, 206, 139, 221, 5, | |||
18, 16, 93, 185, 99, 102, 205, 172, 191, 29, | |||
185, 6, 47, 84, 0, 47, 105, 203, 251, 73, | |||
196, 83, 3, 211, 32, 181, 49, 111, 114, 83, | |||
148, 232, 77, 17, 35, 2, 154, 100, 41, 135, | |||
141, 206, 56, 91, 137, 199, 104, 192, 75, 122, | |||
78, 65, 184, 69, 91, 82, 2, 172, 194, 240, | |||
49, 145, 87, 210, 97, 190, 179, 93, 125, 105, | |||
181, 207, 148, 178, 133, 53, 25, 198, 238, 151, | |||
14, 120, 213, 195, 145, 20, 122, 107, 217, 185, | |||
65, 5, 115, 110, 82, 206, 163, 86, 2, 2, | |||
44, 125, 50, 38, 41, 106, 30, 5, 151, 243, | |||
238, 181, 232, 191, 161, 57, 23, 204, | |||
109, 39, 30, 115, 71, 15, 206, 139, 221, 5, | |||
18, 16, 93, 185, 99, 102, 205, 172, 191, 29, | |||
185, 6, 47, 84, 0, 47, 105, 203, 251, 73, | |||
196, 83, 3, 211, 32, 181, 49, 111, 114, 83, | |||
148, 232, 77, 17, 35, 2, 154, 100, 41, 135, | |||
141, 206, 56, 91, 137, 199, 104, 192, 75, 122, | |||
78, 65, 184, 69, 91, 82, 2, 172, 194, 240, | |||
49, 145, 87, 210, 97, 190, 179, 93, 125, 105, | |||
181, 207, 148, 178, 133, 53, 25, 198, 238, 151, | |||
14, 120, 213, 195, 145, 20, 122, 107, 217, 185, | |||
65, 5, 115, 110, 82, 206, 163, 86, 2, 2, | |||
44, 125, 50, 38, 41, 106, 30, 5, 151, 243, | |||
238, 181, 232, 191, 161, 57, 23, 204}), | |||
TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}), | |||
TensorValue({32}, dtype::Int32(), | |||
{0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, | |||
25, 18, 11, 4, 5, 0, 1, 8, 16, 9, 2, | |||
3, 10, 0, 1, 8, 16, 9, 2, 3, 10}), | |||
{}}, | |||
Testcase{{}, | |||
{}, | |||
{}, | |||
TensorValue({1, 32, 1, 2}, dtype::Float32(), | |||
{890.12494, 941.25, -7.0498576, | |||
99.47632, -22.850792, -97.862236, | |||
-101.043236, -4.727012, 28.275675, | |||
-157.96654, 42.1377, 45.06531, | |||
-149.77373, 24.487143, -8.054966, | |||
-13.990831, -6.9395194, -3.9211385, | |||
64.79172, -12.363858, -47.875, | |||
59., 56.271786, -62.725567, | |||
120.522675, 16.559765, 85.74334, | |||
112.904495, 99.375, 29.499973, | |||
2.0220923, -19.681704, 890.12494, | |||
941.25, -7.0498576, 99.47632, | |||
-22.850792, -97.862236, -101.043236, | |||
-4.727012, 28.275675, -157.96654, | |||
42.1377, 45.06531, -149.77373, | |||
24.487143, -8.054966, -13.990831, | |||
890.12494, 941.25, -7.0498576, | |||
99.47632, -22.850792, -97.862236, | |||
-101.043236, -4.727012, 28.275675, | |||
-157.96654, 42.1377, 45.06531, | |||
-149.77373, 24.487143, -8.054966, | |||
-13.990831})}); | |||
checker.set_param(param).exect( | |||
Testcase{TensorValue( | |||
{1, 3, 8, 16}, dtype::Uint8(), | |||
{109, 39, 30, 115, 71, 15, 206, 139, 221, 5, | |||
18, 16, 93, 185, 99, 102, 205, 172, 191, 29, | |||
185, 6, 47, 84, 0, 47, 105, 203, 251, 73, | |||
196, 83, 3, 211, 32, 181, 49, 111, 114, 83, | |||
148, 232, 77, 17, 35, 2, 154, 100, 41, 135, | |||
141, 206, 56, 91, 137, 199, 104, 192, 75, 122, | |||
78, 65, 184, 69, 91, 82, 2, 172, 194, 240, | |||
49, 145, 87, 210, 97, 190, 179, 93, 125, 105, | |||
181, 207, 148, 178, 133, 53, 25, 198, 238, 151, | |||
14, 120, 213, 195, 145, 20, 122, 107, 217, 185, | |||
65, 5, 115, 110, 82, 206, 163, 86, 2, 2, | |||
44, 125, 50, 38, 41, 106, 30, 5, 151, 243, | |||
238, 181, 232, 191, 161, 57, 23, 204, | |||
109, 39, 30, 115, 71, 15, 206, 139, 221, 5, | |||
18, 16, 93, 185, 99, 102, 205, 172, 191, 29, | |||
185, 6, 47, 84, 0, 47, 105, 203, 251, 73, | |||
196, 83, 3, 211, 32, 181, 49, 111, 114, 83, | |||
148, 232, 77, 17, 35, 2, 154, 100, 41, 135, | |||
141, 206, 56, 91, 137, 199, 104, 192, 75, 122, | |||
78, 65, 184, 69, 91, 82, 2, 172, 194, 240, | |||
49, 145, 87, 210, 97, 190, 179, 93, 125, 105, | |||
181, 207, 148, 178, 133, 53, 25, 198, 238, 151, | |||
14, 120, 213, 195, 145, 20, 122, 107, 217, 185, | |||
65, 5, 115, 110, 82, 206, 163, 86, 2, 2, | |||
44, 125, 50, 38, 41, 106, 30, 5, 151, 243, | |||
238, 181, 232, 191, 161, 57, 23, 204, | |||
109, 39, 30, 115, 71, 15, 206, 139, 221, 5, | |||
18, 16, 93, 185, 99, 102, 205, 172, 191, 29, | |||
185, 6, 47, 84, 0, 47, 105, 203, 251, 73, | |||
196, 83, 3, 211, 32, 181, 49, 111, 114, 83, | |||
148, 232, 77, 17, 35, 2, 154, 100, 41, 135, | |||
141, 206, 56, 91, 137, 199, 104, 192, 75, 122, | |||
78, 65, 184, 69, 91, 82, 2, 172, 194, 240, | |||
49, 145, 87, 210, 97, 190, 179, 93, 125, 105, | |||
181, 207, 148, 178, 133, 53, 25, 198, 238, 151, | |||
14, 120, 213, 195, 145, 20, 122, 107, 217, 185, | |||
65, 5, 115, 110, 82, 206, 163, 86, 2, 2, | |||
44, 125, 50, 38, 41, 106, 30, 5, 151, 243, | |||
238, 181, 232, 191, 161, 57, 23, 204}), | |||
TensorValue({4}, dtype::Int32(), {0, 8, 16, 24}), | |||
TensorValue({24}, dtype::Int32(), | |||
{17, 24, 32, 25, 18, 11, 4, 5, 0, 1, 8, 16, | |||
9, 2, 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}), | |||
{}}, | |||
Testcase{{}, | |||
{}, | |||
{}, | |||
TensorValue({1, 24, 1, 2}, dtype::Float32(), | |||
{-6.9395194, -3.9211385, 64.79172, | |||
-12.363858, -47.875, 59., | |||
56.271786, -62.725567, 120.522675, | |||
16.559765, 85.74334, 112.904495, | |||
99.375, 29.499973, 2.0220923, | |||
-19.681704, 890.12494, 941.25, | |||
-7.0498576, 99.47632, -22.850792, | |||
-97.862236, -101.043236, -4.727012, | |||
28.275675, -157.96654, 42.1377, | |||
45.06531, -149.77373, 24.487143, | |||
-8.054966, -13.990831, 890.12494, | |||
941.25, -7.0498576, 99.47632, | |||
-22.850792, -97.862236, -101.043236, | |||
-4.727012, 28.275675, -157.96654, | |||
42.1377, 45.06531, -149.77373, | |||
24.487143, -8.054966, -13.990831})}); | |||
} | |||
TEST_F(NAIVE, DCT_WITH_FIX_32_MASK) { | |||
Checker<DctChannelSelectForward> checker(handle(), | |||
/* check_dispatch */ false); | |||
using Param = DctChannelSelectForward::Param; | |||
Param param; | |||
param.fastImpl = Param::FastImpl::FIX_32_MASK; | |||
checker.set_param(param).exect( | |||
Testcase{TensorValue( | |||
{1, 3, 8, 16}, dtype::Uint8(), | |||
{109, 39, 30, 115, 71, 15, 206, 139, 221, 5, | |||
18, 16, 93, 185, 99, 102, 205, 172, 191, 29, | |||
185, 6, 47, 84, 0, 47, 105, 203, 251, 73, | |||
196, 83, 3, 211, 32, 181, 49, 111, 114, 83, | |||
148, 232, 77, 17, 35, 2, 154, 100, 41, 135, | |||
141, 206, 56, 91, 137, 199, 104, 192, 75, 122, | |||
78, 65, 184, 69, 91, 82, 2, 172, 194, 240, | |||
49, 145, 87, 210, 97, 190, 179, 93, 125, 105, | |||
181, 207, 148, 178, 133, 53, 25, 198, 238, 151, | |||
14, 120, 213, 195, 145, 20, 122, 107, 217, 185, | |||
65, 5, 115, 110, 82, 206, 163, 86, 2, 2, | |||
44, 125, 50, 38, 41, 106, 30, 5, 151, 243, | |||
238, 181, 232, 191, 161, 57, 23, 204, | |||
109, 39, 30, 115, 71, 15, 206, 139, 221, 5, | |||
18, 16, 93, 185, 99, 102, 205, 172, 191, 29, | |||
185, 6, 47, 84, 0, 47, 105, 203, 251, 73, | |||
196, 83, 3, 211, 32, 181, 49, 111, 114, 83, | |||
148, 232, 77, 17, 35, 2, 154, 100, 41, 135, | |||
141, 206, 56, 91, 137, 199, 104, 192, 75, 122, | |||
78, 65, 184, 69, 91, 82, 2, 172, 194, 240, | |||
49, 145, 87, 210, 97, 190, 179, 93, 125, 105, | |||
181, 207, 148, 178, 133, 53, 25, 198, 238, 151, | |||
14, 120, 213, 195, 145, 20, 122, 107, 217, 185, | |||
65, 5, 115, 110, 82, 206, 163, 86, 2, 2, | |||
44, 125, 50, 38, 41, 106, 30, 5, 151, 243, | |||
238, 181, 232, 191, 161, 57, 23, 204, | |||
109, 39, 30, 115, 71, 15, 206, 139, 221, 5, | |||
18, 16, 93, 185, 99, 102, 205, 172, 191, 29, | |||
185, 6, 47, 84, 0, 47, 105, 203, 251, 73, | |||
196, 83, 3, 211, 32, 181, 49, 111, 114, 83, | |||
148, 232, 77, 17, 35, 2, 154, 100, 41, 135, | |||
141, 206, 56, 91, 137, 199, 104, 192, 75, 122, | |||
78, 65, 184, 69, 91, 82, 2, 172, 194, 240, | |||
49, 145, 87, 210, 97, 190, 179, 93, 125, 105, | |||
181, 207, 148, 178, 133, 53, 25, 198, 238, 151, | |||
14, 120, 213, 195, 145, 20, 122, 107, 217, 185, | |||
65, 5, 115, 110, 82, 206, 163, 86, 2, 2, | |||
44, 125, 50, 38, 41, 106, 30, 5, 151, 243, | |||
238, 181, 232, 191, 161, 57, 23, 204}), | |||
TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}), | |||
TensorValue({32}, dtype::Int32(), | |||
{0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, | |||
25, 18, 11, 4, 5, 0, 1, 8, 16, 9, 2, | |||
3, 10, 0, 1, 8, 16, 9, 2, 3, 10}), | |||
{}}, | |||
Testcase{{}, | |||
{}, | |||
{}, | |||
TensorValue({1, 32, 1, 2}, dtype::Float32(), | |||
{890.12494, 941.25, -7.0498576, | |||
99.47632, -22.850792, -97.862236, | |||
-101.043236, -4.727012, 28.275675, | |||
-157.96654, 42.1377, 45.06531, | |||
-149.77373, 24.487143, -8.054966, | |||
-13.990831, -6.9395194, -3.9211385, | |||
64.79172, -12.363858, -47.875, | |||
59., 56.271786, -62.725567, | |||
120.522675, 16.559765, 85.74334, | |||
112.904495, 99.375, 29.499973, | |||
2.0220923, -19.681704, 890.12494, | |||
941.25, -7.0498576, 99.47632, | |||
-22.850792, -97.862236, -101.043236, | |||
-4.727012, 28.275675, -157.96654, | |||
42.1377, 45.06531, -149.77373, | |||
24.487143, -8.054966, -13.990831, | |||
890.12494, 941.25, -7.0498576, | |||
99.47632, -22.850792, -97.862236, | |||
-101.043236, -4.727012, 28.275675, | |||
-157.96654, 42.1377, 45.06531, | |||
-149.77373, 24.487143, -8.054966, | |||
-13.990831})}); | |||
} | |||
TEST_F(NAIVE, DCT_WITH_MASK2) { | |||
Checker<DctChannelSelectForward> checker(handle(), false); | |||
DctChannelSelectForward::Param param; | |||
UniformIntRNG rng_oc(0, 3 * 64); | |||
for (size_t n : {1, 3}) { | |||
for (size_t ic : {1, 3}) { | |||
for (size_t ih : {8, 16, 32, 512, 1024}) { | |||
for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) { | |||
int random_oc = static_cast<int>(rng_oc.gen_single_val()); | |||
int max_oc = ic * 64; | |||
int mask_oc = (random_oc % max_oc) + 1; | |||
auto test_case = | |||
gen_dct_case(n, ic, ih, iw, mask_oc, param); | |||
checker.set_param(param).exect(test_case->testcase_in, | |||
test_case->testcase_out); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} // namespace test | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |