|
|
@@ -54,7 +54,7 @@ namespace convolution { |
|
|
|
using KernLayout = _kern_layout; \ |
|
|
|
using OutputLayout = _output_layout; \ |
|
|
|
using Param = _conv_param; \ |
|
|
|
static constexpr bool check_bounds = check_bounds_; |
|
|
|
static constexpr bool check_bounds = check_bounds_ |
|
|
|
#define MEGDNN_COMMA , |
|
|
|
|
|
|
|
template < |
|
|
@@ -183,7 +183,7 @@ struct IConvIMMATraitReorderFilter { |
|
|
|
check_bounds MEGDNN_COMMA IMMAConfig MEGDNN_COMMA |
|
|
|
WarpTileConfig MEGDNN_COMMA ThreadConfig>::src_dtype |
|
|
|
MEGDNN_COMMA src_dtype>::value == true, |
|
|
|
"data type of input tensor should be int8_t"); |
|
|
|
"data type of input tensor should be int8_t") |
|
|
|
using DataTileCount = typename IConvIMMATrait< |
|
|
|
check_bounds, IMMAConfig, WarpTileConfig, ThreadConfig>::DataTileCount; |
|
|
|
struct FilterTileCount { |
|
|
@@ -284,7 +284,7 @@ struct IConvIMMATraitUnrollWidth { |
|
|
|
check_bounds MEGDNN_COMMA IMMAConfig MEGDNN_COMMA |
|
|
|
WarpTileConfig MEGDNN_COMMA ThreadConfig>:: |
|
|
|
filter_dtype MEGDNN_COMMA filter_dtype>::value == true, |
|
|
|
"data type of filter tensor should be int8_t"); |
|
|
|
"data type of filter tensor should be int8_t") |
|
|
|
using FilterTileCount = typename IConvIMMATraitReorderFilter< |
|
|
|
check_bounds, IMMAConfig, WarpTileConfig, ThreadConfig>::FilterTileCount; |
|
|
|
using BlockTileIterator = |
|
|
@@ -345,7 +345,7 @@ struct IConvIMMATraitUnrollWidthV2 { |
|
|
|
MEGDNN_STATIC_ASSERT( |
|
|
|
WarpTileConfig::warp_tile_k == 1, |
|
|
|
"kernel unrolling along width axis assumes tile k " |
|
|
|
"in warp-level must be 1"); |
|
|
|
"in warp-level must be 1") |
|
|
|
using copy_t = int4; |
|
|
|
using smem_storage_dtype = smem_storage_dtype; |
|
|
|
static int constexpr load_width = sizeof(copy_t) / sizeof(smem_storage_dtype); |
|
|
@@ -388,7 +388,7 @@ struct IConvIMMATraitUnrollWidthV2 { |
|
|
|
MEGDNN_STATIC_ASSERT( |
|
|
|
WarpTileConfig::warp_tile_k == 1, |
|
|
|
"kernel unrolling along width axis assumes tile k " |
|
|
|
"in warp-level must be 1"); |
|
|
|
"in warp-level must be 1") |
|
|
|
using copy_t = int4; |
|
|
|
using smem_storage_dtype = smem_storage_dtype; |
|
|
|
static int constexpr load_width = sizeof(copy_t) / sizeof(smem_storage_dtype); |
|
|
|