GitOrigin-RevId: 4d1a9c6c84
tags/v0.5.0
@@ -542,7 +542,8 @@ def optimize_for_inference( | |||||
use_nchw32=False, | use_nchw32=False, | ||||
fuse_conv_bias_with_z=False, | fuse_conv_bias_with_z=False, | ||||
use_nchw88=False, | use_nchw88=False, | ||||
use_nchw44=False | |||||
use_nchw44=False, | |||||
use_chwn4=False | |||||
): | ): | ||||
"""optimize computing graph for inference | """optimize computing graph for inference | ||||
@@ -566,6 +567,8 @@ def optimize_for_inference( | |||||
times. | times. | ||||
:param use_nchw32: whether to use NCHW32 tensor format. Mainly used for | :param use_nchw32: whether to use NCHW32 tensor format. Mainly used for | ||||
nvidia tensorcore. | nvidia tensorcore. | ||||
:param use_chwn4: whether to use CHWN4 tensor format. Mainly used for | |||||
nvidia tensorcore. | |||||
:return: list of transformed vars corresponding to given output vars | :return: list of transformed vars corresponding to given output vars | ||||
@@ -589,6 +592,7 @@ def optimize_for_inference( | |||||
"use_nchw32": "nchw2nchw32", | "use_nchw32": "nchw2nchw32", | ||||
"use_nchw88": "nchw2nchw88", | "use_nchw88": "nchw2nchw88", | ||||
"use_nchw44": "nchw2nchw44", | "use_nchw44": "nchw2nchw44", | ||||
"use_chwn4": "nchw42chwn4", | |||||
}.items(): | }.items(): | ||||
if settings[k]: | if settings[k]: | ||||
assert ( | assert ( | ||||
@@ -84,6 +84,7 @@ struct _OptimizeForInferenceOptions { | |||||
SET(nchw2nchw88, NCHW2NCHW88); | SET(nchw2nchw88, NCHW2NCHW88); | ||||
SET(nchw2nchw44, NCHW2NCHW44); | SET(nchw2nchw44, NCHW2NCHW44); | ||||
SET(nchw2nchw32, NCHW2NCHW32); | SET(nchw2nchw32, NCHW2NCHW32); | ||||
SET(nchw42chwn4, NCHW42CHWN4); | |||||
#undef SET | #undef SET | ||||
}; | }; | ||||
@@ -254,8 +254,9 @@ def optimize_for_inference(args, outputs): | |||||
'enable_hwcd4': 'use_nhwcd4', | 'enable_hwcd4': 'use_nhwcd4', | ||||
'enable_nchw88': 'use_nchw88', | 'enable_nchw88': 'use_nchw88', | ||||
'enable_nchw44': 'use_nchw44', | 'enable_nchw44': 'use_nchw44', | ||||
'enable_fuse_conv_bias_nonlinearity': 'fuse_conv_bias_nonlinearity', | |||||
'enable_nchw32': 'use_nchw32', | 'enable_nchw32': 'use_nchw32', | ||||
'enable_chwn4': 'use_chwn4', | |||||
'enable_fuse_conv_bias_nonlinearity': 'fuse_conv_bias_nonlinearity', | |||||
'enable_fuse_conv_bias_with_z': 'fuse_conv_bias_with_z', | 'enable_fuse_conv_bias_with_z': 'fuse_conv_bias_with_z', | ||||
} | } | ||||
kwargs = {} | kwargs = {} | ||||
@@ -399,6 +400,12 @@ def main(): | |||||
'for inference on nvidia TensoCore' | 'for inference on nvidia TensoCore' | ||||
) | ) | ||||
parser.add_argument( | parser.add_argument( | ||||
'--enable-chwn4', | |||||
action='store_true', | |||||
help='transform the model format to CHWN4 ' | |||||
'for inference, mainly used for nvidia tensorcore' | |||||
) | |||||
parser.add_argument( | |||||
'--enable-fuse-conv-bias-with-z', | '--enable-fuse-conv-bias-with-z', | ||||
action='store_true', | action='store_true', | ||||
help='fuse conv_bias with z input for inference on ' | help='fuse conv_bias with z input for inference on ' | ||||
@@ -724,6 +724,13 @@ void GraphOptimizer::apply_optimize_options( | |||||
add_pass<ShuffleShuffleRemovePass>(); | add_pass<ShuffleShuffleRemovePass>(); | ||||
add_pass<RemoveRedundantTypeCvtPass>(); | add_pass<RemoveRedundantTypeCvtPass>(); | ||||
} | } | ||||
if (options->transform_nchw42chwn4()) { | |||||
add_pass<FuseConvBiasNonlinPass>(); | |||||
add_pass<FuseConvBiasZPass>(); | |||||
add_pass(EnableCHWN4Pass::make_chwn4_converter()); | |||||
add_pass<ShuffleShuffleRemovePass>(); | |||||
add_pass<RemoveRedundantTypeCvtPass>(); | |||||
} | |||||
if (options->fuse_conv_bias_nonlinearity) { | if (options->fuse_conv_bias_nonlinearity) { | ||||
add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
@@ -395,6 +395,8 @@ namespace gopt { | |||||
NCHW2NCHW44, ///< compute using NCHW44 tensor format | NCHW2NCHW44, ///< compute using NCHW44 tensor format | ||||
NCHW2NCHW32, ///< compute using NCHW32 tensor format, used for | NCHW2NCHW32, ///< compute using NCHW32 tensor format, used for | ||||
///< tensorcore | ///< tensorcore | ||||
NCHW42CHWN4, ///< compute using CHWN4 tensor format, transformed | |||||
///< from NCHW4, mainly used for cuda | |||||
}; | }; | ||||
LayoutTransform layout_transform = LayoutTransform::DEFAULT; | LayoutTransform layout_transform = LayoutTransform::DEFAULT; | ||||
//! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) | //! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) | ||||
@@ -424,6 +426,7 @@ namespace gopt { | |||||
SET(nchw2nchw88, NCHW2NCHW88); | SET(nchw2nchw88, NCHW2NCHW88); | ||||
SET(nchw2nchw44, NCHW2NCHW44); | SET(nchw2nchw44, NCHW2NCHW44); | ||||
SET(nchw2nchw32, NCHW2NCHW32); | SET(nchw2nchw32, NCHW2NCHW32); | ||||
SET(nchw42chwn4, NCHW42CHWN4); | |||||
#undef SET | #undef SET | ||||
}; | }; | ||||
@@ -2011,14 +2011,11 @@ TEST(TestGoptInference, EnableCHWN4) { | |||||
y4 = opr::TypeCvt::make(y4, dtype::Float32()); | y4 = opr::TypeCvt::make(y4, dtype::Float32()); | ||||
SymbolVar y_opt; | SymbolVar y_opt; | ||||
SymbolVar y_cudnn; | SymbolVar y_cudnn; | ||||
unpack_vector( | |||||
gopt::GraphOptimizer{} | |||||
.add_pass<gopt::FuseConvBiasNonlinPass>() | |||||
.add_pass(gopt::EnableCHWN4Pass::make_chwn4_converter()) | |||||
.add_pass<gopt::FuseConvBiasZPass>() | |||||
.apply({{y4}}) | |||||
.endpoint_vars(), | |||||
y_opt); | |||||
{ | |||||
auto options = gopt::OptimizeForInferenceOptions{}; | |||||
options.enable_nchw42chwn4(); | |||||
unpack_vector(gopt::optimize_for_inference({y4}, options), y_opt); | |||||
} | |||||
unpack_vector(gopt::GraphOptimizer{} | unpack_vector(gopt::GraphOptimizer{} | ||||
.add_pass<gopt::FuseConvBiasNonlinPass>() | .add_pass<gopt::FuseConvBiasNonlinPass>() | ||||
.add_pass<gopt::FuseConvBiasZPass>() | .add_pass<gopt::FuseConvBiasZPass>() | ||||
@@ -2100,13 +2097,11 @@ TEST(TestGoptInference, EnableCHWN4WarpPespective) { | |||||
auto y2 = opr::WarpPerspective::make(y1, mat_var, TensorShape{16, 16}, warp_param); | auto y2 = opr::WarpPerspective::make(y1, mat_var, TensorShape{16, 16}, warp_param); | ||||
SymbolVar y_opt; | SymbolVar y_opt; | ||||
SymbolVar y_cudnn; | SymbolVar y_cudnn; | ||||
unpack_vector(gopt::GraphOptimizer{} | |||||
.add_pass<gopt::FuseConvBiasNonlinPass>() | |||||
.add_pass<gopt::FuseConvBiasZPass>() | |||||
.add_pass(gopt::EnableCHWN4Pass::make_chwn4_converter()) | |||||
.apply({{y2}}) | |||||
.endpoint_vars(), | |||||
y_opt); | |||||
{ | |||||
auto options = gopt::OptimizeForInferenceOptions{}; | |||||
options.enable_nchw42chwn4(); | |||||
unpack_vector(gopt::optimize_for_inference({y2}, options), y_opt); | |||||
} | |||||
unpack_vector(gopt::GraphOptimizer{} | unpack_vector(gopt::GraphOptimizer{} | ||||
.add_pass<gopt::FuseConvBiasNonlinPass>() | .add_pass<gopt::FuseConvBiasNonlinPass>() | ||||
.add_pass<gopt::FuseConvBiasZPass>() | .add_pass<gopt::FuseConvBiasZPass>() | ||||