@@ -434,7 +434,7 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) | |||||
'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'), | 'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'), | ||||
Doc('MK8', 'Split 8 from M and K, better for neon compute:' | Doc('MK8', 'Split 8 from M and K, better for neon compute:' | ||||
'(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the ' | '(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the ' | ||||
'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), | |||||
'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), | |||||
Doc('MK4_DOT', 'Split 4 from M and K, better for neon dotprod:' | Doc('MK4_DOT', 'Split 4 from M and K, better for neon dotprod:' | ||||
'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the ' | 'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the ' | ||||
'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))')) | 'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))')) | ||||
@@ -858,7 +858,10 @@ when the ``I`` suffix is present. | |||||
'NCHW_NCHW88_CONV_CHAN_WEIGHT', | 'NCHW_NCHW88_CONV_CHAN_WEIGHT', | ||||
'NCHW_NCHW88_CONV_GROUP_WEIGHT', | 'NCHW_NCHW88_CONV_GROUP_WEIGHT', | ||||
'NCHW_NCHW88', | 'NCHW_NCHW88', | ||||
'NCHW88_NCHW') | |||||
'NCHW88_NCHW', | |||||
'NCHW_NCHW4_IC_SMALL', | |||||
'NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT', | |||||
) | |||||
) | ) | ||||
@@ -90,12 +90,11 @@ inline int8x16_t vqtbl1q_s8_v7(int8x16_t a, uint8x16_t index) { | |||||
_sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem); | _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem); | ||||
template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
void conv_bias::conv_direct_stride1_2x2_int8_dot(const int8_t* src, | |||||
const int8_t* filter, | |||||
const int32_t* bias, int32_t* temp, | |||||
int8_t* dst, const size_t IH, | |||||
const size_t IW, const size_t OH, | |||||
const size_t OW, const Op& op) { | |||||
void conv_bias::conv_direct_stride1_2x2_int8_dot( | |||||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | |||||
const size_t OH, const size_t OW, const Op& op) { | |||||
MEGDNN_MARK_USED_VAR(IH); | |||||
const size_t tail_step = IW - OW; | const size_t tail_step = IW - OW; | ||||
const uint8x16_t _idx0 = {0, 1, 16, 16, 1, 2, 16, 16, | const uint8x16_t _idx0 = {0, 1, 16, 16, 1, 2, 16, 16, | ||||
2, 3, 16, 16, 3, 4, 16, 16}; | 2, 3, 16, 16, 3, 4, 16, 16}; | ||||
@@ -326,12 +325,11 @@ void conv_bias::conv_direct_stride1_2x2_int8_dot(const int8_t* src, | |||||
} | } | ||||
template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
void conv_bias::conv_direct_stride1_3x3_int8_dot(const int8_t* src, | |||||
const int8_t* filter, | |||||
const int32_t* bias, int32_t* temp, | |||||
int8_t* dst, const size_t IH, | |||||
const size_t IW, const size_t OH, | |||||
const size_t OW, const Op& op) { | |||||
void conv_bias::conv_direct_stride1_3x3_int8_dot( | |||||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | |||||
const size_t OH, const size_t OW, const Op& op) { | |||||
MEGDNN_MARK_USED_VAR(IH); | |||||
const size_t tail_step = IW - OW; | const size_t tail_step = IW - OW; | ||||
const uint8x16_t _idx0 = {0, 1, 2, 16, 1, 2, 3, 16, | const uint8x16_t _idx0 = {0, 1, 2, 16, 1, 2, 3, 16, | ||||
@@ -562,12 +560,11 @@ void conv_bias::conv_direct_stride1_3x3_int8_dot(const int8_t* src, | |||||
} | } | ||||
template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
void conv_bias::conv_direct_stride2_2x2_int8_dot(const int8_t* src, | |||||
const int8_t* filter, | |||||
const int32_t* bias, int32_t* temp, | |||||
int8_t* dst, const size_t IH, | |||||
const size_t IW, const size_t OH, | |||||
const size_t OW, const Op& op) { | |||||
void conv_bias::conv_direct_stride2_2x2_int8_dot( | |||||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | |||||
const size_t OH, const size_t OW, const Op& op) { | |||||
MEGDNN_MARK_USED_VAR(IH); | |||||
const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
const uint8x16_t _idx0 = {0, 1, 16, 16, 2, 3, 16, 16, | const uint8x16_t _idx0 = {0, 1, 16, 16, 2, 3, 16, 16, | ||||
@@ -658,12 +655,11 @@ void conv_bias::conv_direct_stride2_2x2_int8_dot(const int8_t* src, | |||||
} | } | ||||
template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
void conv_bias::conv_direct_stride2_3x3_int8_dot(const int8_t* src, | |||||
const int8_t* filter, | |||||
const int32_t* bias, int32_t* temp, | |||||
int8_t* dst, const size_t IH, | |||||
const size_t IW, const size_t OH, | |||||
const size_t OW, const Op& op) { | |||||
void conv_bias::conv_direct_stride2_3x3_int8_dot( | |||||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | |||||
const size_t OH, const size_t OW, const Op& op) { | |||||
MEGDNN_MARK_USED_VAR(IH); | |||||
const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
const uint8x16_t _idx0 = {0, 1, 2, 16, 2, 3, 4, 16, | const uint8x16_t _idx0 = {0, 1, 2, 16, 2, 3, 4, 16, | ||||
@@ -814,12 +810,11 @@ void conv_bias::conv_direct_stride2_3x3_int8_dot(const int8_t* src, | |||||
_sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem); | _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem); | ||||
template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
void conv_bias::conv_direct_stride2_5x5_int8_dot(const int8_t* src, | |||||
const int8_t* filter, | |||||
const int32_t* bias, int32_t* temp, | |||||
int8_t* dst, const size_t IH, | |||||
const size_t IW, const size_t OH, | |||||
const size_t OW, const Op& op) { | |||||
void conv_bias::conv_direct_stride2_5x5_int8_dot( | |||||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | |||||
const size_t OH, const size_t OW, const Op& op) { | |||||
MEGDNN_MARK_USED_VAR(IH); | |||||
const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
const uint8x16_t _idx00 = {0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9}; | const uint8x16_t _idx00 = {0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9}; | ||||
@@ -1113,12 +1108,11 @@ void conv_bias::conv_direct_stride2_5x5_int8_dot(const int8_t* src, | |||||
} | } | ||||
template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
void conv_bias::conv_direct_stride2_7x7_int8_dot(const int8_t* src, | |||||
const int8_t* filter, | |||||
const int32_t* bias, int32_t* temp, | |||||
int8_t* dst, const size_t IH, | |||||
const size_t IW, const size_t OH, | |||||
const size_t OW, const Op& op) { | |||||
void conv_bias::conv_direct_stride2_7x7_int8_dot( | |||||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | |||||
const size_t OH, const size_t OW, const Op& op) { | |||||
MEGDNN_MARK_USED_VAR(IH); | |||||
const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
const uint8x16_t _idx00 = {0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9}; | const uint8x16_t _idx00 = {0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9}; | ||||
@@ -1476,12 +1470,11 @@ void conv_bias::conv_direct_stride2_7x7_int8_dot(const int8_t* src, | |||||
} | } | ||||
template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
void conv_bias::conv_direct_stride1_5x5_int8_dot(const int8_t* src, | |||||
const int8_t* filter, | |||||
const int32_t* bias, int32_t* temp, | |||||
int8_t* dst, const size_t IH, | |||||
const size_t IW, const size_t OH, | |||||
const size_t OW, const Op& op) { | |||||
void conv_bias::conv_direct_stride1_5x5_int8_dot( | |||||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | |||||
const size_t OH, const size_t OW, const Op& op) { | |||||
MEGDNN_MARK_USED_VAR(IH); | |||||
const size_t tail_step = IW - OW; | const size_t tail_step = IW - OW; | ||||
const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; | const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; | ||||
@@ -1777,12 +1770,11 @@ void conv_bias::conv_direct_stride1_5x5_int8_dot(const int8_t* src, | |||||
} | } | ||||
template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
void conv_bias::conv_direct_stride1_7x7_int8_dot(const int8_t* src, | |||||
const int8_t* filter, | |||||
const int32_t* bias, int32_t* temp, | |||||
int8_t* dst, const size_t IH, | |||||
const size_t IW, const size_t OH, | |||||
const size_t OW, const Op& op) { | |||||
void conv_bias::conv_direct_stride1_7x7_int8_dot( | |||||
const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | |||||
const size_t OH, const size_t OW, const Op& op) { | |||||
MEGDNN_MARK_USED_VAR(IH); | |||||
const size_t tail_step = IW - OW; | const size_t tail_step = IW - OW; | ||||
const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; | const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; | ||||
@@ -29,6 +29,7 @@ void copy_packed_src_int8_nchw44<1>(int8_t* dst, const int dst_step, | |||||
const int ih, const int pad_left, | const int ih, const int pad_left, | ||||
const int pad_right, const int pad_top, | const int pad_right, const int pad_top, | ||||
const int pad_bottom) { | const int pad_bottom) { | ||||
MEGDNN_MARK_USED_VAR(pad_right); | |||||
constexpr int IC_PACK_SIZE = 4; | constexpr int IC_PACK_SIZE = 4; | ||||
rep_step(ic_idx, ic, IC_PACK_SIZE) { | rep_step(ic_idx, ic, IC_PACK_SIZE) { | ||||
const int8_t* i_src = src + ic_idx * ic_step; | const int8_t* i_src = src + ic_idx * ic_step; | ||||
@@ -66,6 +67,7 @@ void copy_packed_src_int8_nchw44<2>(int8_t* dst, const int dst_step, | |||||
const int ih, const int pad_left, | const int ih, const int pad_left, | ||||
const int pad_right, const int pad_top, | const int pad_right, const int pad_top, | ||||
const int pad_bottom) { | const int pad_bottom) { | ||||
MEGDNN_MARK_USED_VAR(pad_right); | |||||
constexpr int IC_PACK_SIZE = 4; | constexpr int IC_PACK_SIZE = 4; | ||||
int odd_start = megdnn::div_ceil(dst_step, 2); | int odd_start = megdnn::div_ceil(dst_step, 2); | ||||
bool nochange = pad_left % 2 == 0; | bool nochange = pad_left % 2 == 0; | ||||
@@ -367,4 +369,4 @@ FOR_FILTER(2) | |||||
} // namespace megdnn | } // namespace megdnn | ||||
#endif | #endif | ||||
//vim: syntax=cpp.doxygen | |||||
//vim: syntax=cpp.doxygen |
@@ -163,6 +163,7 @@ static void conv_kern(WorkspaceBundle bundle, | |||||
bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::usable( | bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::usable( | ||||
FallbackConvBiasImpl*, const NCBKernSizeParam& param, | FallbackConvBiasImpl*, const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const { | AlgoSelectionStrategy algo_selection_strategy) const { | ||||
MEGDNN_MARK_USED_VAR(algo_selection_strategy); | |||||
auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
auto FW = fm.spatial[1]; | auto FW = fm.spatial[1]; | ||||
@@ -199,6 +200,7 @@ bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::usable( | |||||
bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::is_preferred( | bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::is_preferred( | ||||
megdnn::fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | megdnn::fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | ||||
MEGDNN_MARK_USED_VAR(param); | |||||
return true; | return true; | ||||
} | } | ||||
@@ -338,4 +340,4 @@ ConvBiasImpl::AlgoDotS8Direct_NCHW44::dispatch_kerns( | |||||
#endif | #endif | ||||
//vim: syntax=cpp.doxygen | |||||
//vim: syntax=cpp.doxygen |
@@ -98,6 +98,7 @@ template <int ow_remain, typename Op, typename T> | |||||
struct StoreOCxOWx<1, ow_remain, Op, T> { | struct StoreOCxOWx<1, ow_remain, Op, T> { | ||||
static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr, | static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr, | ||||
const int ld_dst_oc) { | const int ld_dst_oc) { | ||||
MEGDNN_MARK_USED_VAR(ld_dst_oc); | |||||
switch (ow_remain) { | switch (ow_remain) { | ||||
case 8: | case 8: | ||||
UNROLL_CALL_RAW(4, cb12); | UNROLL_CALL_RAW(4, cb12); | ||||
@@ -337,14 +337,11 @@ ConvBias::WinogradParam ConvBias::parse_winograd_name( | |||||
&(ret.channel_block_size), &(ret.output_block_size), | &(ret.channel_block_size), &(ret.output_block_size), | ||||
&(ret.tile_size)); | &(ret.tile_size)); | ||||
if (strcmp(name, pre.c_str())) { | if (strcmp(name, pre.c_str())) { | ||||
megdnn_log_warn("algo %s is not %s algo", name, pre.c_str()); | |||||
ret = INVALID_WINOGRAD_PARAM; | ret = INVALID_WINOGRAD_PARAM; | ||||
return false; | return false; | ||||
} | } | ||||
if (ret.tile_size == 0 || ret.output_block_size == 0 || | if (ret.tile_size == 0 || ret.output_block_size == 0 || | ||||
ret.channel_block_size == 0) { | ret.channel_block_size == 0) { | ||||
megdnn_log_warn("the algo name %s is not suitable for %s", | |||||
algo_name.c_str(), pre.c_str()); | |||||
ret = INVALID_WINOGRAD_PARAM; | ret = INVALID_WINOGRAD_PARAM; | ||||
return false; | return false; | ||||
} | } | ||||
@@ -28,6 +28,26 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, | |||||
dst[3] = src[3]; | dst[3] = src[3]; | ||||
dst[4] = 4; | dst[4] = 4; | ||||
break; | break; | ||||
case Param::Mode::NCHW_NCHW4_IC_SMALL: | |||||
dst.ndim = 5; | |||||
megdnn_assert(src[1] <= 4_z, "ic should be less equal 4"); | |||||
dst[0] = src[0]; | |||||
dst[1] = div_ceil(src[1], 4_z); | |||||
dst[2] = src[2]; | |||||
dst[3] = src[3]; | |||||
dst[4] = 4; | |||||
break; | |||||
case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: | |||||
megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 4"); | |||||
megdnn_assert(src[1] <= 4_z, "ic should be less equal 4"); | |||||
dst.ndim = 5; | |||||
dst[0] = src[0]; | |||||
dst[1] = div_ceil(src[1], 4_z); | |||||
dst[2] = src[2]; | |||||
dst[3] = src[3]; | |||||
dst[4] = 4; | |||||
break; | |||||
case Param::Mode::NCHW_NCHW88: | case Param::Mode::NCHW_NCHW88: | ||||
dst.ndim = 5; | dst.ndim = 5; | ||||
dst[0] = src[0]; | dst[0] = src[0]; | ||||
@@ -276,6 +296,8 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { | |||||
case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: | case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: | ||||
case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: | case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: | ||||
case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: | case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: | ||||
case Param::Mode::NCHW_NCHW4_IC_SMALL: | |||||
case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: | |||||
CHECK_SRC(DefaultTensorFormat::make()); | CHECK_SRC(DefaultTensorFormat::make()); | ||||
dst = src; | dst = src; | ||||
break; | break; | ||||
@@ -284,6 +306,15 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { | |||||
megdnn_throw("Invalid relayout format mode"); | megdnn_throw("Invalid relayout format mode"); | ||||
break; | break; | ||||
} | } | ||||
if (!dst.is_default() && | |||||
( | |||||
handle()->type() != Handle::HandleType::NAIVE)) { | |||||
megdnn_throw( | |||||
"Only naive and opencl handle support " | |||||
"Image2DPack4TensorFormat, try to export MGB_USE_MEGDNN_DBG=2 " | |||||
"to enable naive handle"); | |||||
} | |||||
#undef CHECK_SRC | #undef CHECK_SRC | ||||
} | } | ||||
@@ -374,6 +405,23 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, | |||||
exec_dst = dst; | exec_dst = dst; | ||||
} | } | ||||
break; | break; | ||||
case Param::Mode::NCHW_NCHW4_IC_SMALL: | |||||
case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: | |||||
// nchw to nchw4c or oihw to oihw4i | |||||
{ | |||||
TensorLayout work_space_layout( | |||||
{src[0], round_up(src[1], 4_z), src[2], src[3]}, | |||||
src.dtype, src.format); | |||||
exec_src = work_space_layout | |||||
.reshape({src[0], div_ceil(src[1], 4_z), 4, | |||||
src[2], src[3]}) | |||||
.dimshuffle({0, 1, 3, 4, 2}); | |||||
exec_dst = dst; | |||||
} | |||||
break; | |||||
case Param::Mode::NCHW_NHWCD4: | case Param::Mode::NCHW_NHWCD4: | ||||
case Param::Mode::NCHW_NHWCD4I: | case Param::Mode::NCHW_NHWCD4I: | ||||
// src is {N, C, H, W} | // src is {N, C, H, W} | ||||
@@ -10,6 +10,7 @@ | |||||
*/ | */ | ||||
#include "src/cuda/convolution/opr_impl.h" | #include "src/cuda/convolution/opr_impl.h" | ||||
#include "megdnn/dtype.h" | |||||
#include "src/cuda/convolution/helper.h" | #include "src/cuda/convolution/helper.h" | ||||
#include "src/cuda/convolution/backward_data/algo.h" | #include "src/cuda/convolution/backward_data/algo.h" | ||||
#include "src/cuda/convolution/backward_filter/algo.h" | #include "src/cuda/convolution/backward_filter/algo.h" | ||||
@@ -28,10 +29,35 @@ using namespace convolution; | |||||
/* ============== ConvolutionForwardImpl ============== */ | /* ============== ConvolutionForwardImpl ============== */ | ||||
ConvolutionForwardImpl::ConvBiasExtraData | ConvolutionForwardImpl::ConvBiasExtraData | ||||
ConvolutionForwardImpl::conv_bias_extra_data(const TensorLayout& dst) { | |||||
ConvolutionForwardImpl::conv_bias_extra_data(const TensorLayout& src, | |||||
const TensorLayout& filter, | |||||
const TensorLayout& dst) { | |||||
auto conv_param = param(); | auto conv_param = param(); | ||||
DType bias_type; | |||||
if (src.dtype.enumv() == DTypeEnum::QuantizedS8) { | |||||
bias_type = dtype::QuantizedS32( | |||||
src.dtype.param<dtype::QuantizedS8>().scale * | |||||
filter.dtype.param<dtype::QuantizedS8>().scale); | |||||
} else if (src.dtype.enumv() == DTypeEnum::Quantized8Asymm) { | |||||
bias_type = dtype::QuantizedS32( | |||||
src.dtype.param<dtype::Quantized8Asymm>().scale * | |||||
filter.dtype.param<dtype::Quantized8Asymm>().scale); | |||||
} else if (src.dtype.enumv() == DTypeEnum::Uint8 || | |||||
src.dtype.enumv() == DTypeEnum::Int8) { | |||||
bias_type = dtype::Int32{}; | |||||
} else if (src.dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||||
bias_type = dtype::QuantizedS32( | |||||
src.dtype.param<dtype::Quantized4Asymm>().scale * | |||||
filter.dtype.param<dtype::Quantized4Asymm>().scale); | |||||
} else { | |||||
megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT); | |||||
bias_type = src.dtype; | |||||
} | |||||
ConvBiasExtraData ret = {this->handle()->create_operator<ConvBiasForward>(), | ConvBiasExtraData ret = {this->handle()->create_operator<ConvBiasForward>(), | ||||
TensorLayout(dst.dtype), TensorLayout(dst.dtype)}; | |||||
TensorLayout(bias_type), TensorLayout(dst.dtype)}; | |||||
ret.convbias_opr->param() = {param::ConvBias::NonlineMode::IDENTITY, | ret.convbias_opr->param() = {param::ConvBias::NonlineMode::IDENTITY, | ||||
conv_param.mode, | conv_param.mode, | ||||
conv_param.sparse, | conv_param.sparse, | ||||
@@ -54,7 +80,7 @@ ConvolutionForwardImpl::get_algorithm_heuristic(const TensorLayout& src, | |||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | bool reproducible) { | ||||
auto extra_data = conv_bias_extra_data(dst); | |||||
auto extra_data = conv_bias_extra_data(src, filter, dst); | |||||
return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get()) | return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get()) | ||||
->get_algorithm_heuristic(src, filter, extra_data.bias_layout, | ->get_algorithm_heuristic(src, filter, extra_data.bias_layout, | ||||
extra_data.z_layout, dst, | extra_data.z_layout, dst, | ||||
@@ -65,7 +91,7 @@ std::vector<ConvolutionForwardImpl::Algorithm*> | |||||
ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src, | ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src, | ||||
const TensorLayout& filter, | const TensorLayout& filter, | ||||
const TensorLayout& dst) { | const TensorLayout& dst) { | ||||
auto extra_data = conv_bias_extra_data(dst); | |||||
auto extra_data = conv_bias_extra_data(src, filter, dst); | |||||
return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get()) | return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get()) | ||||
->get_all_algorithms(src, filter, extra_data.bias_layout, | ->get_all_algorithms(src, filter, extra_data.bias_layout, | ||||
extra_data.z_layout, dst); | extra_data.z_layout, dst); | ||||
@@ -75,7 +101,7 @@ size_t ConvolutionForwardImpl::get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
const PreprocessedFilter* preprocessed_filter) { | const PreprocessedFilter* preprocessed_filter) { | ||||
auto extra_data = conv_bias_extra_data(dst); | |||||
auto extra_data = conv_bias_extra_data(src, filter, dst); | |||||
return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get()) | return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get()) | ||||
->get_workspace_in_bytes( | ->get_workspace_in_bytes( | ||||
src, filter, extra_data.bias_layout, extra_data.z_layout, | src, filter, extra_data.bias_layout, extra_data.z_layout, | ||||
@@ -90,7 +116,8 @@ void ConvolutionForwardImpl::exec(_megdnn_tensor_in src, | |||||
_megdnn_tensor_out dst, | _megdnn_tensor_out dst, | ||||
const PreprocessedFilter* preprocessed_filter, | const PreprocessedFilter* preprocessed_filter, | ||||
_megdnn_workspace workspace) { | _megdnn_workspace workspace) { | ||||
auto extra_data = conv_bias_extra_data(dst.layout); | |||||
auto extra_data = | |||||
conv_bias_extra_data(src.layout, filter.layout, dst.layout); | |||||
TensorND bias(nullptr, extra_data.bias_layout); | TensorND bias(nullptr, extra_data.bias_layout); | ||||
TensorND z(nullptr, extra_data.z_layout); | TensorND z(nullptr, extra_data.z_layout); | ||||
return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get()) | return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get()) | ||||
@@ -61,7 +61,9 @@ class ConvolutionForwardImpl: public ConvolutionForward { | |||||
TensorLayout z_layout; | TensorLayout z_layout; | ||||
}; | }; | ||||
private: | private: | ||||
ConvBiasExtraData conv_bias_extra_data(const TensorLayout&); | |||||
ConvBiasExtraData conv_bias_extra_data(const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&); | |||||
}; | }; | ||||
class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { | class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { | ||||
@@ -32,7 +32,7 @@ void create_param(const DeformablePSROIPoolingBase* opr, | |||||
p.sample_per_part = param.sample_per_part; | p.sample_per_part = param.sample_per_part; | ||||
p.trans_std = param.trans_std; | p.trans_std = param.trans_std; | ||||
p.scale = param.spatial_scale; | p.scale = param.spatial_scale; | ||||
p.nr_cls = p.no_trans ? 1 : trans[0]; | |||||
p.nr_cls = p.no_trans ? 1 : trans[1] / 2; | |||||
p.nr_bbox = rois[0]; | p.nr_bbox = rois[0]; | ||||
p.IC = data[1]; | p.IC = data[1]; | ||||
p.IH = data[2]; | p.IH = data[2]; | ||||
@@ -11,6 +11,7 @@ | |||||
#include "src/cuda/relayout_format/opr_impl.h" | #include "src/cuda/relayout_format/opr_impl.h" | ||||
#include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
#include "src/cuda/utils.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -20,15 +21,22 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
auto src_dtype = src.layout.dtype; | auto src_dtype = src.layout.dtype; | ||||
megdnn_assert( | megdnn_assert( | ||||
param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 || | param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 || | ||||
param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4, | |||||
param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4 || | |||||
param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL || | |||||
param().mode == | |||||
Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT, | |||||
"relayout format of cuda only support NCHW4->CHWN4 or " | "relayout format of cuda only support NCHW4->CHWN4 or " | ||||
"CHWN4->NCHW4"); | |||||
if (src_dtype.enumv() == DTypeEnum::QuantizedS8) { | |||||
"CHWN4->NCHW4 or NCHW->NCHW4"); | |||||
if ((param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 || | |||||
param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4) && | |||||
src_dtype.enumv() == DTypeEnum::QuantizedS8) { | |||||
size_t row = 0, col = 0; | size_t row = 0, col = 0; | ||||
if (param().mode == Param::RelayoutFormat::Mode::NCHW4_CHWN4) { | if (param().mode == Param::RelayoutFormat::Mode::NCHW4_CHWN4) { | ||||
row = src.layout[0], | row = src.layout[0], | ||||
col = src.layout[1] * src.layout[2] * src.layout[3]; | col = src.layout[1] * src.layout[2] * src.layout[3]; | ||||
} else { | } else { | ||||
megdnn_assert(param().mode == | |||||
param::RelayoutFormat::Mode::CHWN4_NCHW4); | |||||
row = src.layout[0] * src.layout[1] * src.layout[2], | row = src.layout[0] * src.layout[1] * src.layout[2], | ||||
col = src.layout[3]; | col = src.layout[3]; | ||||
} | } | ||||
@@ -43,6 +51,27 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
return handle()->create_operator<RelayoutForward>()->exec(trans_in, | return handle()->create_operator<RelayoutForward>()->exec(trans_in, | ||||
trans_out); | trans_out); | ||||
} | } | ||||
if ((param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL || | |||||
param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) && | |||||
src.layout[1] % 4 != 0) { | |||||
megdnn_assert(src.raw_ptr != dst.raw_ptr && src.layout.ndim == 4, | |||||
"The mode of NCHW_NCHW4 and NCHW_NCHW4_CONV_DENSE_WEIGHT " | |||||
"of RelayoutFormat opr(cuda backend) does not support " | |||||
"src.ptr == dst.ptr"); | |||||
megdnn_assert(src.layout[1] <= 4); | |||||
cuda_check(cudaMemsetAsync(dst.raw_ptr, 0, | |||||
dst.layout.span().dist_byte(), | |||||
cuda_stream(this->handle()))); | |||||
TensorLayout exec_dst_layout = dst.layout; | |||||
exec_dst_layout[4] = src.layout[1]; | |||||
TensorLayout exec_src_layout = | |||||
src.layout | |||||
.reshape({src.layout[0], src.layout[1], 1, | |||||
src.layout[2], src.layout[3]}) | |||||
.dimshuffle({0, 2, 3, 4, 1}); | |||||
return handle()->create_operator<RelayoutForward>()->exec( | |||||
{src.raw_ptr, exec_src_layout}, {dst.raw_ptr, exec_dst_layout}); | |||||
} | |||||
TensorLayout exec_src, exec_dst; | TensorLayout exec_src, exec_dst; | ||||
deduce_exec_layout(src.layout, dst.layout, exec_src, exec_dst); | deduce_exec_layout(src.layout, dst.layout, exec_src, exec_dst); | ||||
TensorND exec_src_nd{src.raw_ptr, exec_src}; | TensorND exec_src_nd{src.raw_ptr, exec_src}; | ||||
@@ -293,7 +293,7 @@ void Fwd::exec(_megdnn_tensor_in data, _megdnn_tensor_in rois, | |||||
float trans_std = param.trans_std, scale = param.spatial_scale; | float trans_std = param.trans_std, scale = param.spatial_scale; | ||||
size_t nr_bbox = rois.layout[0]; | size_t nr_bbox = rois.layout[0]; | ||||
size_t nr_cls = no_trans ? 1 : trans.layout[0]; | |||||
size_t nr_cls = no_trans ? 1 : trans.layout[1] / 2; | |||||
size_t IC = data.layout[1], IH = data.layout[2], IW = data.layout[3]; | size_t IC = data.layout[1], IH = data.layout[2], IW = data.layout[3]; | ||||
const float* data_ptr = data.ptr<float>(); | const float* data_ptr = data.ptr<float>(); | ||||
@@ -339,7 +339,7 @@ void Bwd::exec(_megdnn_tensor_in data, _megdnn_tensor_in rois, | |||||
float trans_std = param.trans_std, scale = param.spatial_scale; | float trans_std = param.trans_std, scale = param.spatial_scale; | ||||
size_t nr_bbox = rois.layout[0]; | size_t nr_bbox = rois.layout[0]; | ||||
size_t nr_cls = no_trans ? 1 : trans.layout[0]; | |||||
size_t nr_cls = no_trans ? 1 : trans.layout[1] / 2; | |||||
size_t IC = data.layout[1], IH = data.layout[2], IW = data.layout[3]; | size_t IC = data.layout[1], IH = data.layout[2], IW = data.layout[3]; | ||||
const float* data_ptr = data.ptr<float>(); | const float* data_ptr = data.ptr<float>(); | ||||
@@ -107,11 +107,7 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t computing_handle, | |||||
m_dispatcher{megcoreGetCPUDispatcher(computing_handle)} {} | m_dispatcher{megcoreGetCPUDispatcher(computing_handle)} {} | ||||
size_t HandleImpl::image2d_pitch_alignment() const { | size_t HandleImpl::image2d_pitch_alignment() const { | ||||
if (type() == Handle::HandleType::NAIVE) { | |||||
// only naive CPU handle supports this format | |||||
return g_image2d_pitch_alignment; | |||||
} | |||||
megdnn_throw("Image2DTensorFormat is not supported on this handle"); | |||||
return g_image2d_pitch_alignment; | |||||
} | } | ||||
size_t HandleImpl::exchange_image2d_pitch_alignment(size_t alignment) { | size_t HandleImpl::exchange_image2d_pitch_alignment(size_t alignment) { | ||||
@@ -370,65 +370,67 @@ void pooling_backward_max_impl(const ctype* __restrict src, | |||||
} | } | ||||
} | } | ||||
} // anonymous namespace | |||||
} // namespace | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace naive { | namespace naive { | ||||
void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | ||||
_megdnn_workspace workspace) { | _megdnn_workspace workspace) { | ||||
MIDOUT_BEGIN(megdnn_naive_pooling) { | |||||
check_exec(src.layout, dst.layout, workspace.size); | |||||
size_t c_pos, spatial_pos, batch_pos = 0; | |||||
if (param().format == Param::Format::NCHW || | |||||
param().format == Param::Format::NCHW4 || | |||||
param().format == Param::Format::NCHW88 || | |||||
param().format == Param::Format::NCHW44 || | |||||
param().format == Param::Format::NCHW32) { | |||||
c_pos = 1; | |||||
spatial_pos = 2; | |||||
} else if (param().format == Param::Format::NHWC) { | |||||
c_pos = 3; | |||||
spatial_pos = 1; | |||||
} else if (param().format == Param::Format::CHWN4) { | |||||
c_pos = 0; | |||||
spatial_pos = 1; | |||||
batch_pos = 3; | |||||
} else { | |||||
megdnn_assert(param().format == Param::Format::NHWCD4); | |||||
c_pos = 2; | |||||
spatial_pos = 1; | |||||
} | |||||
size_t N = src.layout.shape[batch_pos], C = src.layout.shape[c_pos], | |||||
IH = src.layout.shape[spatial_pos + 0], | |||||
IW = src.layout.shape[spatial_pos + 1]; | |||||
size_t OH = dst.layout.shape[spatial_pos + 0], | |||||
OW = dst.layout.shape[spatial_pos + 1]; | |||||
if (param().format == Param::Format::NHWCD4) { | |||||
C *= 4; | |||||
IW = src.layout.shape[spatial_pos + 2]; | |||||
OW = dst.layout.shape[spatial_pos + 2]; | |||||
} | |||||
if (param().format == Param::Format::NCHW4 || | |||||
param().format == Param::Format::NCHW44 || | |||||
param().format == Param::Format::CHWN4) { | |||||
C *= 4; | |||||
} | |||||
if (param().format == Param::Format::NCHW88) { | |||||
C *= 8; | |||||
} | |||||
if (param().format == Param::Format::NCHW32) { | |||||
C *= 32; | |||||
} | |||||
size_t PH = param().pad_h, PW = param().pad_w; | |||||
size_t FH = param().window_h, FW = param().window_w; | |||||
size_t SH = param().stride_h, SW = param().stride_w; | |||||
#define DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, IdxGetter) \ | |||||
MEGDNN_DISPATCH_CPU_KERN( \ | |||||
static_cast<naive::HandleImpl*>(handle()), \ | |||||
pooling_forward_impl<Pooler MEGDNN_COMMA IdxGetter>( \ | |||||
sptr, dptr, src.layout.dtype, N, C, IH, IW, OH, OW, PH, \ | |||||
PW, SH, SW, FH, FW)); | |||||
check_exec(src.layout, dst.layout, workspace.size); | |||||
size_t c_pos, spatial_pos, batch_pos = 0; | |||||
if (param().format == Param::Format::NCHW || | |||||
param().format == Param::Format::NCHW4 || | |||||
param().format == Param::Format::NCHW88 || | |||||
param().format == Param::Format::NCHW44 || | |||||
param().format == Param::Format::NCHW32) { | |||||
c_pos = 1; | |||||
spatial_pos = 2; | |||||
} else if (param().format == Param::Format::NHWC) { | |||||
c_pos = 3; | |||||
spatial_pos = 1; | |||||
} else if (param().format == Param::Format::CHWN4) { | |||||
c_pos = 0; | |||||
spatial_pos = 1; | |||||
batch_pos = 3; | |||||
} else { | |||||
megdnn_assert(param().format == Param::Format::NHWCD4); | |||||
c_pos = 2; | |||||
spatial_pos = 1; | |||||
} | |||||
size_t N = src.layout.shape[batch_pos], C = src.layout.shape[c_pos], | |||||
IH = src.layout.shape[spatial_pos + 0], | |||||
IW = src.layout.shape[spatial_pos + 1]; | |||||
size_t OH = dst.layout.shape[spatial_pos + 0], | |||||
OW = dst.layout.shape[spatial_pos + 1]; | |||||
if (param().format == Param::Format::NHWCD4) { | |||||
C *= 4; | |||||
IW = src.layout.shape[spatial_pos + 2]; | |||||
OW = dst.layout.shape[spatial_pos + 2]; | |||||
} | |||||
if (param().format == Param::Format::NCHW4 || | |||||
param().format == Param::Format::NCHW44 || | |||||
param().format == Param::Format::CHWN4) { | |||||
C *= 4; | |||||
} | |||||
if (param().format == Param::Format::NCHW88) { | |||||
C *= 8; | |||||
} | |||||
if (param().format == Param::Format::NCHW32) { | |||||
C *= 32; | |||||
} | |||||
size_t PH = param().pad_h, PW = param().pad_w; | |||||
size_t FH = param().window_h, FW = param().window_w; | |||||
size_t SH = param().stride_h, SW = param().stride_w; | |||||
#define DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, IdxGetter) \ | |||||
MIDOUT_BEGIN(megdnn_naive_pooling, midout_iv(#Pooler #IdxGetter##_hash)) { \ | |||||
MEGDNN_DISPATCH_CPU_KERN( \ | |||||
static_cast<naive::HandleImpl*>(handle()), \ | |||||
pooling_forward_impl<Pooler MEGDNN_COMMA IdxGetter>( \ | |||||
sptr, dptr, src.layout.dtype, N, C, IH, IW, OH, OW, \ | |||||
PH, PW, SH, SW, FH, FW)); \ | |||||
} \ | |||||
MIDOUT_END(); | |||||
#define DISPATCH_WITH_POOLER(Pooler) \ | #define DISPATCH_WITH_POOLER(Pooler) \ | ||||
switch (param().format) { \ | switch (param().format) { \ | ||||
@@ -484,14 +486,12 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
} \ | } \ | ||||
} \ | } \ | ||||
} | } | ||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||||
#undef cb | #undef cb | ||||
#undef DISPATCH_WITH_POOLER_AND_IDX_GETTER | #undef DISPATCH_WITH_POOLER_AND_IDX_GETTER | ||||
#undef DISPATCH_WITH_POOLER | #undef DISPATCH_WITH_POOLER | ||||
megdnn_assert_internal(0); | |||||
} | |||||
MIDOUT_END(); | |||||
megdnn_assert_internal(0); | |||||
} | } | ||||
WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle( | WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle( | ||||
@@ -14,6 +14,10 @@ | |||||
#include "megdnn/tensor_iter.h" | #include "megdnn/tensor_iter.h" | ||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_naive_relayout_format) | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace naive; | using namespace naive; | ||||
@@ -79,6 +83,7 @@ void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src, | |||||
} | } | ||||
cb(Float32, dt_float32); | cb(Float32, dt_float32); | ||||
cb(QuantizedS8, dt_qint8); | |||||
default: | default: | ||||
megdnn_assert(0); | megdnn_assert(0); | ||||
#undef cb | #undef cb | ||||
@@ -138,7 +143,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
return n * c * h * w * src.dtype.size(); | return n * c * h * w * src.dtype.size(); | ||||
} | } | ||||
case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: { | case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: { | ||||
megdnn_assert(src.ndim == 4, "src must be oihw ,nmdim == 5"); | |||||
megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 5"); | |||||
megdnn_assert(src[0] % 8 == 0, | megdnn_assert(src[0] % 8 == 0, | ||||
"NCHW_NCHW88_CONV_DENSE_WEIGHT oc must align to 8"); | "NCHW_NCHW88_CONV_DENSE_WEIGHT oc must align to 8"); | ||||
if (src[1] % 8 == 0) | if (src[1] % 8 == 0) | ||||
@@ -150,7 +155,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
return oc * ic * h * w * src.dtype.size(); | return oc * ic * h * w * src.dtype.size(); | ||||
} | } | ||||
case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: { | case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: { | ||||
megdnn_assert(src.ndim == 5, "src must be goihw ,nmdim == 5"); | |||||
megdnn_assert(src.ndim == 5, "src must be goihw, ndim == 5"); | |||||
megdnn_assert(src[1] % 8 == 0, | megdnn_assert(src[1] % 8 == 0, | ||||
"NCHW_NCHW88_CONV_CHAN_WEIGHT oc per group must " | "NCHW_NCHW88_CONV_CHAN_WEIGHT oc per group must " | ||||
"align to 8"); | "align to 8"); | ||||
@@ -164,7 +169,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
return group * ocpg * icpg * h * w * src.dtype.size(); | return group * ocpg * icpg * h * w * src.dtype.size(); | ||||
} | } | ||||
case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: { | case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: { | ||||
megdnn_assert(src.ndim == 5, "src must be goihw ,nmdim == 5"); | |||||
megdnn_assert(src.ndim == 5, "src must be goihw, ndim == 5"); | |||||
if (src[0] % 8 == 0) | if (src[0] % 8 == 0) | ||||
return 0; | return 0; | ||||
size_t group = round_up(src[0], 8_z); | size_t group = round_up(src[0], 8_z); | ||||
@@ -174,6 +179,27 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
size_t w = src[4]; | size_t w = src[4]; | ||||
return group * ocpg * icpg * h * w * src.dtype.size(); | return group * ocpg * icpg * h * w * src.dtype.size(); | ||||
} | } | ||||
case Param::Mode::NCHW_NCHW4_IC_SMALL: { | |||||
if (src[1] % 4 == 0) | |||||
return 0; | |||||
size_t n = src[0]; | |||||
size_t c = round_up(src[1], 4_z); | |||||
size_t h = src[2]; | |||||
size_t w = src[3]; | |||||
return n * c * h * w * src.dtype.size(); | |||||
} | |||||
case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: { | |||||
megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 5"); | |||||
if (src[1] % 4 == 0) | |||||
return 0; | |||||
size_t oc = src[0]; | |||||
size_t ic = round_up(src[1], 4_z); | |||||
size_t h = src[2]; | |||||
size_t w = src[3]; | |||||
return oc * ic * h * w * src.dtype.size(); | |||||
} | |||||
default: | default: | ||||
return 0; | return 0; | ||||
} | } | ||||
@@ -200,14 +226,18 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
//! ic % 4 != 0 | //! ic % 4 != 0 | ||||
if ((IC & 0x3)) { | if ((IC & 0x3)) { | ||||
switch (src.layout.dtype.enumv()) { | switch (src.layout.dtype.enumv()) { | ||||
#define cb(name, ctype) \ | |||||
case (DTypeEnum::name): { \ | |||||
ctype* sptr = src.compatible_ptr<ctype>(); \ | |||||
ctype* dptr = workspace.ptr<ctype>(); \ | |||||
MEGDNN_DISPATCH_CPU_KERN( \ | |||||
m_handle, \ | |||||
padding_src_to_workspace<ctype>(dptr, sptr, N, IC, IH, IW);); \ | |||||
break; \ | |||||
#define cb(name, ctype) \ | |||||
case (DTypeEnum::name): { \ | |||||
MIDOUT_BEGIN(megdnn_naive_relayout_format, ctype, \ | |||||
midout_iv(Param::Mode::NCHW_NHWCD4I)) { \ | |||||
ctype* sptr = src.compatible_ptr<ctype>(); \ | |||||
ctype* dptr = workspace.ptr<ctype>(); \ | |||||
MEGDNN_DISPATCH_CPU_KERN( \ | |||||
m_handle, padding_src_to_workspace<ctype>(dptr, sptr, N, \ | |||||
IC, IH, IW);); \ | |||||
} \ | |||||
MIDOUT_END(); \ | |||||
break; \ | |||||
} | } | ||||
cb(Float32, dt_float32); | cb(Float32, dt_float32); | ||||
MEGDNN_INC_FLOAT16(cb(Float16, dt_float16)); | MEGDNN_INC_FLOAT16(cb(Float16, dt_float16)); | ||||
@@ -226,14 +256,18 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
size_t FW = src.layout[3]; | size_t FW = src.layout[3]; | ||||
if ((IC & 0x3)) { | if ((IC & 0x3)) { | ||||
switch (src.layout.dtype.enumv()) { | switch (src.layout.dtype.enumv()) { | ||||
#define cb(name, ctype) \ | |||||
case (DTypeEnum::name): { \ | |||||
ctype* sptr = src.compatible_ptr<ctype>(); \ | |||||
ctype* dptr = workspace.ptr<ctype>(); \ | |||||
MEGDNN_DISPATCH_CPU_KERN( \ | |||||
m_handle, padding_filter_to_workspace<ctype>(dptr, sptr, OC, \ | |||||
IC, FH, FW);); \ | |||||
break; \ | |||||
#define cb(name, ctype) \ | |||||
case (DTypeEnum::name): { \ | |||||
MIDOUT_BEGIN(megdnn_naive_relayout_format, ctype, \ | |||||
midout_iv(Param::Mode::INTER_WEIGHT_DENSEI_DOT)) { \ | |||||
ctype* sptr = src.compatible_ptr<ctype>(); \ | |||||
ctype* dptr = workspace.ptr<ctype>(); \ | |||||
MEGDNN_DISPATCH_CPU_KERN(m_handle, \ | |||||
padding_filter_to_workspace<ctype>( \ | |||||
dptr, sptr, OC, IC, FH, FW);); \ | |||||
} \ | |||||
MIDOUT_END(); \ | |||||
break; \ | |||||
} | } | ||||
cb(Quantized8Asymm, dt_uint8); | cb(Quantized8Asymm, dt_uint8); | ||||
cb(QuantizedS8, dt_int8); | cb(QuantizedS8, dt_int8); | ||||
@@ -244,33 +278,35 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
exec_src_nd.raw_ptr = workspace.raw_ptr; | exec_src_nd.raw_ptr = workspace.raw_ptr; | ||||
} | } | ||||
} else if (param().mode == Param::Mode::NCHW_NCHW88) { | } else if (param().mode == Param::Mode::NCHW_NCHW88) { | ||||
size_t ic = src.layout[1]; | |||||
if (ic % 8 != 0) { | |||||
padding_to_workspace({workspace.raw_ptr, exec_src}, src, 1, 8); | |||||
exec_src_nd.raw_ptr = workspace.raw_ptr; | |||||
} | |||||
#define cb(_idx, _pack_size, _mode) \ | |||||
MIDOUT_BEGIN(megdnn_naive_relayout_format, \ | |||||
midout_iv(Param::Mode::_mode)) { \ | |||||
size_t val = src.layout[_idx]; \ | |||||
if (val % _pack_size != 0) { \ | |||||
padding_to_workspace({workspace.raw_ptr, exec_src}, src, _idx, \ | |||||
_pack_size); \ | |||||
exec_src_nd.raw_ptr = workspace.raw_ptr; \ | |||||
} \ | |||||
} \ | |||||
MIDOUT_END(); | |||||
cb(1, 8, NCHW_NCHW88); | |||||
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT) { | } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT) { | ||||
megdnn_assert(src.layout[0] % 8 == 0); | megdnn_assert(src.layout[0] % 8 == 0); | ||||
size_t ic = src.layout[1]; | |||||
if (ic % 8 != 0) { | |||||
padding_to_workspace({workspace.raw_ptr, exec_src}, src, 1, 8_z); | |||||
exec_src_nd.raw_ptr = workspace.raw_ptr; | |||||
} | |||||
cb(1, 8, NCHW_NCHW88_CONV_DENSE_WEIGHT); | |||||
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT) { | } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT) { | ||||
size_t group = src.layout[0]; | |||||
if (group % 8 != 0) { | |||||
padding_to_workspace({workspace.raw_ptr, exec_src}, src, 0, 8_z); | |||||
exec_src_nd.raw_ptr = workspace.raw_ptr; | |||||
} | |||||
cb(0, 8, NCHW_NCHW88_CONV_CHAN_WEIGHT); | |||||
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT) { | } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT) { | ||||
megdnn_assert(src.layout[1] % 8 == 0); | megdnn_assert(src.layout[1] % 8 == 0); | ||||
size_t ic = src.layout[2]; | |||||
if (ic % 8 != 0) { | |||||
padding_to_workspace({workspace.raw_ptr, exec_src}, src, 2, 8_z); | |||||
exec_src_nd.raw_ptr = workspace.raw_ptr; | |||||
} | |||||
cb(2, 8, NCHW_NCHW88_CONV_GROUP_WEIGHT); | |||||
} else if (param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL) { | |||||
cb(1, 4, NCHW_NCHW4_IC_SMALL); | |||||
} else if (param().mode == | |||||
Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) { | |||||
cb(1, 4, NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT); | |||||
} | } | ||||
m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); | m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); | ||||
#undef cb | |||||
} | } | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -8,6 +8,7 @@ | |||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "megdnn/dtype.h" | |||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "test/common/checker.h" | #include "test/common/checker.h" | ||||
#include "test/common/rng.h" | #include "test/common/rng.h" | ||||
@@ -30,4 +31,25 @@ TEST_F(CUDA, RELAYOUT_FORMAT) { | |||||
checker.execs({{22, 23, 24, 25, 4}, {}}); | checker.execs({{22, 23, 24, 25, 4}, {}}); | ||||
} | } | ||||
TEST_F(CUDA, RELAYOUT_FORMAT_NCHW4) { | |||||
Checker<RelayoutFormat> checker(handle_cuda()); | |||||
UniformIntRNG rng{-50, 50}; | |||||
param::RelayoutFormat param; | |||||
param.mode = param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL; | |||||
for (DType dtype : | |||||
std::vector<DType>({dtype::QuantizedS8{0.1f}, dtype::Float32{}})) { | |||||
checker.set_dtype(0, dtype).set_rng(0, &rng); | |||||
checker.set_param(param).execs({{2, 4, 35, 36}, {}}); | |||||
checker.set_param(param).execs({{2, 3, 35, 36}, {}}); | |||||
checker.set_param(param).execs({{2, 1, 35, 36}, {}}); | |||||
param.mode = param::RelayoutFormat::Mode:: | |||||
NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT; | |||||
checker.set_param(param).execs({{4, 3, 3, 3}, {}}); | |||||
checker.set_param(param).execs({{4, 4, 3, 3}, {}}); | |||||
checker.set_param(param).execs({{1, 4, 3, 3}, {}}); | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -25,7 +25,7 @@ from . import config, craniotome, dtype | |||||
from . import global_init as _global_init | from . import global_init as _global_init | ||||
from . import helper as _helper | from . import helper as _helper | ||||
from . import mgb as _detail | from . import mgb as _detail | ||||
from . import opr, opr_param_defs, plugin | |||||
from . import opr, opr_extra, opr_param_defs, plugin | |||||
from .exc import MegBrainError | from .exc import MegBrainError | ||||
from .logconf import get_logger | from .logconf import get_logger | ||||
from .mgb import ( | from .mgb import ( | ||||
@@ -0,0 +1,3 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# Copyright (c) 2015-2019 Megvii Inc. All rights reserved. | |||||
@@ -154,6 +154,7 @@ class Function(metaclass=ABCMeta): | |||||
memo[id(self)] = result | memo[id(self)] = result | ||||
for k, v in self.__dict__.items(): | for k, v in self.__dict__.items(): | ||||
setattr(result, k, copy.deepcopy(v, memo)) | setattr(result, k, copy.deepcopy(v, memo)) | ||||
setattr(result, "saved_tensors", tmp) | |||||
self.saved_tensors = tmp | self.saved_tensors = tmp | ||||
return result | return result | ||||
@@ -235,6 +235,14 @@ class Tensor: | |||||
return self.__val.dtype | return self.__val.dtype | ||||
return self._symvar.dtype | return self._symvar.dtype | ||||
def set_dtype(self, dtype: str = None): | |||||
r"""Set the data type of the tensor. | |||||
""" | |||||
if self.__val is not None: | |||||
self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy()) | |||||
elif self.__sym is not None: | |||||
self.__sym = self.__sym.astype(dtype) | |||||
@property | @property | ||||
def _comp_node(self): | def _comp_node(self): | ||||
if self.__val is not None: | if self.__val is not None: | ||||
@@ -26,7 +26,7 @@ def _clear_plasma_store(): | |||||
# `_PlasmaStoreManager.__del__` will not be called automaticly in subprocess, | # `_PlasmaStoreManager.__del__` will not be called automaticly in subprocess, | ||||
# so this function should be called explicitly | # so this function should be called explicitly | ||||
global MGE_PLASMA_STORE_MANAGER | global MGE_PLASMA_STORE_MANAGER | ||||
if MGE_PLASMA_STORE_MANAGER is not None: | |||||
if MGE_PLASMA_STORE_MANAGER is not None and MGE_PLASMA_STORE_MANAGER.refcount == 0: | |||||
del MGE_PLASMA_STORE_MANAGER | del MGE_PLASMA_STORE_MANAGER | ||||
MGE_PLASMA_STORE_MANAGER = None | MGE_PLASMA_STORE_MANAGER = None | ||||
@@ -50,6 +50,7 @@ class _PlasmaStoreManager: | |||||
stderr=None if debug_flag else subprocess.DEVNULL, | stderr=None if debug_flag else subprocess.DEVNULL, | ||||
) | ) | ||||
self.__initialized = True | self.__initialized = True | ||||
self.refcount = 1 | |||||
def __del__(self): | def __del__(self): | ||||
if self.__initialized and self.plasma_store.returncode is None: | if self.__initialized and self.plasma_store.returncode is None: | ||||
@@ -83,6 +84,8 @@ class PlasmaShmQueue: | |||||
"Exception happened in starting plasma_store: {}\n" | "Exception happened in starting plasma_store: {}\n" | ||||
"Tips: {}".format(str(e), err_info) | "Tips: {}".format(str(e), err_info) | ||||
) | ) | ||||
else: | |||||
MGE_PLASMA_STORE_MANAGER.refcount += 1 | |||||
self.socket_name = MGE_PLASMA_STORE_MANAGER.socket_name | self.socket_name = MGE_PLASMA_STORE_MANAGER.socket_name | ||||
@@ -133,6 +136,8 @@ class PlasmaShmQueue: | |||||
def close(self): | def close(self): | ||||
self.queue.close() | self.queue.close() | ||||
self.disconnect_client() | self.disconnect_client() | ||||
global MGE_PLASMA_STORE_MANAGER | |||||
MGE_PLASMA_STORE_MANAGER.refcount -= 1 | |||||
_clear_plasma_store() | _clear_plasma_store() | ||||
def cancel_join_thread(self): | def cancel_join_thread(self): | ||||
@@ -44,7 +44,7 @@ def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor | |||||
ret = mgb.opr.matrix_mul(inp, weight, transposeB=True) | ret = mgb.opr.matrix_mul(inp, weight, transposeB=True) | ||||
ret = ret.reshape(orig_shape[:-1], weight.shape[0]) | ret = ret.reshape(orig_shape[:-1], weight.shape[0]) | ||||
if bias is not None: | if bias is not None: | ||||
ret += bias | |||||
ret += bias.reshape(1, bias.shape[0]) | |||||
return ret | return ret | ||||
@@ -442,17 +442,38 @@ class trace: | |||||
Serialize trace to file system. | Serialize trace to file system. | ||||
:param fpath: positional only argument. Path of output file. | :param fpath: positional only argument. Path of output file. | ||||
:param arg_names: names of the input tensors in the traced function | |||||
:param append: whether output is appended to ``fpath`` | |||||
:param f16_io_f32_comp: whether to use float16 for I/O between oprs and use | |||||
:param arg_names: names of the input tensors in the traced function. | |||||
:param append: whether output is appended to ``fpath``. | |||||
:param optimize_for_inference: whether to enable optimize_for_inference | |||||
pass before dump. | |||||
:param enable_io16xc32: whether to use float16 for I/O between oprs and use | |||||
float32 as internal computation precision. Note the output var would be | float32 as internal computation precision. Note the output var would be | ||||
changed to float16 | |||||
:param f16_io_comp: whether to use float16 for both I/O and computation | |||||
precision | |||||
:param use_nhwcd4: whether to use NHWCD4 data format. This is faster on some | |||||
OpenCL devices | |||||
:param fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||||
into one opr. This is supported only in NHWCD4 format. | |||||
changed to float16. | |||||
:param enable_ioc16: whether to use float16 for both I/O and computation | |||||
precision. | |||||
:param enable_hwcd4: whether to use NHWCD4 data layout. This is faster on some | |||||
OpenCL backend. | |||||
:param enable_nchw88: whether to use NCHW4 data layout. it currently | |||||
used in X86 AVX backend. | |||||
:param enable_nchw44: whether to use NCHW4 data layout. it currently | |||||
used in arm backend. | |||||
:param enable_nchw44_dot: whether to use NCHW4 data layout. it currently | |||||
used in armv8.2+dotprod backend. | |||||
:param enable_nchw4: whether to use NCHW4 data layout. it currently | |||||
used in nvidia backend(based on cudnn). | |||||
:param enable_nchw32 whether to use NCHW32 data layout. it currently | |||||
used in nvidia backend with tensorcore(based on cudnn). | |||||
:param enable_chwn4 whether to use CHWN4 data layout. it currently | |||||
used in nvidia backend with tensorcore. | |||||
:param enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||||
into one opr. | |||||
:param enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z | |||||
input for inference on nvidia backend(this optimization pass will | |||||
result in mismatch of the precision of output of training and | |||||
inference) | |||||
""" | """ | ||||
if self._status != self._FINISHED: | if self._status != self._FINISHED: | ||||
raise ValueError("not traced") | raise ValueError("not traced") | ||||
@@ -475,6 +496,7 @@ class trace: | |||||
"enable_nchw88": "use_nchw88", | "enable_nchw88": "use_nchw88", | ||||
"enable_nchw32": "use_nchw32", | "enable_nchw32": "use_nchw32", | ||||
"enable_nchw44": "use_nchw44", | "enable_nchw44": "use_nchw44", | ||||
"enable_nchw44_dot": "use_nchw44_dot", | |||||
"enable_chwn4": "use_chwn4", | "enable_chwn4": "use_chwn4", | ||||
"enable_fuse_conv_bias_nonlinearity": "fuse_conv_bias_nonlinearity", | "enable_fuse_conv_bias_nonlinearity": "fuse_conv_bias_nonlinearity", | ||||
"enable_fuse_conv_bias_with_z": "fuse_conv_bias_with_z", | "enable_fuse_conv_bias_with_z": "fuse_conv_bias_with_z", | ||||
@@ -11,6 +11,7 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | |||||
import numpy as np | import numpy as np | ||||
from .._internal.dtype import is_quantize | |||||
from ..core import Buffer, Parameter, Tensor | from ..core import Buffer, Parameter, Tensor | ||||
from ..logger import get_logger | from ..logger import get_logger | ||||
@@ -460,6 +461,10 @@ class Module(metaclass=ABCMeta): | |||||
), "param `{}` shape mismatch, should be {}, get {}".format( | ), "param `{}` shape mismatch, should be {}, get {}".format( | ||||
k, var.shape, to_be_load.shape | k, var.shape, to_be_load.shape | ||||
) | ) | ||||
# For quantized dtype, the initialized dtype | |||||
# scale/zero_points maybe invalid, use pretrained dtype instead. | |||||
if is_quantize(to_be_load.dtype) and is_quantize(var.dtype): | |||||
var.set_dtype(to_be_load.dtype) | |||||
var.set_value(to_be_load) | var.set_value(to_be_load) | ||||
loaded.append(k) | loaded.append(k) | ||||
@@ -37,15 +37,14 @@ class QATModule(Module): | |||||
Set quantization related configs with ``qconfig``, including | Set quantization related configs with ``qconfig``, including | ||||
observer and fake_quant for weight and activation. | observer and fake_quant for weight and activation. | ||||
""" | """ | ||||
self.weight_observer = qconfig.weight_observer() | |||||
self.act_observer = qconfig.act_observer() | |||||
if qconfig.fake_quant is None: | |||||
self.weight_fake_quant = None | |||||
self.act_fake_quant = None | |||||
else: | |||||
self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype) | |||||
self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype) | |||||
def safe_call(func): | |||||
return func() if func is not None else None | |||||
self.weight_observer = safe_call(qconfig.weight_observer) | |||||
self.act_observer = safe_call(qconfig.act_observer) | |||||
self.weight_fake_quant = safe_call(qconfig.weight_fake_quant) | |||||
self.act_fake_quant = safe_call(qconfig.act_fake_quant) | |||||
def _apply_fakequant_with_observer( | def _apply_fakequant_with_observer( | ||||
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer | self, target: Tensor, fake_quant: FakeQuantize, observer: Observer | ||||
@@ -77,13 +76,19 @@ class QATModule(Module): | |||||
r""" | r""" | ||||
Get weight's quantization dtype as the method from ``qconfig``. | Get weight's quantization dtype as the method from ``qconfig``. | ||||
""" | """ | ||||
return self.weight_observer.get_dtype() | |||||
if hasattr(self.act_fake_quant, "get_dtype"): | |||||
return self.weight_fake_quant.get_dtype() | |||||
else: | |||||
return self.weight_observer.get_dtype() | |||||
def get_activation_dtype(self): | def get_activation_dtype(self): | ||||
r""" | r""" | ||||
Get activation's quantization dtype as the method from ``qconfig``. | Get activation's quantization dtype as the method from ``qconfig``. | ||||
""" | """ | ||||
return self.act_observer.get_dtype() | |||||
if hasattr(self.act_fake_quant, "get_dtype"): | |||||
return self.act_fake_quant.get_dtype() | |||||
else: | |||||
return self.act_observer.get_dtype() | |||||
@classmethod | @classmethod | ||||
@abstractmethod | @abstractmethod | ||||
@@ -12,4 +12,5 @@ from .qconfig import ( | |||||
calibration_qconfig, | calibration_qconfig, | ||||
ema_fakequant_qconfig, | ema_fakequant_qconfig, | ||||
min_max_fakequant_qconfig, | min_max_fakequant_qconfig, | ||||
tqt_quant_qconfig, | |||||
) | ) |
@@ -5,18 +5,21 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import copy | |||||
import math | |||||
import numpy as np | |||||
from .. import functional as F | from .. import functional as F | ||||
from .._internal.dtype import _metadata_dict | |||||
from .._internal.dtype import _metadata_dict, get_quantized_dtype | |||||
from ..core import Buffer, Function, Parameter | |||||
from ..jit import sideeffect | |||||
from ..module import Module | from ..module import Module | ||||
from .observer import ObserverMode, Round | from .observer import ObserverMode, Round | ||||
class FakeQuantize(Module): | |||||
r""" | |||||
A module to do quant and dequant according to observer's scale and zero_point. | |||||
""" | |||||
def __init__(self, dtype: str, enable: bool = True): | |||||
class _FakeQuantize(Module): | |||||
def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): | |||||
super().__init__() | super().__init__() | ||||
if not dtype in _metadata_dict.keys(): | if not dtype in _metadata_dict.keys(): | ||||
raise ValueError( | raise ValueError( | ||||
@@ -25,7 +28,10 @@ class FakeQuantize(Module): | |||||
) | ) | ||||
) | ) | ||||
self.dtype = dtype | self.dtype = dtype | ||||
self.qmin = _metadata_dict[dtype].qmin | |||||
self.narrow_range = narrow_range | |||||
self.qmin = ( | |||||
-_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin | |||||
) | |||||
self.qmax = _metadata_dict[dtype].qmax | self.qmax = _metadata_dict[dtype].qmax | ||||
self.enabled = enable | self.enabled = enable | ||||
@@ -35,25 +41,108 @@ class FakeQuantize(Module): | |||||
def disable(self): | def disable(self): | ||||
self.enabled = False | self.enabled = False | ||||
def fake_quant_forward(self, inp, q_dict): | |||||
return inp | |||||
def normal_foward(self, inp, q_dict): | |||||
return inp | |||||
def forward(self, inp, q_dict): | def forward(self, inp, q_dict): | ||||
if self.enabled: | if self.enabled: | ||||
if q_dict["mode"] == ObserverMode.SYMMERTIC: | |||||
scale = q_dict["scale"] | |||||
# Quant | |||||
oup = Round()(inp / scale) | |||||
# clip | |||||
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||||
# DeQuant | |||||
oup = (oup) * scale | |||||
return oup | |||||
else: | |||||
scale = q_dict["scale"] | |||||
zero_point = q_dict["zero_point"] | |||||
# Quant | |||||
oup = Round()(inp / scale) + zero_point | |||||
# clip | |||||
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||||
# DeQuant | |||||
oup = (oup - zero_point) * scale | |||||
return oup | |||||
return self.fake_quant_forward(inp, q_dict) | |||||
else: | |||||
return self.normal_foward(inp, q_dict) | |||||
class TQT_Function(Function): | |||||
def __init__(self, lowerbound, upperbound): | |||||
super().__init__() | |||||
self.lowerbound = lowerbound | |||||
self.upperbound = upperbound | |||||
def forward(self, inp, scale): | |||||
t = 2 ** scale | |||||
# t = F.maximum(t, 1e-4) | |||||
inp_scaled = inp / t | |||||
inp_clipped = F.maximum(F.minimum(inp_scaled, self.upperbound), self.lowerbound) | |||||
inp_rounded = F.round(inp_clipped) | |||||
inp_flq = inp_rounded * t | |||||
self.save_for_backward(inp_scaled, inp_rounded, t) | |||||
return inp_flq | |||||
def backward(self, grad_inp_flq): | |||||
(inp_scaled, inp_rounded, t) = self.saved_tensors | |||||
mask_clip = (inp_scaled < -0.5 + self.lowerbound) + ( | |||||
inp_scaled > self.upperbound + 0.5 | |||||
) # mask for accumulating the gradients of |data_scaled|>L | |||||
mask_quant = F.abs( | |||||
mask_clip - 1 | |||||
) # mask for accumulating the gradients with |data_scaled|<=L | |||||
grad_quant = ( | |||||
grad_inp_flq * mask_quant * (inp_rounded - inp_scaled) | |||||
) # gradient within |data_scaled|<=L | |||||
grad_clip = ( | |||||
grad_inp_flq * mask_clip * inp_rounded | |||||
) # gradient with | data_scaled|>L | |||||
grad_s = grad_clip.sum() + grad_quant.sum() | |||||
# dL/ds = dL/dt * t * ln(2) | |||||
grad_s = grad_s * t * math.log(2) | |||||
grad_inp = grad_inp_flq * mask_quant | |||||
return grad_inp, grad_s | |||||
class TQT(_FakeQuantize): | |||||
""" | |||||
TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds | |||||
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks | |||||
""" | |||||
def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): | |||||
super().__init__(dtype, narrow_range, enable) | |||||
self.scale = Parameter(0.0, dtype=np.float32) | |||||
def fake_quant_forward(self, inp, q_dict): | |||||
# when enable, TQT will do fakequant forward, finetune the scale | |||||
return TQT_Function(self.qmin, self.qmax)(inp, self.scale) | |||||
def normal_foward(self, inp, q_dict): | |||||
# when disable, TQT will do normal forward, initialize scale weight | |||||
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"])) | |||||
tmp_scale = F.log(tmp_scale / 127) / F.log(2) | |||||
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) | |||||
return inp | return inp | ||||
def get_dtype(self): | |||||
return get_quantized_dtype(self.dtype, 2 ** self.scale.numpy()[0], None) | |||||
class FakeQuantize(_FakeQuantize): | |||||
r""" | |||||
A module to do quant and dequant according to observer's scale and zero_point. | |||||
:param dtype: A string indicating the target quantization type of input. | |||||
:param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``, | |||||
instead of 1 greater. Usually True for weight and False for activation. | |||||
:param enable: Whether do ``normal_forward`` or ``fake_quant_forward``. | |||||
""" | |||||
def fake_quant_forward(self, inp, q_dict): | |||||
if q_dict["mode"] == ObserverMode.SYMMERTIC: | |||||
scale = q_dict["scale"] | |||||
# Quant | |||||
oup = Round()(inp / scale) | |||||
# clip | |||||
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||||
# DeQuant | |||||
oup = (oup) * scale | |||||
return oup | |||||
else: | |||||
scale = q_dict["scale"] | |||||
zero_point = q_dict["zero_point"] | |||||
# Quant | |||||
oup = Round()(inp / scale) + zero_point | |||||
# clip | |||||
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||||
# DeQuant | |||||
oup = (oup - zero_point) * scale | |||||
return oup |
@@ -31,9 +31,11 @@ class Observer(Module): | |||||
A base class for Observer Module. | A base class for Observer Module. | ||||
:param dtype: a string indicating to collect scale and zero_point of which dtype | :param dtype: a string indicating to collect scale and zero_point of which dtype | ||||
:param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``, | |||||
instead of 1 greater. Usually True for weight and False for activation. | |||||
""" | """ | ||||
def __init__(self, dtype="qint8"): | |||||
def __init__(self, dtype: str, narrow_range: bool = False): | |||||
super().__init__() | super().__init__() | ||||
if dtype not in _metadata_dict.keys(): | if dtype not in _metadata_dict.keys(): | ||||
raise ValueError( | raise ValueError( | ||||
@@ -42,7 +44,10 @@ class Observer(Module): | |||||
) | ) | ||||
) | ) | ||||
self.dtype = dtype | self.dtype = dtype | ||||
self.qmin = _metadata_dict[dtype].qmin | |||||
self.narrow_range = narrow_range | |||||
self.qmin = ( | |||||
-_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin | |||||
) | |||||
self.qmax = _metadata_dict[dtype].qmax | self.qmax = _metadata_dict[dtype].qmax | ||||
self.enabled = True | self.enabled = True | ||||
@@ -96,8 +101,14 @@ def create_observer_dict(mode): | |||||
class MinMaxObserver(Observer): | class MinMaxObserver(Observer): | ||||
def __init__(self, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8"): | |||||
super().__init__(dtype) | |||||
def __init__( | |||||
self, | |||||
mode=ObserverMode.SYMMERTIC, | |||||
eps=0.00001, | |||||
dtype="qint8", | |||||
narrow_range: bool = False, | |||||
): | |||||
super().__init__(dtype, narrow_range) | |||||
self.mode = mode | self.mode = mode | ||||
self.min_val = Buffer(np.finfo(np.float32).max, dtype=np.float32) | self.min_val = Buffer(np.finfo(np.float32).max, dtype=np.float32) | ||||
self.max_val = Buffer(np.finfo(np.float32).min, dtype=np.float32) | self.max_val = Buffer(np.finfo(np.float32).min, dtype=np.float32) | ||||
@@ -107,6 +118,8 @@ class MinMaxObserver(Observer): | |||||
min_val = F.minimum(0.0, inp_min_val) | min_val = F.minimum(0.0, inp_min_val) | ||||
max_val = F.maximum(0.0, inp_max_val) | max_val = F.maximum(0.0, inp_max_val) | ||||
q_dict = create_observer_dict(self.mode) | q_dict = create_observer_dict(self.mode) | ||||
q_dict["min_val"] = inp_min_val | |||||
q_dict["max_val"] = inp_max_val | |||||
if self.mode == ObserverMode.SYMMERTIC: | if self.mode == ObserverMode.SYMMERTIC: | ||||
symmetric_max_vals = F.maximum(-min_val, max_val) | symmetric_max_vals = F.maximum(-min_val, max_val) | ||||
# use maximun to avoid scale too small at the begin | # use maximun to avoid scale too small at the begin | ||||
@@ -151,9 +164,14 @@ class MinMaxObserver(Observer): | |||||
class ExponentialMovingAverageObserver(MinMaxObserver): | class ExponentialMovingAverageObserver(MinMaxObserver): | ||||
def __init__( | def __init__( | ||||
self, momentum=0.9, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8" | |||||
self, | |||||
momentum=0.9, | |||||
mode=ObserverMode.SYMMERTIC, | |||||
eps=0.00001, | |||||
dtype="qint8", | |||||
narrow_range: bool = False, | |||||
): | ): | ||||
super().__init__(mode, eps, dtype) | |||||
super().__init__(mode, eps, dtype, narrow_range) | |||||
self.momentum = Buffer(momentum) | self.momentum = Buffer(momentum) | ||||
self.runtime_momentum = Buffer(0.0) | self.runtime_momentum = Buffer(0.0) | ||||
@@ -186,11 +204,12 @@ class HistogramObserver(MinMaxObserver): | |||||
self, | self, | ||||
bins=2048, | bins=2048, | ||||
upsample_rate=128, | upsample_rate=128, | ||||
dtype="qint8", | |||||
mode=ObserverMode.SYMMERTIC, | mode=ObserverMode.SYMMERTIC, | ||||
eps=0.00001, | eps=0.00001, | ||||
dtype="qint8", | |||||
narrow_range: bool = False, | |||||
): | ): | ||||
super().__init__(mode, eps, dtype) | |||||
super().__init__(mode, eps, dtype, narrow_range) | |||||
self.bins = bins | self.bins = bins | ||||
self.upsample_rate = upsample_rate | self.upsample_rate = upsample_rate | ||||
self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 | self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 | ||||
@@ -1,12 +1,14 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
# | # | ||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
# | |||||
#' | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from functools import partial | |||||
from ..module import Module | from ..module import Module | ||||
from .fake_quant import FakeQuantize | |||||
from .fake_quant import TQT, FakeQuantize | |||||
from .observer import ( | from .observer import ( | ||||
ExponentialMovingAverageObserver, | ExponentialMovingAverageObserver, | ||||
HistogramObserver, | HistogramObserver, | ||||
@@ -22,9 +24,9 @@ class QConfig: | |||||
:param weight_observer: interface to instantiate an :class:`~.Observer` indicating | :param weight_observer: interface to instantiate an :class:`~.Observer` indicating | ||||
how to collect scales and zero_point of wegiht. | how to collect scales and zero_point of wegiht. | ||||
:param act_observer: similar to ``weight_observer`` but toward activation. | :param act_observer: similar to ``weight_observer`` but toward activation. | ||||
:param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating | |||||
how to do fake_quant calculation. can be invoked multi times to get different | |||||
instance for each target tensor, for better control on enable and disable. | |||||
:param weight_fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating | |||||
how to do fake_quant calculation. | |||||
:param act_observer: similar to ``weight_fake_quant`` but toward activation. | |||||
Examples: | Examples: | ||||
@@ -32,14 +34,24 @@ class QConfig: | |||||
# Default EMA QConfig for QAT. | # Default EMA QConfig for QAT. | ||||
ema_fakequant_qconfig = QConfig( | ema_fakequant_qconfig = QConfig( | ||||
weight_observer=MinMaxObserver, | |||||
act_observer=ExponentialMovingAverageObserver, | |||||
fake_quant=FakeQuantize, | |||||
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||||
act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False), | |||||
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||||
) | ) | ||||
Each parameter is a ``class`` rather than an instance. And we recommand using ``functools.partial`` | |||||
to add initialization parameters of the ``class``, so that don't need to provide parameters in | |||||
:meth:`~.QATModule.set_qconfig`. | |||||
Usually we set ``narrow_range`` of weight related paramters to ``True`` and of activation related | |||||
parameters to ``False``. For the result of multiplication and addition as ``a * b + c * d``, if | |||||
four variables are all -128 of dtype ``qint8``, then the result will be ``2^15`` and cause overflow. | |||||
Weights are commonly calculated in this way, so needed to narrow the range. | |||||
""" | """ | ||||
def __init__( | def __init__( | ||||
self, act_observer, weight_observer, fake_quant, | |||||
self, weight_observer, act_observer, weight_fake_quant, act_fake_quant | |||||
): | ): | ||||
if isinstance(act_observer, Module) or isinstance(weight_observer, Module): | if isinstance(act_observer, Module) or isinstance(weight_observer, Module): | ||||
raise ValueError( | raise ValueError( | ||||
@@ -47,24 +59,42 @@ class QConfig: | |||||
" class generator using `partial(Observer, ...)` instead. Use" | " class generator using `partial(Observer, ...)` instead. Use" | ||||
" partial(MyObserver, x=1) to override arguments to constructor if needed" | " partial(MyObserver, x=1) to override arguments to constructor if needed" | ||||
) | ) | ||||
self.act_observer = act_observer | |||||
self.weight_observer = weight_observer | self.weight_observer = weight_observer | ||||
self.fake_quant = fake_quant | |||||
self.act_observer = act_observer | |||||
self.weight_fake_quant = weight_fake_quant | |||||
self.act_fake_quant = act_fake_quant | |||||
# Default QAT QConfigs | |||||
tqt_quant_qconfig = QConfig( | |||||
weight_observer=partial( | |||||
ExponentialMovingAverageObserver, dtype="qint8", narrow_range=True | |||||
), | |||||
act_observer=partial( | |||||
ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False | |||||
), | |||||
weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True), | |||||
act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False), | |||||
) | |||||
min_max_fakequant_qconfig = QConfig( | min_max_fakequant_qconfig = QConfig( | ||||
weight_observer=MinMaxObserver, | |||||
act_observer=MinMaxObserver, | |||||
fake_quant=FakeQuantize, | |||||
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||||
act_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=False), | |||||
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||||
) | ) | ||||
ema_fakequant_qconfig = QConfig( | ema_fakequant_qconfig = QConfig( | ||||
weight_observer=MinMaxObserver, | |||||
act_observer=ExponentialMovingAverageObserver, | |||||
fake_quant=FakeQuantize, | |||||
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||||
act_observer=partial( | |||||
ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False | |||||
), | |||||
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||||
) | ) | ||||
calibration_qconfig = QConfig( | calibration_qconfig = QConfig( | ||||
weight_observer=MinMaxObserver, act_observer=HistogramObserver, fake_quant=None, | |||||
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||||
act_observer=partial(HistogramObserver, dtype="qint8", narrow_range=False), | |||||
weight_fake_quant=None, | |||||
act_fake_quant=None, | |||||
) | ) |
@@ -6,6 +6,10 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from megengine._internal.plugin import load_tensor_binary | |||||
def prod(iterable): | def prod(iterable): | ||||
result = 1 | result = 1 | ||||
for i in iterable: | for i in iterable: | ||||
@@ -1,2 +1 @@ | |||||
__version__ = "0.4.0" | |||||
__version__ = "0.5.1" |
@@ -96,7 +96,6 @@ def test_deepcopy(): | |||||
origin = Sigmoid(0) | origin = Sigmoid(0) | ||||
new = copy.deepcopy(Sigmoid(0)) | new = copy.deepcopy(Sigmoid(0)) | ||||
assert new.param == origin.param | assert new.param == origin.param | ||||
assert new.saved_tensors == None | |||||
def test_save_context(): | def test_save_context(): | ||||
@@ -10,6 +10,7 @@ import numpy as np | |||||
import pytest | import pytest | ||||
import megengine as mge | import megengine as mge | ||||
import megengine._internal as mgb | |||||
def test_wrong_dtype(): | def test_wrong_dtype(): | ||||
@@ -26,3 +27,48 @@ def test_tensor_routine(): | |||||
mge.tensor([1]) | mge.tensor([1]) | ||||
mge.tensor(1.5) | mge.tensor(1.5) | ||||
def test_tensor_set_dtype(): | |||||
def check_dtype_value(tensor, dtype_scale, value): | |||||
if mgb.dtype.is_quantize(tensor.dtype): | |||||
if np.abs(mgb.dtype.get_scale(tensor.dtype) - dtype_scale) > 1e-5: | |||||
raise AssertionError( | |||||
"compare scale failed expect {} got {}".format( | |||||
dtype_scale, mgb.dtype.get_scale(tensor.dtype) | |||||
) | |||||
) | |||||
if np.abs(tensor.numpy()[0][0] - value) > 1e-5: | |||||
raise AssertionError( | |||||
"compare value failed expect {} got {}".format( | |||||
tensor.numpy()[0][0], value | |||||
) | |||||
) | |||||
t = mge.Parameter(np.ones((3, 4), dtype="float32")) | |||||
t.set_dtype(mgb.dtype.qint8(0.1)) | |||||
check_dtype_value(t, 0.1, 10) | |||||
t = mge.Parameter(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) | |||||
t.set_dtype(mgb.dtype.qint8(0.3)) | |||||
check_dtype_value(t, 0.3, 3) | |||||
t = mge.Buffer(np.ones((3, 4), dtype="float32")) | |||||
t.set_dtype(mgb.dtype.qint8(0.1)) | |||||
check_dtype_value(t, 0.1, 10) | |||||
t = mge.Buffer(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) | |||||
t.set_dtype(mgb.dtype.qint8(0.3)) | |||||
check_dtype_value(t, 0.3, 3) | |||||
t = mge.Buffer(np.ones((3, 4), dtype="float32")) | |||||
s = t + 1 | |||||
s.set_dtype(mgb.dtype.qint8(0.2)) | |||||
check_dtype_value(s, 0.2, 10) | |||||
t.set_dtype(mgb.dtype.qint8(0.3)) | |||||
s = t + 1 | |||||
s.set_dtype(mgb.dtype.qint8(0.1)) | |||||
check_dtype_value(s, 0.1, 18) | |||||
s.set_dtype("float32") | |||||
check_dtype_value(s, 0, 1.8) |
@@ -132,3 +132,52 @@ def test_dataloader_parallel_worker_exception(): | |||||
with pytest.raises(RuntimeError, match=r"worker.*died"): | with pytest.raises(RuntimeError, match=r"worker.*died"): | ||||
data_iter = iter(dataloader) | data_iter = iter(dataloader) | ||||
batch_data = next(data_iter) | batch_data = next(data_iter) | ||||
def _multi_instances_parallel_dataloader_worker(): | |||||
dataset = init_dataset() | |||||
for divide_flag in [True, False]: | |||||
train_dataloader = DataLoader( | |||||
dataset, | |||||
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||||
num_workers=2, | |||||
divide=divide_flag, | |||||
) | |||||
val_dataloader = DataLoader( | |||||
dataset, | |||||
sampler=RandomSampler(dataset, batch_size=10, drop_last=False), | |||||
num_workers=2, | |||||
divide=divide_flag, | |||||
) | |||||
for idx, (data, label) in enumerate(train_dataloader): | |||||
assert data.shape == (4, 1, 32, 32) | |||||
assert label.shape == (4,) | |||||
if idx % 5 == 0: | |||||
for val_data, val_label in val_dataloader: | |||||
assert val_data.shape == (10, 1, 32, 32) | |||||
assert val_label.shape == (10,) | |||||
def test_dataloader_parallel_multi_instances(): | |||||
# set max shared memory to 100M | |||||
os.environ["MGE_PLASMA_MEMORY"] = "100000000" | |||||
_multi_instances_parallel_dataloader_worker() | |||||
def test_dataloader_parallel_multi_instances_multiprocessing(): | |||||
# set max shared memory to 100M | |||||
os.environ["MGE_PLASMA_MEMORY"] = "100000000" | |||||
import multiprocessing as mp | |||||
# mp.set_start_method("spawn") | |||||
processes = [] | |||||
for i in range(4): | |||||
p = mp.Process(target=_multi_instances_parallel_dataloader_worker) | |||||
p.start() | |||||
processes.append(p) | |||||
for p in processes: | |||||
p.join() |
@@ -14,8 +14,10 @@ import pytest | |||||
from helpers import MLP | from helpers import MLP | ||||
import megengine as mge | import megengine as mge | ||||
import megengine._internal as mgb | |||||
from megengine.core import Buffer, Parameter, Tensor, tensor | from megengine.core import Buffer, Parameter, Tensor, tensor | ||||
from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential | from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential | ||||
from megengine.quantization.quantize import quantize, quantize_qat | |||||
from megengine.test import assertTensorClose | from megengine.test import assertTensorClose | ||||
@@ -347,3 +349,38 @@ def test_dump_model(): | |||||
pred = mlp(data) | pred = mlp(data) | ||||
with tempfile.NamedTemporaryFile() as f: | with tempfile.NamedTemporaryFile() as f: | ||||
mge.dump(pred, f.name) | mge.dump(pred, f.name) | ||||
def test_load_quantized(): | |||||
data_shape = (2, 28) | |||||
data = tensor(np.random.random(data_shape), dtype="float32") | |||||
data = data.astype(mgb.dtype.qint8(0.1)) | |||||
mlp = MLP() | |||||
quantize_qat(mlp) | |||||
quantize(mlp) | |||||
mlp.dense0.weight = Parameter( | |||||
mlp.dense0.weight.astype(mgb.dtype.qint8(0.001)).numpy() | |||||
) | |||||
mlp.dense1.weight = Parameter( | |||||
mlp.dense1.weight.astype(mgb.dtype.qint8(0.0002)).numpy() | |||||
) | |||||
mlp.eval() | |||||
pred0 = mlp(data) | |||||
with BytesIO() as fout: | |||||
mge.save(mlp.state_dict(), fout) | |||||
fout.seek(0) | |||||
checkpoint = mge.load(fout) | |||||
# change mlp weight. | |||||
mlp.dense0.weight = Parameter( | |||||
mlp.dense0.weight.astype(mgb.dtype.qint8(0.00001)).numpy() | |||||
) | |||||
mlp.dense1.weight = Parameter( | |||||
mlp.dense1.weight.astype(mgb.dtype.qint8(0.2)).numpy() | |||||
) | |||||
mlp.load_state_dict(checkpoint) | |||||
pred1 = mlp(data) | |||||
assertTensorClose( | |||||
pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), max_err=5e-6 | |||||
) |
@@ -0,0 +1,77 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import numpy as np | |||||
import pytest | |||||
import megengine as mge | |||||
import megengine._internal as mgb | |||||
from megengine.core import tensor | |||||
from megengine.quantization.fake_quant import TQT_Function | |||||
from megengine.test import assertTensorClose | |||||
class numpy_TQT_Function: | |||||
def __init__(self, lowerbound, upperbound): | |||||
super().__init__() | |||||
self.lowerbound = lowerbound | |||||
self.upperbound = upperbound | |||||
def forward(self, inp, scale): | |||||
t = 2 ** scale | |||||
# t = F.maximum(t, 1e-4) | |||||
inp_scaled = inp / t | |||||
inp_clipped = np.maximum( | |||||
np.minimum(inp_scaled, self.upperbound), self.lowerbound | |||||
) | |||||
inp_rounded = np.round(inp_clipped) | |||||
inp_flq = inp_rounded * t | |||||
self.saved_tensors = (inp_scaled, inp_rounded, t) | |||||
return inp_flq | |||||
def backward(self, grad_inp_flq): | |||||
(inp_scaled, inp_rounded, t) = self.saved_tensors | |||||
mask_clip = (inp_scaled < -0.5 + self.lowerbound) + ( | |||||
inp_scaled > self.upperbound + 0.5 | |||||
) # mask for accumulating the gradients of |data_scaled|>L | |||||
mask_quant = np.abs( | |||||
mask_clip - 1 | |||||
) # mask for accumulating the gradients with |data_scaled|<=L | |||||
grad_quant = ( | |||||
grad_inp_flq * mask_quant * (inp_rounded - inp_scaled) | |||||
) # gradient within |data_scaled|<=L | |||||
grad_clip = ( | |||||
grad_inp_flq * mask_clip * inp_rounded | |||||
) # gradient with | data_scaled|>L | |||||
grad_s = grad_clip.sum() + grad_quant.sum() | |||||
# dL/ds = dL/dt * t * ln(2) | |||||
grad_s = grad_s * t * np.log(2) | |||||
grad_inp = grad_inp_flq * mask_quant | |||||
return grad_inp, grad_s | |||||
def test_TQT(): | |||||
f = TQT_Function(-127, 127) | |||||
nf = numpy_TQT_Function(-127, 127) | |||||
def check_inp(a, b, c, a_np, b_np, c_np): | |||||
assertTensorClose( | |||||
f.forward(a, b).numpy(), nf.forward(a_np, b_np).astype("float32") | |||||
) | |||||
c1, c2 = f.backward(c) | |||||
c1_np, c2_np = nf.backward(c_np) | |||||
assertTensorClose(c1.numpy(), c1_np.astype("float32")) | |||||
assertTensorClose(c2.numpy(), c2_np.astype("float32")) | |||||
a = tensor() | |||||
b = tensor() | |||||
a_np = np.random.random((4, 3)).astype("float32") | |||||
b_np = np.random.random((1)).astype("float32") | |||||
a.set_value(a_np) | |||||
b.set_value(b_np) | |||||
check_inp(a, b, b, a_np, b_np, b_np) |
@@ -14,7 +14,7 @@ import struct | |||||
import cv2 | import cv2 | ||||
import numpy as np | import numpy as np | ||||
import megbrain as mgb | |||||
import megengine._internal as mgb | |||||
import megengine as mge | import megengine as mge | ||||
logger = mge.get_logger(__name__) | logger = mge.get_logger(__name__) | ||||
@@ -709,6 +709,41 @@ void run_test_st(Args &env) { | |||||
} | } | ||||
}; | }; | ||||
auto run_iters = [&](uint32_t case_idx) -> float { | |||||
double time_sqrsum = 0, time_sum = 0, | |||||
min_time = std::numeric_limits<double>::max(), max_time = 0; | |||||
for (int run = 0; run < env.nr_run; ++run) { | |||||
mgb_log_debug("load_and_run: before running iter %d", run); | |||||
timer.reset(); | |||||
func->execute(); | |||||
mgb_log_debug("load_and_run: before waiting iter %d", run); | |||||
auto exec_time = timer.get_msecs(); | |||||
func->wait(); | |||||
output_dumper.write_to_file(); | |||||
auto cur = timer.get_msecs(); | |||||
printf("iter %d/%d: %.3fms (exec=%.3f,device=%.3f)\n", run, | |||||
env.nr_run, cur, exec_time, | |||||
func->get_prev_exec_time() * 1e3); | |||||
time_sum += cur; | |||||
time_sqrsum += cur * cur; | |||||
fflush(stdout); | |||||
if (cur < min_time) { | |||||
min_time = cur; | |||||
} | |||||
if (cur > max_time) { | |||||
max_time = cur; | |||||
} | |||||
} | |||||
printf("=== finished test #%u: time=%.3fms avg_time=%.3fms " | |||||
"sd=%.3fms minmax=%.3f,%.3f\n\n", | |||||
case_idx, time_sum, time_sum / env.nr_run, | |||||
std::sqrt((time_sqrsum * env.nr_run - time_sum * time_sum) / | |||||
(env.nr_run * (env.nr_run - 1))), | |||||
min_time, max_time); | |||||
return time_sum; | |||||
}; | |||||
if (nr_test) { | if (nr_test) { | ||||
// run testcase, generated by dump_with_testcase.py | // run testcase, generated by dump_with_testcase.py | ||||
@@ -742,37 +777,7 @@ void run_test_st(Args &env) { | |||||
if (!env.nr_run) { | if (!env.nr_run) { | ||||
continue; | continue; | ||||
} | } | ||||
double time_sqrsum = 0, time_sum = 0, | |||||
min_time = std::numeric_limits<double>::max(), max_time = 0; | |||||
for (int run = 0; run < env.nr_run; ++ run) { | |||||
mgb_log_debug("load_and_run: before running iter %d", run); | |||||
timer.reset(); | |||||
func->execute(); | |||||
mgb_log_debug("load_and_run: before waiting iter %d", run); | |||||
auto exec_time = timer.get_msecs(); | |||||
func->wait(); | |||||
output_dumper.write_to_file(); | |||||
auto cur = timer.get_msecs(); | |||||
printf("iter %d/%d: %.3fms (exec=%.3f,device=%.3f)\n", run, | |||||
env.nr_run, cur, exec_time, | |||||
func->get_prev_exec_time() * 1e3); | |||||
time_sum += cur; | |||||
time_sqrsum += cur * cur; | |||||
fflush(stdout); | |||||
if (cur < min_time) { | |||||
min_time = cur; | |||||
} | |||||
if (cur > max_time) { | |||||
max_time = cur; | |||||
} | |||||
} | |||||
tot_time += time_sum; | |||||
printf("=== finished test #%u: time=%.3fms avg_time=%.3fms " | |||||
"sd=%.3fms minmax=%.3f,%.3f\n\n", | |||||
i, time_sum, time_sum / env.nr_run, | |||||
std::sqrt((time_sqrsum * env.nr_run - time_sum * time_sum) / | |||||
(env.nr_run * (env.nr_run - 1))), | |||||
min_time, max_time); | |||||
tot_time += run_iters(i); | |||||
} | } | ||||
printf("=== total time: %.3fms\n", tot_time); | printf("=== total time: %.3fms\n", tot_time); | ||||
@@ -793,15 +798,10 @@ void run_test_st(Args &env) { | |||||
in->copy_from(i.second); | in->copy_from(i.second); | ||||
} | } | ||||
warmup(); | |||||
timer.reset(); | timer.reset(); | ||||
func->execute(); | |||||
auto exec_time = timer.get_msecs(); | |||||
func->wait(); | |||||
output_dumper.write_to_file(); | |||||
auto cur = timer.get_msecs(); | |||||
printf("%.3fms %.3fms (device=%.3f)\n", cur, exec_time, | |||||
func->get_prev_exec_time() * 1e3); | |||||
printf("=== going to run input for %d times\n", env.nr_run); | |||||
run_iters(0); | |||||
} else { | } else { | ||||
// run speed test for a raw mgb graph | // run speed test for a raw mgb graph | ||||
mgb_assert(env.load_ret.tensor_map.empty(), | mgb_assert(env.load_ret.tensor_map.empty(), | ||||
@@ -34,6 +34,11 @@ if(MGE_WITH_CUDA AND MGE_WITH_TRT) | |||||
endif() | endif() | ||||
if(MGE_WITH_CUDA) | |||||
file(GLOB_RECURSE SOURCES_ opr/impl/standalone/*.cu) | |||||
list(APPEND SOURCES ${SOURCES_}) | |||||
endif() | |||||
add_library(megbrain OBJECT EXCLUDE_FROM_ALL ${SOURCES}) | add_library(megbrain OBJECT EXCLUDE_FROM_ALL ${SOURCES}) | ||||
target_link_libraries(megbrain PUBLIC mgb_opr_param_defs) | target_link_libraries(megbrain PUBLIC mgb_opr_param_defs) | ||||
target_include_directories(megbrain | target_include_directories(megbrain | ||||
@@ -795,7 +795,7 @@ bool CpuCompNode::CompNodeImpl::check_global_finalized(const char* reason) { | |||||
/* ======================== CompNode methods ======================== */ | /* ======================== CompNode methods ======================== */ | ||||
CompNode CompNode::default_cpu() { | CompNode CompNode::default_cpu() { | ||||
static Locator locator{DeviceType::CPU, Locator::DEVICE_CPU_DEFAULT, -1}; | |||||
static Locator locator{DeviceType::CPU, Locator::DEVICE_CPU_DEFAULT, {-1}}; | |||||
static auto empty_queue = | static auto empty_queue = | ||||
std::make_shared<CpuCompNode::WorkerQueue>(locator); | std::make_shared<CpuCompNode::WorkerQueue>(locator); | ||||
static CpuCompNodeImpl impl{locator, locator, empty_queue}; | static CpuCompNodeImpl impl{locator, locator, empty_queue}; | ||||
@@ -464,7 +464,7 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( | |||||
#if MGB_ENABLE_TENSOR_RT | #if MGB_ENABLE_TENSOR_RT | ||||
if (options().graph_opt.tensorrt) { | if (options().graph_opt.tensorrt) { | ||||
options().graph_opt.tensorrt = false; | options().graph_opt.tensorrt = false; | ||||
tensorrt::transform_dest_vars_inplace(dest_vars); | |||||
tensorrt::transform_dest_vars_inplace(dest_vars, options().graph_opt); | |||||
} | } | ||||
#endif | #endif | ||||
@@ -12,8 +12,8 @@ | |||||
#pragma once | #pragma once | ||||
#define MGB_MAJOR 8 | #define MGB_MAJOR 8 | ||||
#define MGB_MINOR 4 | |||||
#define MGB_PATCH 1 | |||||
#define MGB_MINOR 5 | |||||
#define MGB_PATCH 0 | |||||
//! whether it is development version | //! whether it is development version | ||||
#ifndef MGB_IS_DEV | #ifndef MGB_IS_DEV | ||||
#define MGB_IS_DEV 0 | #define MGB_IS_DEV 0 | ||||
@@ -756,6 +756,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( | |||||
cb(nchw32, { | cb(nchw32, { | ||||
add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
add_pass<FuseConvBiasZPass>(); | add_pass<FuseConvBiasZPass>(); | ||||
add_pass(EnableNCHW4Pass::make_nchw4_converter()); | |||||
add_pass(EnableTensorCorePass::make_tensorcore_converter()); | add_pass(EnableTensorCorePass::make_tensorcore_converter()); | ||||
add_pass<ShuffleShuffleRemovePass>(); | add_pass<ShuffleShuffleRemovePass>(); | ||||
add_pass<RemoveRedundantTypeCvtPass>(); | add_pass<RemoveRedundantTypeCvtPass>(); | ||||
@@ -763,6 +764,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( | |||||
cb(chwn4, { | cb(chwn4, { | ||||
add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
add_pass<FuseConvBiasZPass>(); | add_pass<FuseConvBiasZPass>(); | ||||
add_pass(EnableNCHW4Pass::make_nchw4_converter()); | |||||
add_pass(EnableCHWN4Pass::make_chwn4_converter()); | add_pass(EnableCHWN4Pass::make_chwn4_converter()); | ||||
add_pass<ShuffleShuffleRemovePass>(); | add_pass<ShuffleShuffleRemovePass>(); | ||||
add_pass<RemoveRedundantTypeCvtPass>(); | add_pass<RemoveRedundantTypeCvtPass>(); | ||||
@@ -60,19 +60,24 @@ MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder, | |||||
public: | public: | ||||
//! relayout type of this opr | //! relayout type of this opr | ||||
enum class LayoutType { | enum class LayoutType { | ||||
NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout | |||||
NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout | |||||
NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout | |||||
CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout | |||||
NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout | |||||
NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout | |||||
NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout | |||||
NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout | |||||
NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout | |||||
NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout | |||||
NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout | |||||
CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout | |||||
NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout | |||||
NCHW_TO_NCHW4_IC_SMALL_CONV, ///< from nchw layout to nchw4 whose | |||||
///< channel size less than 4 | |||||
NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout | |||||
NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout | |||||
NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout | |||||
WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4 | WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4 | ||||
//!< layout | //!< layout | ||||
WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to | WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to | ||||
//!< nchw4 layout | //!< nchw4 layout | ||||
WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV, //!< weight from nchw layout | |||||
//!< to nchw4 layout whose | |||||
//! channel size less than 4 | |||||
WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88 | WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88 | ||||
//!< layout | //!< layout | ||||
@@ -177,11 +182,21 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { | |||||
dst[3] = inp_shape[2]; | dst[3] = inp_shape[2]; | ||||
dst[4] = inp_shape[4]; | dst[4] = inp_shape[4]; | ||||
} else if (layout_type() == | } else if (layout_type() == | ||||
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4){ | |||||
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); | |||||
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4 || | |||||
layout_type() == RelayoutPlaceholder::LayoutType:: | |||||
NCHW_TO_NCHW4_IC_SMALL_CONV) { | |||||
if (layout_type() == | |||||
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4) { | |||||
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); | |||||
} else { | |||||
mgb_assert(layout_type() == | |||||
RelayoutPlaceholder::LayoutType:: | |||||
NCHW_TO_NCHW4_IC_SMALL_CONV); | |||||
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] < 4); | |||||
} | |||||
dst.ndim = 5; | dst.ndim = 5; | ||||
dst[0] = inp_shape[0]; | dst[0] = inp_shape[0]; | ||||
dst[1] = inp_shape[1] / 4; | |||||
dst[1] = (inp_shape[1] + 4 - 1) / 4; | |||||
dst[2] = inp_shape[2]; | dst[2] = inp_shape[2]; | ||||
dst[3] = inp_shape[3]; | dst[3] = inp_shape[3]; | ||||
dst[4] = 4; | dst[4] = 4; | ||||
@@ -194,11 +209,23 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { | |||||
dst[2] = inp_shape[2]; | dst[2] = inp_shape[2]; | ||||
dst[3] = inp_shape[3]; | dst[3] = inp_shape[3]; | ||||
} else if (layout_type() == RelayoutPlaceholder::LayoutType:: | } else if (layout_type() == RelayoutPlaceholder::LayoutType:: | ||||
WEIGHT_NCHW_TO_NCHW4_DENSE) { | |||||
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); | |||||
WEIGHT_NCHW_TO_NCHW4_DENSE || | |||||
layout_type() == | |||||
RelayoutPlaceholder::LayoutType:: | |||||
WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV) { | |||||
if (layout_type() == | |||||
RelayoutPlaceholder::LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE) { | |||||
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); | |||||
} else { | |||||
mgb_assert(layout_type() == | |||||
RelayoutPlaceholder::LayoutType:: | |||||
WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV); | |||||
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] < 4); | |||||
} | |||||
dst.ndim = 5; | dst.ndim = 5; | ||||
dst[0] = inp_shape[0]; | dst[0] = inp_shape[0]; | ||||
dst[1] = inp_shape[1] / 4; | |||||
dst[1] = (inp_shape[1] + 4 - 1) / 4; | |||||
dst[2] = inp_shape[2]; | dst[2] = inp_shape[2]; | ||||
dst[3] = inp_shape[3]; | dst[3] = inp_shape[3]; | ||||
dst[4] = 4; | dst[4] = 4; | ||||
@@ -427,6 +454,23 @@ void TensorReformatPass::translate_pass(OptState& opt) const { | |||||
auto y2 = opr::Reshape::make(y1, tshp1); | auto y2 = opr::Reshape::make(y1, tshp1); | ||||
return y2.node(); | return y2.node(); | ||||
}; | }; | ||||
reformat[LayoutType::NCHW_TO_NCHW4_IC_SMALL_CONV] = | |||||
[](VarNode* inp) -> VarNode* { | |||||
auto x = SymbolVar(inp); | |||||
auto y = opr::RelayoutFormat::make( | |||||
x, megdnn::param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL); | |||||
return y.node(); | |||||
}; | |||||
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV] = | |||||
[](VarNode* inp) -> VarNode* { | |||||
auto x = SymbolVar(inp); | |||||
auto y = opr::RelayoutFormat::make( | |||||
x, megdnn::param::RelayoutFormat::Mode:: | |||||
NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT); | |||||
return y.node(); | |||||
}; | |||||
reformat[LayoutType::NCHW_TO_NCHW4] = [](VarNode* inp) -> VarNode* { | reformat[LayoutType::NCHW_TO_NCHW4] = [](VarNode* inp) -> VarNode* { | ||||
auto x = SymbolVar(inp); | auto x = SymbolVar(inp); | ||||
auto xshp = opr::GetVarShape::make(x); | auto xshp = opr::GetVarShape::make(x); | ||||
@@ -435,13 +479,10 @@ void TensorReformatPass::translate_pass(OptState& opt) const { | |||||
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | ||||
}; | }; | ||||
auto tshp0 = opr::Concat::make( | auto tshp0 = opr::Concat::make( | ||||
{sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0), | |||||
tshp1 = opr::Concat::make( | |||||
{sub(0), sub(1) / 4, sub(2), sub(3), cv(4)}, 0); | |||||
{sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0); | |||||
auto y0 = opr::Reshape::make(x, tshp0); | auto y0 = opr::Reshape::make(x, tshp0); | ||||
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); | auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); | ||||
auto y2 = opr::Reshape::make(y1, tshp1); | |||||
return y2.node(); | |||||
return y1.node(); | |||||
}; | }; | ||||
reformat[LayoutType::NCHW4_TO_NCHW] = [](VarNode* inp) -> VarNode* { | reformat[LayoutType::NCHW4_TO_NCHW] = [](VarNode* inp) -> VarNode* { | ||||
auto x = SymbolVar(inp); | auto x = SymbolVar(inp); | ||||
@@ -455,7 +496,8 @@ void TensorReformatPass::translate_pass(OptState& opt) const { | |||||
auto y1 = opr::Reshape::make(y0, tshp0); | auto y1 = opr::Reshape::make(y0, tshp0); | ||||
return y1.node(); | return y1.node(); | ||||
}; | }; | ||||
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE] = [](VarNode* inp) -> VarNode* { | |||||
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE] = | |||||
[](VarNode* inp) -> VarNode* { | |||||
auto x = SymbolVar(inp); | auto x = SymbolVar(inp); | ||||
auto xshp = opr::GetVarShape::make(x); | auto xshp = opr::GetVarShape::make(x); | ||||
auto cv = [&x](int v) { return x.make_scalar(v); }; | auto cv = [&x](int v) { return x.make_scalar(v); }; | ||||
@@ -471,7 +513,8 @@ void TensorReformatPass::translate_pass(OptState& opt) const { | |||||
auto y2 = opr::Reshape::make(y1, tshp1); | auto y2 = opr::Reshape::make(y1, tshp1); | ||||
return y2.node(); | return y2.node(); | ||||
}; | }; | ||||
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP] = [](VarNode* inp) -> VarNode* { | |||||
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP] = | |||||
[](VarNode* inp) -> VarNode* { | |||||
auto x = SymbolVar(inp); | auto x = SymbolVar(inp); | ||||
auto xshp = opr::GetVarShape::make(x); | auto xshp = opr::GetVarShape::make(x); | ||||
auto cv = [&x](int v) { return x.make_scalar(v); }; | auto cv = [&x](int v) { return x.make_scalar(v); }; | ||||
@@ -1357,56 +1400,71 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ | |||||
using RelayoutMode = RelayoutPlaceholder::LayoutType; | using RelayoutMode = RelayoutPlaceholder::LayoutType; | ||||
megdnn::param::Convolution::Format conv_format = | megdnn::param::Convolution::Format conv_format = | ||||
megdnn::param::Convolution::Format::NCHW4; | megdnn::param::Convolution::Format::NCHW4; | ||||
megdnn::param::ConvBias::Format conv_bias_format = | |||||
megdnn::param::ConvBias::Format conv_bias_format = | |||||
megdnn::param::ConvBias::Format::NCHW4; | megdnn::param::ConvBias::Format::NCHW4; | ||||
megdnn::param::BatchConvBias::Format batch_conv_bias_format = | megdnn::param::BatchConvBias::Format batch_conv_bias_format = | ||||
megdnn::param::BatchConvBias::Format::NCHW4; | megdnn::param::BatchConvBias::Format::NCHW4; | ||||
RelayoutMode src_to_nchw4_mode = RelayoutMode::NCHW_TO_NCHW4; | RelayoutMode src_to_nchw4_mode = RelayoutMode::NCHW_TO_NCHW4; | ||||
RelayoutMode src_to_nchw_mode = RelayoutMode::NCHW4_TO_NCHW; | RelayoutMode src_to_nchw_mode = RelayoutMode::NCHW4_TO_NCHW; | ||||
RelayoutMode weight_to_nchw4_mode_dense = | |||||
RelayoutMode weight_to_nchw4_mode_dense = | |||||
RelayoutMode::WEIGHT_NCHW_TO_NCHW4_DENSE; | RelayoutMode::WEIGHT_NCHW_TO_NCHW4_DENSE; | ||||
RelayoutMode weight_to_nchw4_mode_group = | |||||
RelayoutMode weight_to_nchw4_mode_group = | |||||
RelayoutMode::WEIGHT_NCHW_TO_NCHW4_GROUP; | RelayoutMode::WEIGHT_NCHW_TO_NCHW4_GROUP; | ||||
auto trans_nchw4 = [weight_to_nchw4_mode_dense, | |||||
weight_to_nchw4_mode_group]( | |||||
struct ConvMode { | |||||
RelayoutMode weight; | |||||
RelayoutMode src; | |||||
}; | |||||
auto trans_nchw4 = | |||||
[weight_to_nchw4_mode_dense, weight_to_nchw4_mode_group, | |||||
src_to_nchw4_mode]( | |||||
const megdnn::param::Convolution::Sparse conv_mode, | const megdnn::param::Convolution::Sparse conv_mode, | ||||
const VarNode* filter) -> RelayoutMode { | |||||
const VarNode* filter) -> ConvMode { | |||||
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { | if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { | ||||
mgb_assert(filter->shape().ndim == 4, | mgb_assert(filter->shape().ndim == 4, | ||||
"The origin filter is not NCHW mode"); | "The origin filter is not NCHW mode"); | ||||
size_t IC = filter->shape()[1]; | size_t IC = filter->shape()[1]; | ||||
mgb_assert(IC % 4 == 0, | |||||
"The input channel should be divisible by 4"); | |||||
return weight_to_nchw4_mode_dense; | |||||
if (IC < 4) { | |||||
return {RelayoutMode::WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV, | |||||
RelayoutMode::NCHW_TO_NCHW4_IC_SMALL_CONV}; | |||||
} else { | |||||
return {weight_to_nchw4_mode_dense, src_to_nchw4_mode}; | |||||
} | |||||
} else { | } else { | ||||
mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); | mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); | ||||
mgb_assert(filter->shape().ndim == 5, | mgb_assert(filter->shape().ndim == 5, | ||||
"The origin filter if not NCHW mode"); | "The origin filter if not NCHW mode"); | ||||
size_t IC = filter->shape()[2]; | size_t IC = filter->shape()[2]; | ||||
mgb_assert(IC % 4 == 0, | mgb_assert(IC % 4 == 0, | ||||
"The input channel should be divisible by 4"); | |||||
return weight_to_nchw4_mode_group; | |||||
"The input channel should be divisible by 4 for group " | |||||
"conv"); | |||||
return {weight_to_nchw4_mode_group, src_to_nchw4_mode}; | |||||
} | } | ||||
}; | }; | ||||
auto replace_conv_opr = [trans_nchw4, conv_format, src_to_nchw4_mode]( | |||||
OperatorNodeBase* opr, const VarNodeArray& new_inp) { | |||||
auto replace_conv_opr = [trans_nchw4, conv_format]( | |||||
OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) { | |||||
mgb_assert(opr->input().size() == new_inp.size()); | mgb_assert(opr->input().size() == new_inp.size()); | ||||
auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); | auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); | ||||
mgb_assert(conv_opr.param().format == | |||||
megdnn::param::Convolution::Format::NCHW, | |||||
"ConvertFormat Pass only support converting NCHW to NCHW4"); | |||||
if (conv_opr.param().format != | |||||
megdnn::param::Convolution::Format::NCHW) { | |||||
return serialization::copy_opr_shallow(*opr, new_inp, | |||||
opr->config()); | |||||
} | |||||
auto conv_mode = | |||||
trans_nchw4(conv_opr.param().sparse, new_inp[1]); | |||||
VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; | VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; | ||||
// src: NCHW --> NCWH4 | // src: NCHW --> NCWH4 | ||||
if (new_inp[0]->shape().ndim != 5) { | if (new_inp[0]->shape().ndim != 5) { | ||||
mgb_assert(new_inp[0]->shape().ndim == 4); | mgb_assert(new_inp[0]->shape().ndim == 4); | ||||
auto new_src = RelayoutPlaceholder::make(new_inp[0], | |||||
src_to_nchw4_mode); | |||||
auto new_src = | |||||
RelayoutPlaceholder::make(new_inp[0], conv_mode.src); | |||||
conv_src = new_src.node(); | conv_src = new_src.node(); | ||||
} | } | ||||
// weight: NCHW --> NCHW4 | // weight: NCHW --> NCHW4 | ||||
auto weight_mode = | |||||
trans_nchw4(conv_opr.param().sparse, new_inp[1]); | |||||
auto new_filter = RelayoutPlaceholder::make(new_inp[1], weight_mode); | |||||
auto new_filter = | |||||
RelayoutPlaceholder::make(new_inp[1], conv_mode.weight); | |||||
conv_filter = new_filter.node(); | conv_filter = new_filter.node(); | ||||
// format: NCHW --> NCHW4 | // format: NCHW --> NCHW4 | ||||
auto new_param = conv_opr.param(); | auto new_param = conv_opr.param(); | ||||
@@ -1428,7 +1486,13 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ | |||||
mgb_assert(opr->input().size() == new_inp.size()); | mgb_assert(opr->input().size() == new_inp.size()); | ||||
auto& batch_conv_bias_opr = | auto& batch_conv_bias_opr = | ||||
opr->cast_final_safe<opr::BatchConvBiasForward>(); | opr->cast_final_safe<opr::BatchConvBiasForward>(); | ||||
mgb_assert(batch_conv_bias_opr.param().format == | |||||
if (batch_conv_bias_opr.param().format != | |||||
megdnn::param::BatchConvBias::Format::NCHW) { | |||||
return serialization::copy_opr_shallow(*opr, new_inp, | |||||
opr->config()); | |||||
} | |||||
mgb_assert(batch_conv_bias_opr.param().format == | |||||
megdnn::param::BatchConvBias::Format::NCHW, | megdnn::param::BatchConvBias::Format::NCHW, | ||||
"ConvertFormat Pass only support converting NCHW to NCHW4"); | "ConvertFormat Pass only support converting NCHW to NCHW4"); | ||||
// what should be converted: src, weight | // what should be converted: src, weight | ||||
@@ -1491,26 +1555,30 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ | |||||
}; | }; | ||||
auto replace_conv_bias_opr = [trans_nchw4, conv_bias_format, | auto replace_conv_bias_opr = [trans_nchw4, conv_bias_format, | ||||
src_to_nchw4_mode]( | src_to_nchw4_mode]( | ||||
OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) { | |||||
OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) { | |||||
mgb_assert(opr->input().size() == new_inp.size()); | mgb_assert(opr->input().size() == new_inp.size()); | ||||
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); | auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); | ||||
mgb_assert(conv_bias_opr.param().format == | |||||
megdnn::param::ConvBias::Format::NCHW, | |||||
"ConvertFormat Pass only support converting NCHW to NCHW4"); | |||||
if (conv_bias_opr.param().format != | |||||
megdnn::param::Convolution::Format::NCHW) { | |||||
return serialization::copy_opr_shallow(*opr, new_inp, | |||||
opr->config()); | |||||
} | |||||
// what should be converted: src, weight | // what should be converted: src, weight | ||||
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1]; | VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1]; | ||||
auto conv_mode = | |||||
trans_nchw4(conv_bias_opr.param().sparse, new_inp[1]); | |||||
// src: NCHW --> NCHW4 | // src: NCHW --> NCHW4 | ||||
if (new_inp[0]->shape().ndim !=5) { | |||||
if (new_inp[0]->shape().ndim != 5) { | |||||
mgb_assert(new_inp[0]->shape().ndim == 4); | mgb_assert(new_inp[0]->shape().ndim == 4); | ||||
auto new_src = RelayoutPlaceholder::make(new_inp[0], | |||||
src_to_nchw4_mode); | |||||
auto new_src = | |||||
RelayoutPlaceholder::make(new_inp[0], conv_mode.src); | |||||
conv_bias_src = new_src.node(); | conv_bias_src = new_src.node(); | ||||
} | } | ||||
// weight: NCHW --> NCHW4 or GNCHW --> GNCHW4 | // weight: NCHW --> NCHW4 or GNCHW --> GNCHW4 | ||||
auto weight_mode = | |||||
trans_nchw4(conv_bias_opr.param().sparse, new_inp[1]); | |||||
auto new_filter = RelayoutPlaceholder::make(new_inp[1], weight_mode); | |||||
auto new_filter = | |||||
RelayoutPlaceholder::make(new_inp[1], conv_mode.weight); | |||||
conv_bias_filter = new_filter.node(); | conv_bias_filter = new_filter.node(); | ||||
// format: NCHW --> NCHW4 | // format: NCHW --> NCHW4 | ||||
auto new_param = conv_bias_opr.param(); | auto new_param = conv_bias_opr.param(); | ||||
@@ -1527,8 +1595,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ | |||||
// bias: NCHW --> NCHW4 | // bias: NCHW --> NCHW4 | ||||
VarNode* conv_bias_bias = new_inp[2]; | VarNode* conv_bias_bias = new_inp[2]; | ||||
if (new_inp[2]->shape().ndim == 4) { | if (new_inp[2]->shape().ndim == 4) { | ||||
auto new_bias = RelayoutPlaceholder::make(new_inp[2], | |||||
src_to_nchw4_mode); | |||||
auto new_bias = | |||||
RelayoutPlaceholder::make(new_inp[2], src_to_nchw4_mode); | |||||
conv_bias_bias = new_bias.node(); | conv_bias_bias = new_bias.node(); | ||||
} | } | ||||
if (new_inp.size() == 3) { | if (new_inp.size() == 3) { | ||||
@@ -1543,8 +1611,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ | |||||
// z_inp: NCHW --> NCHW4 | // z_inp: NCHW --> NCHW4 | ||||
VarNode* z_inp = new_inp[3]; | VarNode* z_inp = new_inp[3]; | ||||
if (new_inp[3]->shape().ndim == 4) { | if (new_inp[3]->shape().ndim == 4) { | ||||
auto new_z = RelayoutPlaceholder::make(new_inp[3], | |||||
src_to_nchw4_mode); | |||||
auto new_z = | |||||
RelayoutPlaceholder::make(new_inp[3], src_to_nchw4_mode); | |||||
z_inp = new_z.node(); | z_inp = new_z.node(); | ||||
} | } | ||||
auto new_conv_bias_opr = opr::ConvBias::make(conv_bias_src, | auto new_conv_bias_opr = opr::ConvBias::make(conv_bias_src, | ||||
@@ -1599,18 +1667,100 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ | |||||
} | } | ||||
return serialization::copy_opr_shallow(*opr, temp_inp, opr->config()); | return serialization::copy_opr_shallow(*opr, temp_inp, opr->config()); | ||||
}; | }; | ||||
auto replace_pooling_opr = [](OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) { | |||||
using Param = opr::PoolingForward::Param; | |||||
using Format = Param::Format; | |||||
mgb_assert(opr->input().size() == new_inp.size()); | |||||
auto& pooling = opr->cast_final_safe<opr::PoolingForward>(); | |||||
if (pooling.param().format != Format::NCHW) { | |||||
return opr; | |||||
} | |||||
if (new_inp[0]->shape().ndim == 5) { | |||||
mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8); | |||||
auto new_param = pooling.param(); | |||||
new_param.format = Format::NCHW4; | |||||
auto new_pooling = | |||||
opr::PoolingForward::make(new_inp[0], new_param, opr->config()); | |||||
mgb_assert(new_pooling.shape().ndim == 5, | |||||
"out var of Pooling opr after transform must be 5 (got: " | |||||
"%zu).", | |||||
new_pooling.shape().ndim); | |||||
return new_pooling.node()->owner_opr(); | |||||
} | |||||
auto new_opr = | |||||
serialization::copy_opr_shallow(*opr, new_inp, opr->config()); | |||||
return new_opr; | |||||
}; | |||||
auto replace_resize_opr = [](OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) { | |||||
using Param = opr::ResizeForward::Param; | |||||
using Format = Param::Format; | |||||
mgb_assert(opr->input().size() == new_inp.size()); | |||||
auto& resize = opr->cast_final_safe<opr::ResizeForward>(); | |||||
if (new_inp[0]->shape().ndim == 5) { | |||||
mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8); | |||||
auto new_param = resize.param(); | |||||
new_param.format = Format::NCHW4; | |||||
auto new_resize = opr::ResizeForward::make( | |||||
new_inp[0], new_inp[1], new_param, opr->config()); | |||||
mgb_assert(new_resize.shape().ndim == 5, | |||||
"out var of Resize opr after transform must be 5 (got: " | |||||
"%zu).", | |||||
new_resize.shape().ndim); | |||||
return new_resize.node()->owner_opr(); | |||||
} | |||||
auto new_opr = | |||||
serialization::copy_opr_shallow(*opr, new_inp, opr->config()); | |||||
return new_opr; | |||||
}; | |||||
auto replace_warp_perspective_opr = [](OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) { | |||||
using Param = opr::WarpPerspective::Param; | |||||
using Format = Param::Format; | |||||
mgb_assert(opr->input().size() == new_inp.size()); | |||||
auto& warp = opr->cast_final_safe<opr::WarpPerspectiveForward>(); | |||||
if (new_inp[0]->shape().ndim == 5) { | |||||
mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8); | |||||
auto new_param = warp.param(); | |||||
new_param.format = Format::NCHW4; | |||||
SymbolVar new_warp; | |||||
if (new_inp.size() == 3) { | |||||
new_warp = opr::WarpPerspectiveForward::make( | |||||
new_inp[0], new_inp[1], nullptr, new_inp[2], new_param, | |||||
opr->config()); | |||||
} else { | |||||
mgb_assert(new_inp.size() == 4); | |||||
new_warp = opr::WarpPerspectiveForward::make( | |||||
new_inp[0], new_inp[1], new_inp[2], new_inp[3], | |||||
new_param, opr->config()); | |||||
} | |||||
mgb_assert(new_warp.shape().ndim == 5, | |||||
"out var of WarpPerspective opr after transform must be " | |||||
"5 (got: " | |||||
"%zu).", | |||||
new_warp.shape().ndim); | |||||
return new_warp.node()->owner_opr(); | |||||
} | |||||
auto new_opr = | |||||
serialization::copy_opr_shallow(*opr, new_inp, opr->config()); | |||||
return new_opr; | |||||
}; | |||||
auto&& replace_func = ret->m_opr_replace_func; | auto&& replace_func = ret->m_opr_replace_func; | ||||
//! supportted nchw4 | //! supportted nchw4 | ||||
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; | replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; | ||||
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; | replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; | ||||
replace_func[opr::BatchConvBias::typeinfo()] = | replace_func[opr::BatchConvBias::typeinfo()] = | ||||
replace_batch_conv_bias_opr; | replace_batch_conv_bias_opr; | ||||
replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; | |||||
replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr; | |||||
replace_func[opr::WarpPerspectiveForward::typeinfo()] = | |||||
replace_warp_perspective_opr; | |||||
replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr; | replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr; | ||||
replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; | replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; | ||||
replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr; | replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr; | ||||
replace_func[opr::PowC::typeinfo()] = replace_elemwise_opr; | replace_func[opr::PowC::typeinfo()] = replace_elemwise_opr; | ||||
//! not supported nchw4 | //! not supported nchw4 | ||||
replace_func[opr::PoolingForward::typeinfo()] = relayout_inp_to_nchw; | |||||
replace_func[opr::Concat::typeinfo()] = relayout_inp_to_nchw; | replace_func[opr::Concat::typeinfo()] = relayout_inp_to_nchw; | ||||
replace_func[opr::ConvolutionBackwardData::typeinfo()] = | replace_func[opr::ConvolutionBackwardData::typeinfo()] = | ||||
relayout_inp_to_nchw; | relayout_inp_to_nchw; | ||||
@@ -1620,9 +1770,6 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ | |||||
replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw; | replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw; | ||||
replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw; | replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw; | ||||
replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; | replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; | ||||
replace_func[opr::ResizeForward::typeinfo()] = relayout_inp_to_nchw; | |||||
replace_func[opr::WarpPerspectiveForward::typeinfo()] = | |||||
relayout_inp_to_nchw; | |||||
replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; | replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; | ||||
return ret; | return ret; | ||||
} | } | ||||
@@ -1512,6 +1512,7 @@ TEST_PASS(FuseConvBiasNonlinPass, Basic) { | |||||
#if MGB_CUDA | #if MGB_CUDA | ||||
TEST(TestEnableTensorCore, SmallInputShape) { | TEST(TestEnableTensorCore, SmallInputShape) { | ||||
REQUIRE_GPU(1); | REQUIRE_GPU(1); | ||||
auto cn = CompNode::load("gpu0"); | auto cn = CompNode::load("gpu0"); | ||||
@@ -1579,6 +1580,104 @@ TEST(TestEnableTensorCore, SmallInputShape) { | |||||
MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); | MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); | ||||
} | } | ||||
TEST(TestEnableTensorCore, Nchw4Nchw) { | |||||
REQUIRE_GPU(1); | |||||
auto cn = CompNode::load("gpu0"); | |||||
cn.activate(); | |||||
auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop; | |||||
auto sm_ver = prop.major * 10 + prop.minor; | |||||
if (sm_ver < 75) { | |||||
printf("This testcast ignored due to insufficient cuda cap(got: %d, " | |||||
"expected: %d)\n", | |||||
sm_ver, 75); | |||||
return; | |||||
} | |||||
HostTensorGenerator<dtype::Int8> gen; | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt_level = 0; | |||||
auto mkvar = [&](const char* name, const TensorShape& shp, | |||||
const DType& dtype) { | |||||
return opr::TypeCvt::make( | |||||
opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name), | |||||
dtype); | |||||
}; | |||||
auto mkcvar = [&](const char* name, const TensorShape& shp, | |||||
const DType& dtype) { | |||||
return opr::TypeCvt::make( | |||||
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
.rename(name), | |||||
dtype); | |||||
}; | |||||
auto mkshape = [](opr::ConvBias::Param::Format format, size_t N, size_t C, | |||||
size_t H, size_t W) -> TensorShape { | |||||
mgb_assert(C % 4 == 0); | |||||
if (format == opr::ConvBias::Param::Format::NCHW4) { | |||||
return {N, C / 4, H, W, 4}; | |||||
} else { | |||||
mgb_assert(format == opr::ConvBias::Param::Format::NCHW); | |||||
return {N, C, H, W}; | |||||
} | |||||
}; | |||||
for (auto format : {opr::ConvBias::Param::Format::NCHW, | |||||
opr::ConvBias::Param::Format::NCHW4}) { | |||||
auto x = mkvar("x", mkshape(format, 32, 64, 16, 16), | |||||
dtype::QuantizedS8(2.5f)), | |||||
w = mkcvar("w1", mkshape(format, 64, 64, 3, 3), | |||||
dtype::QuantizedS8(2.5f)), | |||||
b = mkcvar("b", mkshape(format, 1, 64, 1, 1), | |||||
dtype::QuantizedS32(6.25f)), | |||||
z = mkcvar("b1", mkshape(format, 32, 64, 8, 8), | |||||
dtype::QuantizedS8(2.5f)); | |||||
opr::ConvBias::Param param; | |||||
param.format = format; | |||||
param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; | |||||
param.stride_h = param.stride_w = 2; | |||||
param.pad_h = param.pad_w = 1; | |||||
auto y = opr::ConvBias::make( | |||||
x, w, b, z, param, {}, | |||||
OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); | |||||
y = opr::ConvBias::make(y, w, b, param, {}, | |||||
OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); | |||||
y = opr::TypeCvt::make(y, dtype::Float32()); | |||||
SymbolVar y_opt; | |||||
SymbolVar y_no_tc; | |||||
{ | |||||
auto options = gopt::OptimizeForInferenceOptions{}; | |||||
options.enable_nchw32().enable_fuse_conv_bias_nonlinearity(); | |||||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||||
} | |||||
{ | |||||
auto options = gopt::OptimizeForInferenceOptions{}; | |||||
options.enable_fuse_conv_bias_nonlinearity(); | |||||
unpack_vector(gopt::optimize_for_inference({y}, options), y_no_tc); | |||||
} | |||||
auto nr_dimshuffle = find_opr_num<mgb::opr::Dimshuffle>(y_opt); | |||||
std::string json_name; | |||||
ASSERT_EQ(2u, nr_dimshuffle); | |||||
if (format == opr::ConvBias::Param::Format::NCHW4) { | |||||
json_name = "TestGoptInference.Nchw4Nchw.NCHW4.json"; | |||||
} else { | |||||
mgb_assert(format == opr::ConvBias::Param::Format::NCHW); | |||||
json_name = "TestGoptInference.Nchw4Nchw.NCHW.json"; | |||||
} | |||||
graph->compile({{y_opt, {}}}) | |||||
->to_json() | |||||
->writeto_fpath(output_file(json_name.c_str())); | |||||
HostTensorND host_y, host_y_opt; | |||||
auto func = graph->compile({make_callback_copy(y_no_tc, host_y), | |||||
make_callback_copy(y_opt, host_y_opt)}); | |||||
func->execute(); | |||||
MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); | |||||
} | |||||
} | |||||
TEST(TestEnableTensorCore, ConvBiasWithZ) { | TEST(TestEnableTensorCore, ConvBiasWithZ) { | ||||
REQUIRE_GPU(1); | REQUIRE_GPU(1); | ||||
auto cn = CompNode::load("gpu0"); | auto cn = CompNode::load("gpu0"); | ||||
@@ -2043,53 +2142,74 @@ TEST(TestGoptInference, EnableCHWN4) { | |||||
.rename(name), | .rename(name), | ||||
dtype); | dtype); | ||||
}; | }; | ||||
auto mkshape = [](opr::ConvBias::Param::Format format, size_t N, size_t C, | |||||
size_t H, size_t W) -> TensorShape { | |||||
mgb_assert(C % 4 == 0); | |||||
if (format == opr::ConvBias::Param::Format::NCHW4) { | |||||
return {N, C / 4, H, W, 4}; | |||||
} else { | |||||
mgb_assert(format == opr::ConvBias::Param::Format::NCHW); | |||||
return {N, C, H, W}; | |||||
} | |||||
}; | |||||
auto x = mkvar("x", {32, 16, 16, 16, 4}, dtype::QuantizedS8(2.5f)), | |||||
w = mkcvar("w1", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)), | |||||
b = mkcvar("b", {1, 16, 1, 1, 4}, dtype::QuantizedS32(6.25f)), | |||||
b1 = mkvar("b1", {32, 16, 16, 16, 4}, dtype::QuantizedS8(2.5f)); | |||||
opr::ConvBias::Param param; | |||||
param.format = opr::ConvBias::Param::Format::NCHW4; | |||||
param.stride_h = param.stride_w = 1; | |||||
param.pad_h = param.pad_w = 1; | |||||
param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; | |||||
for (auto format : {opr::ConvBias::Param::Format::NCHW, | |||||
opr::ConvBias::Param::Format::NCHW4}) { | |||||
auto x = mkvar("x", mkshape(format, 32, 64, 16, 16), | |||||
dtype::QuantizedS8(2.5f)), | |||||
w = mkcvar("w1", mkshape(format, 64, 64, 3, 3), | |||||
dtype::QuantizedS8(2.5f)), | |||||
b = mkcvar("b", mkshape(format, 1, 64, 1, 1), | |||||
dtype::QuantizedS32(6.25f)), | |||||
b1 = mkvar("b1", mkshape(format, 32, 64, 16, 16), | |||||
dtype::QuantizedS8(2.5f)); | |||||
opr::ConvBias::Param param; | |||||
param.format = format; | |||||
param.stride_h = param.stride_w = 1; | |||||
param.pad_h = param.pad_w = 1; | |||||
param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; | |||||
auto y = opr::ConvBiasForward::make( | |||||
x, w, b, param, {}, OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
auto y1 = opr::ElemwiseMultiType::make( | |||||
{y, b1}, opr::ElemwiseMultiType::Mode::QFUSE_ADD_RELU, | |||||
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
auto y2 = opr::ConvBiasForward::make( | |||||
y, w, b, param, {}, OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
auto y3 = opr::ElemwiseMultiType::make( | |||||
{y, b1}, opr::ElemwiseMultiType::Param::Mode::QSUB, | |||||
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
auto y4 = opr::ElemwiseMultiType::make( | |||||
{y1, y2}, opr::ElemwiseMultiType::Param::Mode::QADD, | |||||
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
y4 = opr::ElemwiseMultiType::make( | |||||
{y3, y4}, opr::ElemwiseMultiType::Param::Mode::QADD, | |||||
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
y4 = opr::TypeCvt::make(y4, dtype::Float32()); | |||||
SymbolVar y_opt; | |||||
SymbolVar y_cudnn; | |||||
{ | |||||
auto options = gopt::OptimizeForInferenceOptions{}; | |||||
options.enable_chwn4(); | |||||
unpack_vector(gopt::optimize_for_inference({y4}, options), y_opt); | |||||
auto y = opr::ConvBiasForward::make( | |||||
x, w, b, param, {}, | |||||
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
auto y1 = opr::ElemwiseMultiType::make( | |||||
{y, b1}, opr::ElemwiseMultiType::Mode::QFUSE_ADD_RELU, | |||||
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
auto y2 = opr::ConvBiasForward::make( | |||||
y, w, b, param, {}, | |||||
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
auto y3 = opr::ElemwiseMultiType::make( | |||||
{y, b1}, opr::ElemwiseMultiType::Param::Mode::QSUB, | |||||
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
auto y4 = opr::ElemwiseMultiType::make( | |||||
{y1, y2}, opr::ElemwiseMultiType::Param::Mode::QADD, | |||||
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
y4 = opr::ElemwiseMultiType::make( | |||||
{y3, y4}, opr::ElemwiseMultiType::Param::Mode::QADD, | |||||
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
y4 = opr::TypeCvt::make(y4, dtype::Float32()); | |||||
SymbolVar y_opt; | |||||
SymbolVar y_cudnn; | |||||
{ | |||||
auto options = gopt::OptimizeForInferenceOptions{}; | |||||
options.enable_chwn4(); | |||||
unpack_vector(gopt::optimize_for_inference({y4}, options), y_opt); | |||||
} | |||||
unpack_vector(gopt::GraphOptimizer{} | |||||
.add_pass<gopt::FuseConvBiasNonlinPass>() | |||||
.add_pass<gopt::FuseConvBiasZPass>() | |||||
.apply({{y4}}) | |||||
.endpoint_vars(), | |||||
y_cudnn); | |||||
ASSERT_EQ(opr::ConvBias::Param::Format::CHWN4, | |||||
find_opr<opr::ConvBias>(y_opt).param().format); | |||||
HostTensorND host_y, host_y_opt; | |||||
auto func = graph->compile({make_callback_copy(y_cudnn, host_y), | |||||
make_callback_copy(y_opt, host_y_opt)}); | |||||
func->execute(); | |||||
MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); | |||||
} | } | ||||
unpack_vector(gopt::GraphOptimizer{} | |||||
.add_pass<gopt::FuseConvBiasNonlinPass>() | |||||
.add_pass<gopt::FuseConvBiasZPass>() | |||||
.apply({{y4}}) | |||||
.endpoint_vars(), | |||||
y_cudnn); | |||||
HostTensorND host_y, host_y_opt; | |||||
auto func = graph->compile({make_callback_copy(y_cudnn, host_y), | |||||
make_callback_copy(y_opt, host_y_opt)}); | |||||
func->execute(); | |||||
MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); | |||||
} | } | ||||
TEST(TestGoptInference, EnableCHWN4WarpPespective) { | TEST(TestGoptInference, EnableCHWN4WarpPespective) { | ||||
@@ -2430,14 +2550,16 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) { | |||||
auto w1 = mkcvar("w1", {8, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | auto w1 = mkcvar("w1", {8, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | ||||
b1 = mkcvar("b1", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | b1 = mkcvar("b1", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | ||||
auto conv1 = opr::ConvBiasForward::make( | auto conv1 = opr::ConvBiasForward::make( | ||||
x, w1, b1, param_conv_bias, {}, OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
x, w1, b1, param_conv_bias, {}, | |||||
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
// group | // group | ||||
// icpg != 1 && ocpg != 1 | // icpg != 1 && ocpg != 1 | ||||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; | param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; | ||||
auto w2 = mkcvar("w2", {2, 4, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | auto w2 = mkcvar("w2", {2, 4, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | ||||
b2 = mkcvar("b2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | b2 = mkcvar("b2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | ||||
auto conv2 = opr::ConvBiasForward::make(conv1, w2, b2, | |||||
param_conv_bias, {}, OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
auto conv2 = opr::ConvBiasForward::make( | |||||
conv1, w2, b2, param_conv_bias, {}, | |||||
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
auto y = opr::TypeCvt::make(conv2, dtype::Float32()); | auto y = opr::TypeCvt::make(conv2, dtype::Float32()); | ||||
@@ -2450,11 +2572,13 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) { | |||||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4, | ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4, | ||||
find_opr<opr::ConvBias>(y_opt).param().format); | find_opr<opr::ConvBias>(y_opt).param().format); | ||||
auto nr_reshape = find_opr_num<mgb::opr::Reshape>(y_opt); | |||||
ASSERT_EQ(2u, nr_reshape); | |||||
graph->compile({{y_opt, {}}}) | graph->compile({{y_opt, {}}}) | ||||
->to_json() | ->to_json() | ||||
->writeto_fpath( | |||||
output_file("TestGoptInference.ConvertFormatNCHW4GPU.json")); | |||||
->writeto_fpath(output_file( | |||||
"TestGoptInference.ConvertFormatNCHW4GPU.json")); | |||||
HostTensorND host_y, host_y_opt; | HostTensorND host_y, host_y_opt; | ||||
auto func = graph->compile({make_callback_copy(y, host_y), | auto func = graph->compile({make_callback_copy(y, host_y), | ||||
@@ -2465,6 +2589,90 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) { | |||||
#endif | #endif | ||||
TEST(TestGoptInference, ConvertFormatNCHW4NonConvOpr) { | |||||
auto cn = CompNode::load("xpu0"); | |||||
HostTensorGenerator<dtype::Int8> gen; | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt_level = 0; | |||||
auto mkvar = [&](const char* name, const TensorShape& shp, | |||||
const DType& dtype) { | |||||
return opr::TypeCvt::make( | |||||
opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name), | |||||
dtype); | |||||
}; | |||||
auto mkcvar = [&](const char* name, const TensorShape& shp, | |||||
const DType& dtype) { | |||||
return opr::TypeCvt::make( | |||||
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
.rename(name), | |||||
dtype); | |||||
}; | |||||
auto mkcvarf32 = [&](const char* name, const TensorShape& shp) { | |||||
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
.rename(name); | |||||
}; | |||||
auto x = mkvar("x", {2, 4, 16, 16}, dtype::QuantizedS8(2.5f)); | |||||
opr::ConvBias::Param param_conv_bias; | |||||
param_conv_bias.format = opr::ConvBias::Param::Format::NCHW; | |||||
param_conv_bias.stride_h = param_conv_bias.stride_w = 1; | |||||
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | |||||
param_conv_bias.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; | |||||
// dense | |||||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; | |||||
auto w1 = mkcvar("w1", {8, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | |||||
b1 = mkcvar("b1", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | |||||
auto conv1 = opr::ConvBiasForward::make( | |||||
x, w1, b1, param_conv_bias, {}, | |||||
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
// test Resize | |||||
auto shape_of = opr::GetVarShape::make(x); | |||||
auto subtensor = opr::Subtensor::make( | |||||
shape_of, {opr::Subtensor::AxisIndexer::make_interval( | |||||
0, x.make_scalar(2), None, x.make_scalar(1))}); | |||||
opr::Resize::Param param_resize; | |||||
param_resize.format = opr::Resize::Param::Format::NCHW; | |||||
auto resize = opr::ResizeForward::make(conv1, subtensor * 2, param_resize); | |||||
// test WarpPerspective | |||||
auto mat = mkcvarf32("mat", {2, 3, 3}), | |||||
warp = opr::WarpPerspectiveForward::make( | |||||
resize, mat, nullptr, cg::var_from_tensor_shape(x, {32, 32})); | |||||
opr::Pooling::Param pool_param; | |||||
pool_param.format = opr::Pooling::Param::Format::NCHW; | |||||
// test Pooling | |||||
auto pool = opr::Pooling::make(warp, pool_param); | |||||
// group | |||||
// icpg != 1 && ocpg != 1 | |||||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; | |||||
auto w2 = mkcvar("w2", {2, 4, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | |||||
b2 = mkcvar("b2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | |||||
auto conv2 = opr::ConvBiasForward::make( | |||||
pool, w2, b2, param_conv_bias, {}, | |||||
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
auto add = opr::ElemwiseMultiType::make( | |||||
{conv1, conv2}, {opr::ElemwiseMultiType::Param::Mode::QADD}, | |||||
OperatorNodeConfig{dtype::QuantizedS8{1.2f}}); | |||||
auto y = opr::TypeCvt::make(add, dtype::Float32()); | |||||
SymbolVar y_opt; | |||||
{ | |||||
auto options = gopt::OptimizeForInferenceOptions{}; | |||||
options.enable_nchw4(); | |||||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||||
} | |||||
auto nr_dimshuffle = find_opr_num<mgb::opr::Dimshuffle>(y_opt); | |||||
ASSERT_EQ(2u, nr_dimshuffle); | |||||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4, | |||||
find_opr<opr::ConvBias>(y_opt).param().format); | |||||
ASSERT_EQ(opr::ResizeForward::Param::Format::NCHW4, | |||||
find_opr<opr::ResizeForward>(y_opt).param().format); | |||||
ASSERT_EQ(opr::WarpPerspectiveForward::Param::Format::NCHW4, | |||||
find_opr<opr::WarpPerspectiveForward>(y_opt).param().format); | |||||
ASSERT_EQ(opr::PoolingForward::Param::Format::NCHW4, | |||||
find_opr<opr::PoolingForward>(y_opt).param().format); | |||||
} | |||||
TEST(TestGoptInference, ConvertFormatNCHW4) { | TEST(TestGoptInference, ConvertFormatNCHW4) { | ||||
HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
auto cn = CompNode::load("cpu0"); | auto cn = CompNode::load("cpu0"); | ||||
@@ -2479,7 +2687,7 @@ TEST(TestGoptInference, ConvertFormatNCHW4) { | |||||
}; | }; | ||||
auto x = mkvar("x", {2, 4, 16, 16}); | auto x = mkvar("x", {2, 4, 16, 16}); | ||||
// ConvBias | |||||
// ConvBias test dense | |||||
opr::ConvBias::Param param_conv_bias; | opr::ConvBias::Param param_conv_bias; | ||||
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | ||||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; | param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; | ||||
@@ -2517,6 +2725,67 @@ TEST(TestGoptInference, ConvertFormatNCHW4) { | |||||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | ||||
} | } | ||||
TEST(TestGoptInference, ConvertFormatNCHW4Ic3) { | |||||
REQUIRE_GPU(1); | |||||
HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> gen{ | |||||
1.2f, 127 * 127}; | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt_level = 0; | |||||
auto mkvar = [&](const char* name, const TensorShape& shp, | |||||
const DType& dtype) { | |||||
return opr::TypeCvt::make( | |||||
opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name), | |||||
dtype); | |||||
}; | |||||
auto mkcvar = [&](const char* name, const TensorShape& shp, | |||||
const DType& dtype) { | |||||
return opr::TypeCvt::make( | |||||
opr::SharedDeviceTensor::make(*graph, *gen(shp)) | |||||
.rename(name), | |||||
dtype); | |||||
}; | |||||
auto x = mkvar("x", {2, 3, 16, 16}, dtype::QuantizedS8(2.5f)); | |||||
// ConvBias test dense | |||||
opr::ConvBias::Param param_conv_bias; | |||||
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | |||||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; | |||||
auto w1 = mkcvar("w1", {8, 3, 3, 3}, dtype::QuantizedS8(2.5f)), | |||||
b1 = mkcvar("b1", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | |||||
auto conv1 = | |||||
opr::ConvBias::make(x, w1, b1, param_conv_bias, {}, | |||||
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; | |||||
auto w2 = mkcvar("w2", {2, 4, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | |||||
b2 = mkcvar("b2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | |||||
auto conv2 = | |||||
opr::ConvBias::make(conv1, w2, b2, param_conv_bias, {}, | |||||
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
auto y = opr::TypeCvt::make(conv2, dtype::Float32()); | |||||
SymbolVar y_opt; | |||||
{ | |||||
auto options = gopt::OptimizeForInferenceOptions{}; | |||||
options.enable_nchw4(); | |||||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||||
} | |||||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4, | |||||
find_opr<opr::ConvBias>(y_opt).param().format); | |||||
graph->compile({{y_opt, {}}}) | |||||
->to_json() | |||||
->writeto_fpath(output_file( | |||||
"TestGoptInference.ConvertFormatNCHW4Ic3.json")); | |||||
HostTensorND host_y_opt, host_y; | |||||
auto func = graph->compile({make_callback_copy(y, host_y), | |||||
make_callback_copy(y_opt, host_y_opt)}); | |||||
func->execute(); | |||||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | |||||
} | |||||
TEST(TestGoptInference, ConvertFormatNCHW88) { | TEST(TestGoptInference, ConvertFormatNCHW88) { | ||||
HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
auto cn = CompNode::load("cpu0"); | auto cn = CompNode::load("cpu0"); | ||||
@@ -55,3 +55,8 @@ struct IndexDescMaskItem { | |||||
table IndexDescMaskDump { | table IndexDescMaskDump { | ||||
items:[IndexDescMaskItem]; | items:[IndexDescMaskItem]; | ||||
} | } | ||||
table NMSKeep { | |||||
iou_thresh:float; | |||||
max_output:uint; | |||||
} |
@@ -30,74 +30,75 @@ table Blob { | |||||
table Reserved0 {} | table Reserved0 {} | ||||
union OperatorParam { | union OperatorParam { | ||||
param.Empty, | |||||
param.Axis, | |||||
param.Convolution, | |||||
param.MaskPropagate, | |||||
param.ConvPooling, | |||||
param.ConvBias, | |||||
param.SeparableConv, | |||||
param.Images2Neibs, | |||||
param.Pooling, | |||||
param.LRN, | |||||
param.BN, | |||||
param.ROIPooling, | |||||
param.WarpPerspective, | |||||
param.SpatialTfGridGenerator, | |||||
param.SpatialTfSampler, | |||||
param.MGBAddUpdate, | |||||
param.Elemwise, | |||||
param.ElemwiseMultiType, | |||||
param.PowC, | |||||
param.MatrixMul, | |||||
param.Winograd, | |||||
param.SVD, | |||||
param.Reduce, | |||||
param.Cumsum, | |||||
param.CondTake, | |||||
param.Argsort, | |||||
param.IndexingRemap, | |||||
param.MGBSleep, | |||||
param.Linspace, | |||||
param.LinspaceFull, | |||||
param.Eye, | |||||
param.UniformRNG, | |||||
param.GaussianRNG, | |||||
param.Flip, | |||||
param.Rotate, | |||||
param.ROICopy, | |||||
param.CvtColor, | |||||
param.WarpAffine, | |||||
param.GaussianBlur, | |||||
param.Resize, | |||||
param.Remap, | |||||
param.Convolution3D, | |||||
param.Conv3DBias, | |||||
param.SeparableConv3D, | |||||
param.TopK, | |||||
param.RelayoutFormat, | |||||
param.SeparableFilter, | |||||
param.LocalShare, | |||||
param.ROIAlign, | |||||
param.DeformablePSROIPooling, | |||||
param.BatchConvBias, | |||||
param.DType, | |||||
param.PersistentOutputStorage, | |||||
param.OptionalAxis, | |||||
param.OptionalAxisV1, | |||||
param.ExecutionPolicy, | |||||
param.AssertEqual, | |||||
Reserved0, | |||||
param.CollectiveComm, | |||||
param.CondExecPred, | |||||
param.CondExecPredLogical, | |||||
param.CondExecMark, | |||||
param.CondExecMerge, | |||||
param.Host2DeviceCopy, | |||||
param.Dimshuffle, | |||||
param.AxisAddRemove, | |||||
param.IndexDescMaskDump, | |||||
DType, | |||||
param.Empty = 1, | |||||
param.Axis = 2, | |||||
param.Convolution = 3, | |||||
param.MaskPropagate = 4, | |||||
param.ConvPooling = 5, | |||||
param.ConvBias = 6, | |||||
param.SeparableConv = 7, | |||||
param.Images2Neibs = 8, | |||||
param.Pooling = 9, | |||||
param.LRN = 10, | |||||
param.BN = 11, | |||||
param.ROIPooling = 12, | |||||
param.WarpPerspective = 13, | |||||
param.SpatialTfGridGenerator = 14, | |||||
param.SpatialTfSampler = 15, | |||||
param.MGBAddUpdate = 16, | |||||
param.Elemwise = 17, | |||||
param.ElemwiseMultiType = 18, | |||||
param.PowC = 19, | |||||
param.MatrixMul = 20, | |||||
param.Winograd = 21, | |||||
param.SVD = 22, | |||||
param.Reduce = 23, | |||||
param.Cumsum = 24, | |||||
param.CondTake = 25, | |||||
param.Argsort = 26, | |||||
param.IndexingRemap = 27, | |||||
param.MGBSleep = 28, | |||||
param.Linspace = 29, | |||||
param.LinspaceFull = 30, | |||||
param.Eye = 31, | |||||
param.UniformRNG = 32, | |||||
param.GaussianRNG = 33, | |||||
param.Flip = 34, | |||||
param.Rotate = 35, | |||||
param.ROICopy = 36, | |||||
param.CvtColor = 37, | |||||
param.WarpAffine = 38, | |||||
param.GaussianBlur = 39, | |||||
param.Resize = 40, | |||||
param.Convolution3D = 41, | |||||
param.Conv3DBias = 42, | |||||
param.SeparableConv3D = 43, | |||||
param.TopK = 44, | |||||
param.RelayoutFormat = 45, | |||||
param.SeparableFilter = 46, | |||||
param.LocalShare = 47, | |||||
param.ROIAlign = 48, | |||||
param.DeformablePSROIPooling = 49, | |||||
param.BatchConvBias = 50, | |||||
param.DType = 51, | |||||
param.PersistentOutputStorage = 52, | |||||
param.OptionalAxis = 53, | |||||
param.OptionalAxisV1 = 54, | |||||
param.ExecutionPolicy = 55, | |||||
param.AssertEqual = 56, | |||||
Reserved0 = 57, | |||||
param.CollectiveComm = 58, | |||||
param.CondExecPred = 59, | |||||
param.CondExecPredLogical = 60, | |||||
param.CondExecMark = 61, | |||||
param.CondExecMerge = 62, | |||||
param.Host2DeviceCopy = 63, | |||||
param.Dimshuffle = 64, | |||||
param.AxisAddRemove = 65, | |||||
param.IndexDescMaskDump = 66, | |||||
DType = 67, | |||||
param.Remap = 68, | |||||
param.NMSKeep = 69, | |||||
} | } | ||||
table Operator { | table Operator { | ||||
@@ -846,7 +846,7 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, | |||||
OprLoadContextImpl ctx{this, m_graph->mgb_version()}; | OprLoadContextImpl ctx{this, m_graph->mgb_version()}; | ||||
auto result = ctx.load_oprs(); | auto result = ctx.load_oprs(); | ||||
auto fbs_end = tensor_begin + offset_to_fbs + size; | |||||
auto fbs_end = tensor_begin + offset_to_fbs + sizeof(size) + size; | |||||
auto cur = m_file->tell(); | auto cur = m_file->tell(); | ||||
mgb_assert(fbs_end > cur); | mgb_assert(fbs_end > cur); | ||||
// Skip to Graph end | // Skip to Graph end | ||||
@@ -872,4 +872,4 @@ bool is_fbs_file(InputFile& file) { | |||||
} // namespace serialization | } // namespace serialization | ||||
} // namespace mgb | } // namespace mgb | ||||
#endif | |||||
#endif |
@@ -64,6 +64,34 @@ TEST(TestSerializer2, GraphDumpLoad) { | |||||
load(); | load(); | ||||
} | } | ||||
TEST(TestSerializer2, MultiGraphDumpLoad) { | |||||
auto fname = GET_OUTPUT_FILE(); | |||||
auto dump = [&]() { | |||||
auto cn = CompNode::load("cpu0"); | |||||
auto graph = ComputingGraph::make(); | |||||
auto x = opr::ImmutableTensor::make(*graph, 1926.0817f, {cn}); | |||||
x.rename("varz"); | |||||
auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()), | |||||
GraphDumpFormat::FLATBUFFERS); | |||||
// dump twice | |||||
dumper->dump({x}); | |||||
dumper->dump({x}); | |||||
}; | |||||
auto load = [&]() { | |||||
GraphLoader::LoadConfig load_config = {}; | |||||
auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()), | |||||
GraphDumpFormat::FLATBUFFERS); | |||||
// load twice | |||||
loader->load(load_config, false); | |||||
loader = GraphLoader::make(loader->reset_file(), loader->format()); | |||||
loader->load(load_config, false); | |||||
}; | |||||
dump(); | |||||
load(); | |||||
} | |||||
TEST(TestSerializer2, APlusB) { | TEST(TestSerializer2, APlusB) { | ||||
auto fname = GET_OUTPUT_FILE(); | auto fname = GET_OUTPUT_FILE(); | ||||
TensorShape shape{2, 3}; | TensorShape shape{2, 3}; | ||||
@@ -733,4 +761,4 @@ TEST(TestSerializer2, HasOutputDtype) { | |||||
load(); | load(); | ||||
} | } | ||||
#endif | |||||
#endif |
@@ -1727,8 +1727,17 @@ void TensorRTReplacePass::Impl::TensorRTGraph::mark_varnode_format_nchw4() { | |||||
} | } | ||||
} | } | ||||
void mgb::tensorrt::transform_dest_vars_inplace(mgb::cg::VarNodeArray& dest_vars) { | |||||
void mgb::tensorrt::transform_dest_vars_inplace( | |||||
mgb::cg::VarNodeArray& dest_vars, | |||||
cg::GraphCommonOptimizeOptions& options) { | |||||
gopt::GraphOptimizer optimizer; | gopt::GraphOptimizer optimizer; | ||||
//! As in megengine, the layout is NCHW, while tensorrt pass currently | |||||
//! only support NCHW4(int8), so we transform layout to nchw4 firstly. | |||||
if (options.has_set_nchw4()) { | |||||
options.disable_nchw4(); | |||||
optimizer.add_pass<FuseConvBiasNonlinPass>(); | |||||
optimizer.add_pass(EnableNCHW4Pass::make_nchw4_converter()); | |||||
} | |||||
optimizer.add_pass<ExpandFusedArithPass>(); | optimizer.add_pass<ExpandFusedArithPass>(); | ||||
optimizer.add_pass<gopt::TensorRTReplacePass>(); | optimizer.add_pass<gopt::TensorRTReplacePass>(); | ||||
optimizer.add_pass<ArithFusePass>(); | optimizer.add_pass<ArithFusePass>(); | ||||
@@ -32,7 +32,8 @@ public: | |||||
namespace tensorrt { | namespace tensorrt { | ||||
void transform_dest_vars_inplace(mgb::cg::VarNodeArray& dest_vars); | |||||
void transform_dest_vars_inplace(mgb::cg::VarNodeArray& dest_vars, | |||||
cg::GraphCommonOptimizeOptions& options); | |||||
} | } | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -1930,7 +1930,7 @@ TEST(TestTensorRTReplace, FuseConvAdd) { | |||||
param.stride_h = param.stride_w = 1; | param.stride_h = param.stride_w = 1; | ||||
param.pad_h = param.pad_w = 1; | param.pad_h = param.pad_w = 1; | ||||
auto y = opr::Convolution::make(x, w, param); | auto y = opr::Convolution::make(x, w, param); | ||||
auto nchw2nchw4 = [](SymbolVar x) { | auto nchw2nchw4 = [](SymbolVar x) { | ||||
auto xshp = opr::GetVarShape::make(x); | auto xshp = opr::GetVarShape::make(x); | ||||
@@ -1978,6 +1978,68 @@ TEST(TestTensorRTReplace, FuseConvAdd) { | |||||
MGB_ASSERT_TENSOR_NEAR(outputs[1], outputs[3], 1e-3); | MGB_ASSERT_TENSOR_NEAR(outputs[1], outputs[3], 1e-3); | ||||
} | } | ||||
TEST(TestTensorRTReplace, FuseConvAddNchw2nchw4) { | |||||
REQUIRE_GPU(1); | |||||
HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> gen{ | |||||
1.2f, 127 * 127}; | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt_level = 0; | |||||
auto mkvar = [&](const char* name, const TensorShape& shp, | |||||
const DType& dtype) { | |||||
return opr::TypeCvt::make( | |||||
opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name), | |||||
dtype); | |||||
}; | |||||
auto mkcvar = [&](const char* name, const TensorShape& shp, | |||||
const DType& dtype) { | |||||
return opr::TypeCvt::make( | |||||
opr::SharedDeviceTensor::make(*graph, *gen(shp)) | |||||
.rename(name), | |||||
dtype); | |||||
}; | |||||
auto x = mkvar("x", {32, 4, 28, 28}, dtype::QuantizedS8(2.5f)), | |||||
w = mkcvar("w", {16, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | |||||
b = mkcvar("b", {1, 16, 1, 1}, dtype::QuantizedS32(6.25f)); | |||||
opr::ConvBias::Param param; | |||||
param.format = opr::ConvBias::Param::Format::NCHW; | |||||
param.stride_h = param.stride_w = 1; | |||||
param.pad_h = param.pad_w = 1; | |||||
auto y = opr::ConvBias::make(x, w, b, param, {}, | |||||
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
auto z = opr::TypeCvt::make(y, dtype::Float32()); | |||||
SymbolVar trt_z; | |||||
SymbolVar mgb_z; | |||||
ComputingGraph::Options opt; | |||||
opt.graph_opt_level = 0; | |||||
unpack_vector( | |||||
gopt::GraphOptimizer{} | |||||
.add_pass<gopt::FuseConvBiasNonlinPass>() | |||||
.add_pass(gopt::EnableNCHW4Pass::make_nchw4_converter()) | |||||
.add_pass<gopt::ExpandFusedArithPass>() | |||||
.add_pass<gopt::TensorRTReplacePass>() | |||||
.add_pass<gopt::ArithFusePass>() | |||||
.apply({{z}}) | |||||
.endpoint_vars(), | |||||
trt_z); | |||||
opt.graph_opt_level = 0; | |||||
unpack_vector(gopt::GraphOptimizer{}.apply({{z}}).endpoint_vars(), | |||||
mgb_z); | |||||
ComputingGraph::OutputSpec outspec(2); | |||||
SmallVector<HostTensorND> outputs(2); | |||||
outspec[0] = make_callback_copy(trt_z, outputs[0], false); | |||||
outspec[1] = make_callback_copy(mgb_z, outputs[1], false); | |||||
graph->options().graph_opt.tensorrt = false; | |||||
auto func = graph->compile(outspec); | |||||
func->execute(); | |||||
MGB_ASSERT_TENSOR_NEAR(outputs[0], outputs[1], 1e-3); | |||||
} | |||||
#endif // MGB_ENABLE_TENSOR_RT | #endif // MGB_ENABLE_TENSOR_RT | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |