@@ -194,6 +194,26 @@ R"__usage__( | |||
Execute operators with kernels implemented in MegDNN with CHWN4 tensor format. Can only be used | |||
on Nvidia GPUs, whose compute capability is above 6.1. | |||
)__usage__" | |||
R"__usage__( | |||
--enable-nchw44 | |||
Execute operators with kernels implemented in MegDNN with NCHW44 tensor format. This can only | |||
be used on arm of armv7 and arm64, support data tyep of float32, qint8 and int8x8x16. | |||
)__usage__" | |||
R"__usage__( | |||
--enable-nhw88 | |||
Execute operators with kernels implemented in MegDNN with NCHW88 tensor format. This can only | |||
be used on x86 with data type float. | |||
)__usage__" | |||
R"__usage__( | |||
--enable-nhw44-dot | |||
Execute operators with kernels implemented in MegDNN with NCHW44-DOT tensor format. This Can | |||
only be used on arm32 and arm64 with dot-product supported, and only support qint8 model | |||
)__usage__" | |||
R"__usage__( | |||
--weight-preprocess | |||
Execute operators with weight preprocess, which can optimize the operator execution time with | |||
algo of winograd, im2col ,etc., but it may consume more memory. | |||
)__usage__" | |||
; | |||
@@ -1226,6 +1246,11 @@ Args Args::from_argv(int argc, char **argv) { | |||
graph_opt.graph_opt.weight_winograd_transform = true; | |||
continue; | |||
} | |||
if (!strcmp(argv[i], "--weight-preprocess")) { | |||
mgb_log_warn("enable weight-preprocess optimization"); | |||
graph_opt.graph_opt.enable_weight_preprocess(); | |||
continue; | |||
} | |||
fprintf(stderr, "invalid arg: %s\n", argv[i]); | |||
ret.args_parse_ret = -1; | |||
@@ -97,6 +97,9 @@ struct GraphCommonOptimizeOptions { | |||
bool fuse_conv_bias_with_z = false; | |||
//! whether to enable fast-run profiled winograd opr replace | |||
bool weight_winograd_transform = false; | |||
//! whether to enable weight preprocess, if enabled it may use more | |||
//! memory, default disable now | |||
bool weight_preprocess = false; | |||
enum LayoutTransform : uint32_t { | |||
DEFAULT, | |||
NCHW4, ///< compute using NCHW4 tensor format | |||
@@ -127,6 +130,7 @@ struct GraphCommonOptimizeOptions { | |||
SET(fuse_conv_bias_nonlinearity); | |||
SET(fuse_conv_bias_with_z); | |||
SET(weight_winograd_transform); | |||
SET(weight_preprocess); | |||
#undef SET | |||
#define SET(_trans, _trans_capital) \ | |||
GraphCommonOptimizeOptions& enable_##_trans() { \ | |||
@@ -963,6 +963,9 @@ void mixin::WeightPreprocessExecutor::record_preprocessed_weight( | |||
bool mixin::WeightPreprocessExecutor::mixin_allow_weight_preprocess( | |||
const cg::OperatorNodeBase& opr) const { | |||
if (!opr.owner_graph()->options().graph_opt.weight_preprocess) { | |||
return false; | |||
} | |||
if (!opr.input(1)->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE)) | |||
return false; | |||
if (cg::is_const_var_value(opr.input(1))) | |||
@@ -2225,6 +2225,7 @@ protected: | |||
iw = ih; | |||
comp_node = CompNode::load("cpux"); | |||
graph = ComputingGraph::make(); | |||
graph->options().graph_opt.weight_preprocess = is_weight_preprocess(); | |||
TensorShape x_shape{1, ic, ih, iw}, w_shape{oc, ic, fh, fh}; | |||
x_host = std::make_shared<HostTensorND>(comp_node, x_shape); | |||
auto x = opr::Host2DeviceCopy::make(*graph, x_host); | |||
@@ -2247,6 +2248,8 @@ protected: | |||
void run() { func->execute().wait(); } | |||
virtual bool is_weight_preprocess() { return true; } | |||
void TearDown() override { | |||
func.reset(); | |||
// Triggers mock check | |||
@@ -2346,6 +2349,33 @@ TEST_F(TestWeightPreprocess, PreprocessCalledOnlyOnce) { | |||
} | |||
} | |||
class TestNoWeightPreprocess : public TestWeightPreprocess { | |||
bool is_weight_preprocess() override { return false; } | |||
}; | |||
TEST_F(TestNoWeightPreprocess, NoPreprocess) { | |||
using ::testing::_; | |||
using ::testing::Return; | |||
auto& mock = mock_conv(); | |||
MockAlgorithm algo; | |||
EXPECT_CALL(mock, get_algorithm_heuristic(_, _, _, _, _)) | |||
.WillRepeatedly(Return(&algo)); | |||
EXPECT_CALL(mock, get_workspace_in_bytes(_, _, _, _)) | |||
.WillRepeatedly(Return(0)); | |||
EXPECT_CALL(mock, get_preprocess_workspace_in_bytes(_, _, _)) | |||
.WillRepeatedly(Return(0)); | |||
{ | |||
::testing::InSequence seq; | |||
// Return empty preprocess filters, indicating no need to preprocess | |||
EXPECT_CALL(mock, deduce_preprocessed_filter_layout(_, _, _)).Times(0); | |||
EXPECT_CALL(mock, exec_preprocess(_, _, _, _, _)).Times(0); | |||
EXPECT_CALL(mock, exec(_, _, _, nullptr, _)); | |||
run(); | |||
} | |||
} | |||
} // anonymous namespace | |||
#endif | |||