@@ -194,6 +194,26 @@ R"__usage__( | |||||
Execute operators with kernels implemented in MegDNN with CHWN4 tensor format. Can only be used | Execute operators with kernels implemented in MegDNN with CHWN4 tensor format. Can only be used | ||||
on Nvidia GPUs, whose compute capability is above 6.1. | on Nvidia GPUs, whose compute capability is above 6.1. | ||||
)__usage__" | )__usage__" | ||||
R"__usage__( | |||||
--enable-nchw44 | |||||
Execute operators with kernels implemented in MegDNN with NCHW44 tensor format. This can only | |||||
be used on arm of armv7 and arm64, support data tyep of float32, qint8 and int8x8x16. | |||||
)__usage__" | |||||
R"__usage__( | |||||
--enable-nhw88 | |||||
Execute operators with kernels implemented in MegDNN with NCHW88 tensor format. This can only | |||||
be used on x86 with data type float. | |||||
)__usage__" | |||||
R"__usage__( | |||||
--enable-nhw44-dot | |||||
Execute operators with kernels implemented in MegDNN with NCHW44-DOT tensor format. This Can | |||||
only be used on arm32 and arm64 with dot-product supported, and only support qint8 model | |||||
)__usage__" | |||||
R"__usage__( | |||||
--weight-preprocess | |||||
Execute operators with weight preprocess, which can optimize the operator execution time with | |||||
algo of winograd, im2col ,etc., but it may consume more memory. | |||||
)__usage__" | |||||
; | ; | ||||
@@ -1226,6 +1246,11 @@ Args Args::from_argv(int argc, char **argv) { | |||||
graph_opt.graph_opt.weight_winograd_transform = true; | graph_opt.graph_opt.weight_winograd_transform = true; | ||||
continue; | continue; | ||||
} | } | ||||
if (!strcmp(argv[i], "--weight-preprocess")) { | |||||
mgb_log_warn("enable weight-preprocess optimization"); | |||||
graph_opt.graph_opt.enable_weight_preprocess(); | |||||
continue; | |||||
} | |||||
fprintf(stderr, "invalid arg: %s\n", argv[i]); | fprintf(stderr, "invalid arg: %s\n", argv[i]); | ||||
ret.args_parse_ret = -1; | ret.args_parse_ret = -1; | ||||
@@ -97,6 +97,9 @@ struct GraphCommonOptimizeOptions { | |||||
bool fuse_conv_bias_with_z = false; | bool fuse_conv_bias_with_z = false; | ||||
//! whether to enable fast-run profiled winograd opr replace | //! whether to enable fast-run profiled winograd opr replace | ||||
bool weight_winograd_transform = false; | bool weight_winograd_transform = false; | ||||
//! whether to enable weight preprocess, if enabled it may use more | |||||
//! memory, default disable now | |||||
bool weight_preprocess = false; | |||||
enum LayoutTransform : uint32_t { | enum LayoutTransform : uint32_t { | ||||
DEFAULT, | DEFAULT, | ||||
NCHW4, ///< compute using NCHW4 tensor format | NCHW4, ///< compute using NCHW4 tensor format | ||||
@@ -127,6 +130,7 @@ struct GraphCommonOptimizeOptions { | |||||
SET(fuse_conv_bias_nonlinearity); | SET(fuse_conv_bias_nonlinearity); | ||||
SET(fuse_conv_bias_with_z); | SET(fuse_conv_bias_with_z); | ||||
SET(weight_winograd_transform); | SET(weight_winograd_transform); | ||||
SET(weight_preprocess); | |||||
#undef SET | #undef SET | ||||
#define SET(_trans, _trans_capital) \ | #define SET(_trans, _trans_capital) \ | ||||
GraphCommonOptimizeOptions& enable_##_trans() { \ | GraphCommonOptimizeOptions& enable_##_trans() { \ | ||||
@@ -963,6 +963,9 @@ void mixin::WeightPreprocessExecutor::record_preprocessed_weight( | |||||
bool mixin::WeightPreprocessExecutor::mixin_allow_weight_preprocess( | bool mixin::WeightPreprocessExecutor::mixin_allow_weight_preprocess( | ||||
const cg::OperatorNodeBase& opr) const { | const cg::OperatorNodeBase& opr) const { | ||||
if (!opr.owner_graph()->options().graph_opt.weight_preprocess) { | |||||
return false; | |||||
} | |||||
if (!opr.input(1)->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE)) | if (!opr.input(1)->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE)) | ||||
return false; | return false; | ||||
if (cg::is_const_var_value(opr.input(1))) | if (cg::is_const_var_value(opr.input(1))) | ||||
@@ -2225,6 +2225,7 @@ protected: | |||||
iw = ih; | iw = ih; | ||||
comp_node = CompNode::load("cpux"); | comp_node = CompNode::load("cpux"); | ||||
graph = ComputingGraph::make(); | graph = ComputingGraph::make(); | ||||
graph->options().graph_opt.weight_preprocess = is_weight_preprocess(); | |||||
TensorShape x_shape{1, ic, ih, iw}, w_shape{oc, ic, fh, fh}; | TensorShape x_shape{1, ic, ih, iw}, w_shape{oc, ic, fh, fh}; | ||||
x_host = std::make_shared<HostTensorND>(comp_node, x_shape); | x_host = std::make_shared<HostTensorND>(comp_node, x_shape); | ||||
auto x = opr::Host2DeviceCopy::make(*graph, x_host); | auto x = opr::Host2DeviceCopy::make(*graph, x_host); | ||||
@@ -2247,6 +2248,8 @@ protected: | |||||
void run() { func->execute().wait(); } | void run() { func->execute().wait(); } | ||||
virtual bool is_weight_preprocess() { return true; } | |||||
void TearDown() override { | void TearDown() override { | ||||
func.reset(); | func.reset(); | ||||
// Triggers mock check | // Triggers mock check | ||||
@@ -2346,6 +2349,33 @@ TEST_F(TestWeightPreprocess, PreprocessCalledOnlyOnce) { | |||||
} | } | ||||
} | } | ||||
class TestNoWeightPreprocess : public TestWeightPreprocess { | |||||
bool is_weight_preprocess() override { return false; } | |||||
}; | |||||
TEST_F(TestNoWeightPreprocess, NoPreprocess) { | |||||
using ::testing::_; | |||||
using ::testing::Return; | |||||
auto& mock = mock_conv(); | |||||
MockAlgorithm algo; | |||||
EXPECT_CALL(mock, get_algorithm_heuristic(_, _, _, _, _)) | |||||
.WillRepeatedly(Return(&algo)); | |||||
EXPECT_CALL(mock, get_workspace_in_bytes(_, _, _, _)) | |||||
.WillRepeatedly(Return(0)); | |||||
EXPECT_CALL(mock, get_preprocess_workspace_in_bytes(_, _, _)) | |||||
.WillRepeatedly(Return(0)); | |||||
{ | |||||
::testing::InSequence seq; | |||||
// Return empty preprocess filters, indicating no need to preprocess | |||||
EXPECT_CALL(mock, deduce_preprocessed_filter_layout(_, _, _)).Times(0); | |||||
EXPECT_CALL(mock, exec_preprocess(_, _, _, _, _)).Times(0); | |||||
EXPECT_CALL(mock, exec(_, _, _, nullptr, _)); | |||||
run(); | |||||
} | |||||
} | |||||
} // anonymous namespace | } // anonymous namespace | ||||
#endif | #endif | ||||