diff --git a/dnn/src/common/convolution.cpp b/dnn/src/common/convolution.cpp index 2e1228fc..0fc9afbb 100644 --- a/dnn/src/common/convolution.cpp +++ b/dnn/src/common/convolution.cpp @@ -831,7 +831,7 @@ typename ConvolutionBase::CanonizedFilterMeta ConvolutionBase::CanonizedFilterMeta ConvolutionBase(); + add_pass(); } + return *this; } diff --git a/src/gopt/impl/global_layout_transform/layout_transform_context.cpp b/src/gopt/impl/global_layout_transform/layout_transform_context.cpp index 48b98268..093c6ecf 100644 --- a/src/gopt/impl/global_layout_transform/layout_transform_context.cpp +++ b/src/gopt/impl/global_layout_transform/layout_transform_context.cpp @@ -66,7 +66,6 @@ std::unique_ptr make_cuda_ctx( ctx->add_opr_config( opr::ConvBiasForward::typeinfo(), {OprFormatConfigID::NCHW, OprFormatConfigID::NHWC, - OprFormatConfigID::NCHW4_NCHW32, OprFormatConfigID::NCHW32_NCHW4, OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, OprFormatConfigID::NCHW64, OprFormatConfigID::CHWN4}) .add_opr_config( diff --git a/src/gopt/impl/global_layout_transform/layout_transform_pass.cpp b/src/gopt/impl/global_layout_transform/layout_transform_pass.cpp index 8a198a0c..8f03a02e 100644 --- a/src/gopt/impl/global_layout_transform/layout_transform_pass.cpp +++ b/src/gopt/impl/global_layout_transform/layout_transform_pass.cpp @@ -18,6 +18,7 @@ #include "megbrain/gopt/solver.h" #include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/imgproc.h" +#include "megbrain/opr/tensor_manip.h" #include "megbrain/serialization/sereg.h" #include "megbrain/utils/hash_ct.h" @@ -64,11 +65,6 @@ void LayoutTransformPass::apply(OptState& opt) const { auto&& base_cfg_id = m_ctx->attribute().base_config_id; auto&& reformat_attribute = m_ctx->attribute().reformat_attribute; ThinHashMap var2fmts; - static ThinHashSet format_aware_oprs = { -#define cb(_Opr) opr::_Opr::typeinfo(), - FOREACH_FORMAT_AWARE_OPR(cb) -#undef cb - }; auto rewriter = opt.graph().make_rewriter(); auto on_opr = [&opr_configs, &base_fmt, &base_cfg_id, &reformat_attribute, &rewriter, &solution, &var2fmts, @@ -141,8 +137,12 @@ void LayoutTransformPass::apply(OptState& opt) const { new_inp[i] = new_var; } VarNode* new_out; - if (format_aware_oprs.count(opr->dyn_typeinfo()) > 0) { - new_out = intl::modify_opr_format(opr_fmt.val(), new_inp, opr); + if (intl::has_opr_format_modifier(opr)) { + intl::OprFormatInfo opr_format_info; + opr_format_info.opr_format = opr_fmt.val(); + opr_format_info.tensor_formats = { + base_fmt, opr_format_to_tensor_formats(opr_fmt.val())}; + new_out = intl::modify_opr_format(opr_format_info, new_inp, opr); } else { new_out = serialization::copy_opr_shallow(*opr, new_inp, opr->config()) ->output(0); diff --git a/src/gopt/impl/global_layout_transform/opr_format_modifier.cpp b/src/gopt/impl/global_layout_transform/opr_format_modifier.cpp index 89fe19e5..7e1c8ec9 100644 --- a/src/gopt/impl/global_layout_transform/opr_format_modifier.cpp +++ b/src/gopt/impl/global_layout_transform/opr_format_modifier.cpp @@ -11,10 +11,12 @@ */ #include "./opr_format_modifier.h" +#include "./utils.h" #include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/imgproc.h" #include "megbrain/opr/io.h" +#include "megbrain/opr/tensor_manip.h" #include "megbrain/serialization/sereg.h" #include "midout.h" @@ -201,6 +203,37 @@ INST(ConvolutionBackwardData) INST(PoolingForward) #undef APPLY #undef INST + +VarNode* modify_concat_opr_format( + gopt::intl::OprFormatInfo::TensorFormatsInfo tensor_formats, + const VarNodeArray& i, const cg::OperatorNodeBase* opr) { + auto base_format = tensor_formats.from; + auto tensor_format = tensor_formats.to; + int axis = opr->cast_final_safe().axis(); + /// modify axis + using Dimension = megdnn::Dimension; + static constexpr uint32_t UNDETERMINED_EXTENT = Dimension::UNDETERMINED_EXTENT; + auto orig_shape = tensor_formats_to_named_tensor_shape(base_format); + auto target_shape = tensor_formats_to_named_tensor_shape(tensor_format); + mgb_assert( + static_cast(axis) < orig_shape.ndim, + "invalid axis of concat opr(axis:%d,shp:%s)", axis, + orig_shape.to_string().c_str()); + if (orig_shape[axis].extent() != UNDETERMINED_EXTENT) + return nullptr; + auto axis_name = orig_shape[axis].name(); + int new_axis = target_shape.ndim; + for (size_t i = 0; i < target_shape.ndim; ++i) { + if (target_shape[i].name() == axis_name && + target_shape[i].extent() == UNDETERMINED_EXTENT) { + new_axis = i; + break; + } + } + if (static_cast(new_axis) >= target_shape.ndim) + return nullptr; + return opr::Concat::make(i, new_axis, opr->config()).node(); +} } // namespace namespace mgb { @@ -275,13 +308,16 @@ INST(Resize, 2); #undef INST VarNode* modify_opr_format( - opr::ConvBias::Param::Format opr_format, const VarNodeArray& i, + OprFormatInfo opr_format_info, const VarNodeArray& i, const cg::OperatorNodeBase* opr) { -#define cb(_Opr) \ - if (opr->dyn_typeinfo() == _Opr::typeinfo()) { \ - return OprFormatModifier<_Opr>::make(opr_format, i, opr); \ +#define cb(_Opr) \ + if (opr->dyn_typeinfo() == _Opr::typeinfo()) { \ + return OprFormatModifier<_Opr>::make(opr_format_info.opr_format, i, opr); \ } else - FOREACH_FORMAT_AWARE_OPR(cb) { + FOREACH_FORMAT_AWARE_OPR(cb) + if (opr->dyn_typeinfo() == opr::Concat::typeinfo()) { + return modify_concat_opr_format(opr_format_info.tensor_formats, i, opr); + } else { mgb_throw( InternalError, "invalid format aware operator(got:%s)", opr->dyn_typeinfo()->name); @@ -302,6 +338,28 @@ bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr) InternalError, "invalid multi-algo operator(got:%s)", opr->dyn_typeinfo()->name); } +#undef cb +} + +bool has_opr_format(const cg::OperatorNodeBase* opr) { + bool ret = false; +#define cb(_Opr) ret |= opr->dyn_typeinfo() == _Opr::typeinfo(); + FOREACH_FORMAT_AWARE_OPR(cb) +#undef cb + return ret; +} + +bool has_opr_format_modifier(const cg::OperatorNodeBase* opr) { + bool ret = false; +#define cb(_Opr) ret |= opr->dyn_typeinfo() == _Opr::typeinfo(); + FOREACH_MODIFY_OPR_FORMAT_OPR(cb) +#undef cb + return ret; +} + +bool allow_aligned_layout(const cg::OperatorNodeBase* opr) { + return opr->dyn_typeinfo() != opr::Concat::typeinfo() && + opr->dyn_typeinfo() != opr::Reduce::typeinfo(); } } // namespace intl diff --git a/src/gopt/impl/global_layout_transform/opr_format_modifier.h b/src/gopt/impl/global_layout_transform/opr_format_modifier.h index 11e634f6..2ab6697a 100644 --- a/src/gopt/impl/global_layout_transform/opr_format_modifier.h +++ b/src/gopt/impl/global_layout_transform/opr_format_modifier.h @@ -16,17 +16,36 @@ namespace mgb { namespace gopt { +enum class TensorFormats : uint32_t; namespace intl { #define FOREACH_FORMAT_AWARE_OPR(cb) \ cb(Convolution) cb(ConvBiasForward) cb(ConvolutionBackwardData) cb(PoolingForward) \ cb(WarpPerspective) cb(Resize) + +#define FOREACH_MODIFY_OPR_FORMAT_OPR(cb) FOREACH_FORMAT_AWARE_OPR(cb) cb(Concat) + bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr); +struct OprFormatInfo { + opr::Convolution::Param::Format opr_format; + struct TensorFormatsInfo { + TensorFormats from; + TensorFormats to; + }; + TensorFormatsInfo tensor_formats; +}; + VarNode* modify_opr_format( - opr::Convolution::Param::Format opr_format, const VarNodeArray& i, + OprFormatInfo opr_format, const VarNodeArray& i, const cg::OperatorNodeBase* opr); +bool has_opr_format(const cg::OperatorNodeBase* opr); + +bool has_opr_format_modifier(const cg::OperatorNodeBase* opr); + +bool allow_aligned_layout(const cg::OperatorNodeBase* opr); + } // namespace intl } // namespace gopt } // namespace mgb diff --git a/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp b/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp index 9a2809ab..6ea806bc 100644 --- a/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp +++ b/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp @@ -625,7 +625,13 @@ struct ConvTensorFormatsDispatcherImpl { TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; config.input_tensor_types.emplace_back(tensor_type); } - available &= check_dtype(opr->output(0)->dtype(), false); + // FIXME: hack for nchw nchw44 hybrid mode + static_assert( + std::is_same::value || + std::is_same::value, + "nchw44 hybrid only support conv or conv_bias opr"); + size_t in_channel = opr->input(0)->shape()[1]; + available &= in_channel <= 4_z && check_dtype(opr->output(0)->dtype(), false); config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); available &= conv.param().sparse == Opr::Param::Sparse::DENSE; config.input_tensor_formats = { @@ -696,7 +702,14 @@ struct ConvTensorFormatsDispatcherImpl { TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; config.input_tensor_types.emplace_back(tensor_type); } - available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32; + // FIXME: hack for nchw nchw88 hybrid mode + static_assert( + std::is_same::value || + std::is_same::value, + "nchw nchw88 hybrid mode only support conv or conv_bias opr"); + size_t in_channel = opr->input(0)->shape()[1]; + available &= in_channel <= 8_z && + opr->output(0)->dtype().enumv() == DTypeEnum::Float32; config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); available &= conv.param().sparse == Opr::Param::Sparse::DENSE; // setup tensor formats @@ -783,6 +796,13 @@ struct ConvTensorFormatsDispatcherImploutput(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm; config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); available &= conv.param().sparse == Opr::Param::Sparse::DENSE; + // FIXME: hack for nchw nchw44 dot hybrid mode + static_assert( + std::is_same::value || + std::is_same::value, + "nchw44 dot hybrid only support conv or conv_bias opr"); + size_t in_channel = opr->input(0)->shape()[1]; + available &= in_channel <= 4_z; // setup tensor formats config.input_tensor_formats = { TensorFormats::NCHW, TensorFormats::KRSCk4, TensorFormats::NCHWc4, @@ -940,6 +960,8 @@ StaticData::StaticData() { OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW4); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, CHWN4); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW32); + OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW32_NCHW4); + OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW4_NCHW32); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW64); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88); diff --git a/src/gopt/impl/global_layout_transform/profiler_impl.cpp b/src/gopt/impl/global_layout_transform/profiler_impl.cpp index 6c700195..b1ef404b 100644 --- a/src/gopt/impl/global_layout_transform/profiler_impl.cpp +++ b/src/gopt/impl/global_layout_transform/profiler_impl.cpp @@ -19,6 +19,7 @@ #include "megbrain/opr/imgproc.h" #include "megbrain/opr/io.h" #include "megbrain/opr/nn_int.h" +#include "megbrain/opr/tensor_manip.h" #include "megbrain/plugin/base.h" #include "megbrain/serialization/sereg.h" @@ -202,20 +203,43 @@ float ProfilerImpl::profile_operator( auto graph = ComputingGraph::make(); graph->options().graph_opt_level = 0; graph->options().var_sanity_check_first_run = false; + OperatorNodeBase* new_opr; + /// \note: Concat operators are specially treated. The reasons are as + /// follows: + /// 1. Padding the input varnodes of the + /// Concat opr is not allowed. If we pad the input varnodes of Concat opr, the + /// padding information of output varnode should be propagated in the layout + /// selection algorithm. But this feature has not been implemented now. + /// 2. The axis of concat operator should be modified, because the layouts of the + /// input varnodes has been modified. So we handle the new axis in the OprMaker + /// function. + bool allow_aligned = intl::allow_aligned_layout(opr); VarNodeArray new_inps(opr->input().size()); for (size_t i = 0; i < opr->input().size(); ++i) { auto&& var = opr->input(i); auto&& cn = var->comp_node(); auto&& dtype = var->dtype(); auto dval = std::make_shared(cn, dtype); - auto aligned_tensor_shape = ReformatManager::make_aligned_tensor_shape( - var, base_format, tensor_format, extra_attribute); - dval->resize(aligned_tensor_shape); - auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval); - new_inps[i] = aligned_var.node(); + auto new_shape = ReformatManager::try_make_tensor_shape( + var, base_format, tensor_format, extra_attribute, allow_aligned); + if (new_shape.ndim == 0) + return PROFILE_TIME_OUT; + dval->resize(new_shape); + auto new_var = opr::VolatileSharedDeviceTensor::make(*graph, dval); + new_inps[i] = new_var.node(); + } + if (intl::has_opr_format_modifier(opr)) { + intl::OprFormatInfo opr_format_info; + opr_format_info.tensor_formats = {base_format, tensor_format}; + auto new_var = intl::modify_opr_format(opr_format_info, new_inps, opr); + if (new_var) + new_opr = new_var->owner_opr(); + else + return PROFILE_TIME_OUT; + } else { + new_opr = serialization::copy_opr_shallow( + *opr, new_inps, opr->config(), {graph.get()}); } - auto new_opr = serialization::copy_opr_shallow( - *opr, new_inps, opr->config(), {graph.get()}); if (!m_opr_filter(opr, new_opr)) return PROFILE_TIME_OUT; auto y = new_opr->output(0); @@ -248,6 +272,8 @@ float ProfilerImpl::profile_operator( ReformatAttribute extra_attribute) const { auto graph = ComputingGraph::make(); graph->options().graph_opt_level = 0; + graph->options().graph_opt.weight_preprocess = + opr->owner_graph()->options().graph_opt.weight_preprocess; graph->options().var_sanity_check_first_run = false; VarNodeArray new_inps(opr->input().size()); size_t i = 0; @@ -274,8 +300,11 @@ float ProfilerImpl::profile_operator( config.input_tensor_formats[i], extra_attribute); } dval->resize(aligned_shape); - auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval); - new_inps[i] = aligned_var.node(); + if (config.input_tensor_types[i] == TensorType::WEIGHT) { + new_inps[i] = opr::SharedDeviceTensor::make_const(*graph, dval).node(); + } else { + new_inps[i] = opr::VolatileSharedDeviceTensor::make(*graph, dval).node(); + } } for (; i < opr->input().size(); ++i) { auto&& var = opr->input(i); @@ -291,7 +320,9 @@ float ProfilerImpl::profile_operator( auto imm = opr::ImmutableTensor::make(*graph, *hval); new_inps[i] = imm.node(); } - VarNode* y = mgb::gopt::intl::modify_opr_format(config.opr_format, new_inps, opr); + intl::OprFormatInfo opr_format_info; + opr_format_info.opr_format = config.opr_format; + VarNode* y = mgb::gopt::intl::modify_opr_format(opr_format_info, new_inps, opr); static const ThinHashSet multi_algo_oprs = { opr::Convolution::typeinfo(), opr::ConvBiasForward::typeinfo(), diff --git a/src/gopt/impl/global_layout_transform/reformat_manager.cpp b/src/gopt/impl/global_layout_transform/reformat_manager.cpp index 82b9a862..38c301c3 100644 --- a/src/gopt/impl/global_layout_transform/reformat_manager.cpp +++ b/src/gopt/impl/global_layout_transform/reformat_manager.cpp @@ -587,9 +587,9 @@ const ReformatManager& ReformatManager::instance() { return inst; } -TensorShape ReformatManager::make_aligned_tensor_shape( +TensorShape ReformatManager::try_make_tensor_shape( const VarNode* var, TensorFormats orig_formats, TensorFormats target_formats, - ReformatKey::Attribute extra_attribute) { + ReformatKey::Attribute extra_attribute, bool allow_aligned) { using Dimension = megdnn::Dimension; static constexpr uint32_t UNDETERMINED_EXTENT = Dimension::UNDETERMINED_EXTENT; auto orig_shape = tensor_formats_to_named_tensor_shape(orig_formats); @@ -623,8 +623,17 @@ TensorShape ReformatManager::make_aligned_tensor_shape( : (orig_shape[idx] / target_shape[i]).extent(); if (mul) tshp[i] = oshp[idx] * factor; - else - tshp[i] = divup(oshp[idx], factor); + else { + if (allow_aligned) + tshp[i] = divup(oshp[idx], factor); + else if (!(oshp[idx] % factor)) { + tshp[i] = oshp[idx] / factor; + } else { + return TensorShape{}; + } + } + + /// hack for nhwc auto padding if (name == Dimension::Name::C) { size_t channel_alignment = target_shape[i].stride(); size_t channels = tshp[i] * channel_alignment; @@ -641,6 +650,15 @@ TensorShape ReformatManager::make_aligned_tensor_shape( return tshp; } +TensorShape ReformatManager::make_aligned_tensor_shape( + const VarNode* var, TensorFormats orig_formats, TensorFormats target_formats, + ReformatKey::Attribute extra_attribute) { + auto tshp = ReformatManager::try_make_tensor_shape( + var, orig_formats, target_formats, extra_attribute); + mgb_assert(tshp.ndim); + return tshp; +} + TensorShape ReformatManager::make_aligned_weight_shape( const VarNode* var, TensorFormats orig_formats, TensorFormats target_formats, TensorFormats extra_formats, ReformatKey::Attribute extra_attribute) { diff --git a/src/gopt/impl/global_layout_transform/utils.h b/src/gopt/impl/global_layout_transform/utils.h index 9cc3b82d..64f25b10 100644 --- a/src/gopt/impl/global_layout_transform/utils.h +++ b/src/gopt/impl/global_layout_transform/utils.h @@ -93,6 +93,8 @@ static inline TensorFormats opr_format_to_tensor_formats( return TensorFormats::NCHWc4; case OprFormat::NCHW8: return TensorFormats::NCHWc8; + case OprFormat::NCHW44_DOT: + return TensorFormats::NCHWc4; default: mgb_throw( AssertionError, "format(%s) is not supported", @@ -171,7 +173,6 @@ static inline megdnn::NamedTensorShape tensor_formats_to_named_tensor_shape( static_cast(format)); } } - } // namespace gopt } // namespace mgb diff --git a/src/gopt/include/megbrain/gopt/reformat_manager.h b/src/gopt/include/megbrain/gopt/reformat_manager.h index 7426a568..45091a9a 100644 --- a/src/gopt/include/megbrain/gopt/reformat_manager.h +++ b/src/gopt/include/megbrain/gopt/reformat_manager.h @@ -135,6 +135,13 @@ public: const VarNode* orig_var, const ReformatKey& key, const AlignmentDesc& extra_alignment = {}) const; + /// return empty shape, if shape of origin varnode does not satisfy the alignment + /// requirement of the target tensor formats + static TensorShape try_make_tensor_shape( + const VarNode* var, TensorFormats orig_formats, + TensorFormats target_formats, + ReformatKey::Attribute extra_attribute = ReformatKey::Attribute::DEFAULT, + bool allow_aligned = true); static TensorShape make_aligned_tensor_shape( const VarNode* var, TensorFormats orig_formats, TensorFormats target_formats, diff --git a/src/gopt/test/cache_data.h b/src/gopt/test/cache_data.h index a347b83f..0cd5c465 100644 --- a/src/gopt/test/cache_data.h +++ b/src/gopt/test/cache_data.h @@ -18104,3 +18104,217 @@ static const std::vector TestLayoutTransform_Wide = { 108, 111, 97, 116, 51, 50, 124, 49, 54, 44, 49, 54, 44, 56, 44, 56, 59, 70, 108, 111, 97, 116, 51, 50, 124, 49, 48, 4, 0, 0, 0, 233, 38, 177, 64}; + +static const std::vector TestLayoutTransform_Concat = { + 1, 0, 0, 0, 34, 0, 0, 0, 108, 97, 121, 111, 117, 116, 95, 116, + 114, 97, 110, 115, 102, 111, 114, 109, 95, 112, 114, 111, 102, 105, 108, 101, + 58, 112, 108, 97, 116, 61, 99, 117, 100, 97, 33, 0, 0, 0, 77, 0, + 0, 0, 49, 54, 44, 50, 44, 51, 54, 56, 44, 54, 52, 48, 59, 81, + 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 123, 78, 44, 67, 47, + 47, 52, 44, 72, 44, 87, 44, 67, 37, 52, 125, 59, 123, 78, 44, 67, + 44, 72, 44, 87, 125, 59, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, + 100, 83, 56, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 4, + 0, 0, 0, 28, 90, 20, 67, 89, 0, 0, 0, 47, 225, 79, 220, 129, + 196, 236, 15, 1, 0, 0, 0, 49, 54, 44, 50, 44, 51, 54, 56, 44, + 54, 52, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, + 49, 54, 44, 50, 44, 51, 54, 56, 44, 54, 52, 48, 59, 81, 117, 97, + 110, 116, 105, 122, 101, 100, 83, 56, 124, 49, 54, 44, 52, 44, 51, 54, + 56, 44, 54, 52, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, + 56, 124, 51, 48, 4, 0, 0, 0, 128, 150, 24, 75, 77, 0, 0, 0, + 49, 54, 44, 52, 44, 49, 56, 52, 44, 51, 50, 48, 59, 81, 117, 97, + 110, 116, 105, 122, 101, 100, 83, 56, 124, 123, 78, 44, 67, 47, 47, 52, + 44, 72, 44, 87, 44, 67, 37, 52, 125, 59, 123, 78, 44, 67, 44, 72, + 44, 87, 125, 59, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, + 56, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 4, 0, 0, + 0, 236, 81, 208, 66, 52, 0, 0, 0, 145, 22, 69, 252, 191, 174, 0, + 1, 49, 54, 44, 51, 44, 55, 51, 54, 44, 49, 50, 56, 48, 59, 85, + 105, 110, 116, 56, 124, 49, 54, 44, 51, 44, 55, 51, 54, 44, 49, 50, + 56, 48, 59, 70, 108, 111, 97, 116, 51, 50, 124, 48, 48, 4, 0, 0, + 0, 88, 57, 79, 68, 66, 0, 0, 0, 49, 54, 44, 51, 44, 55, 51, + 54, 44, 49, 50, 56, 48, 59, 70, 108, 111, 97, 116, 51, 50, 124, 123, + 78, 44, 67, 44, 72, 44, 87, 125, 59, 123, 78, 44, 67, 47, 47, 52, + 44, 72, 44, 87, 44, 67, 37, 52, 125, 59, 48, 59, 70, 108, 111, 97, + 116, 51, 50, 59, 70, 108, 111, 97, 116, 51, 50, 4, 0, 0, 0, 134, + 75, 130, 69, 163, 0, 0, 0, 10, 151, 159, 190, 73, 14, 145, 194, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 1, + 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, + 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 49, 54, 44, 56, 44, + 49, 56, 52, 44, 51, 50, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, + 100, 83, 56, 124, 49, 54, 44, 56, 44, 51, 44, 51, 59, 81, 117, 97, + 110, 116, 105, 122, 101, 100, 83, 56, 124, 49, 44, 49, 54, 44, 49, 44, + 49, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 51, 50, 124, 49, + 54, 44, 49, 54, 44, 57, 50, 44, 49, 54, 48, 59, 81, 117, 97, 110, + 116, 105, 122, 101, 100, 83, 56, 124, 51, 48, 4, 0, 0, 0, 68, 139, + 104, 67, 60, 0, 0, 0, 49, 54, 44, 51, 44, 55, 51, 54, 44, 49, + 50, 56, 48, 59, 85, 105, 110, 116, 56, 124, 123, 78, 44, 67, 47, 47, + 52, 44, 72, 44, 87, 44, 67, 37, 52, 125, 59, 123, 78, 44, 67, 44, + 72, 44, 87, 125, 59, 48, 59, 85, 105, 110, 116, 56, 59, 85, 105, 110, + 116, 56, 4, 0, 0, 0, 201, 161, 135, 68, 64, 0, 0, 0, 218, 182, + 120, 146, 221, 155, 179, 235, 11, 0, 0, 0, 49, 54, 44, 52, 44, 49, + 56, 52, 44, 51, 50, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, + 83, 56, 124, 49, 54, 44, 52, 44, 49, 56, 52, 44, 51, 50, 48, 59, + 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 48, 48, 4, 0, + 0, 0, 182, 243, 29, 66, 161, 0, 0, 0, 10, 151, 159, 190, 73, 14, + 145, 194, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, 0, + 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, + 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 49, 54, + 44, 52, 44, 51, 54, 56, 44, 54, 52, 48, 59, 81, 117, 97, 110, 116, + 105, 122, 101, 100, 83, 56, 124, 52, 44, 52, 44, 51, 44, 51, 59, 81, + 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 49, 44, 52, 44, 49, + 44, 49, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 51, 50, 124, + 49, 54, 44, 52, 44, 49, 56, 52, 44, 51, 50, 48, 59, 81, 117, 97, + 110, 116, 105, 122, 101, 100, 83, 56, 124, 48, 48, 4, 0, 0, 0, 128, + 150, 24, 75, 60, 0, 0, 0, 49, 54, 44, 51, 44, 55, 51, 54, 44, + 49, 50, 56, 48, 59, 85, 105, 110, 116, 56, 124, 123, 78, 44, 67, 44, + 72, 44, 87, 125, 59, 123, 78, 44, 67, 47, 47, 52, 44, 72, 44, 87, + 44, 67, 37, 52, 125, 59, 48, 59, 85, 105, 110, 116, 56, 59, 85, 105, + 110, 116, 56, 4, 0, 0, 0, 203, 33, 219, 68, 68, 0, 0, 0, 218, + 182, 120, 146, 221, 155, 179, 235, 16, 0, 0, 0, 49, 54, 44, 51, 44, + 55, 51, 54, 44, 49, 50, 56, 48, 59, 70, 108, 111, 97, 116, 51, 50, + 124, 49, 59, 70, 108, 111, 97, 116, 51, 50, 124, 49, 54, 44, 51, 44, + 55, 51, 54, 44, 49, 50, 56, 48, 59, 70, 108, 111, 97, 116, 51, 50, + 124, 48, 48, 4, 0, 0, 0, 205, 204, 161, 68, 77, 0, 0, 0, 49, + 54, 44, 52, 44, 51, 54, 56, 44, 54, 52, 48, 59, 81, 117, 97, 110, + 116, 105, 122, 101, 100, 83, 56, 124, 123, 78, 44, 67, 47, 47, 52, 44, + 72, 44, 87, 44, 67, 37, 52, 125, 59, 123, 78, 44, 67, 44, 72, 44, + 87, 125, 59, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, + 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 4, 0, 0, 0, + 200, 118, 174, 67, 163, 0, 0, 0, 10, 151, 159, 190, 73, 14, 145, 194, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, + 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, + 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 49, 54, 44, 56, + 44, 49, 56, 52, 44, 51, 50, 48, 59, 81, 117, 97, 110, 116, 105, 122, + 101, 100, 83, 56, 124, 49, 54, 44, 56, 44, 51, 44, 51, 59, 81, 117, + 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 49, 44, 49, 54, 44, 49, + 44, 49, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 51, 50, 124, + 49, 54, 44, 49, 54, 44, 57, 50, 44, 49, 54, 48, 59, 81, 117, 97, + 110, 116, 105, 122, 101, 100, 83, 56, 124, 48, 48, 4, 0, 0, 0, 128, + 150, 24, 75, 161, 0, 0, 0, 10, 151, 159, 190, 73, 14, 145, 194, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 1, + 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, + 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 49, 54, 44, 52, 44, + 51, 54, 56, 44, 54, 52, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, + 100, 83, 56, 124, 52, 44, 52, 44, 51, 44, 51, 59, 81, 117, 97, 110, + 116, 105, 122, 101, 100, 83, 56, 124, 49, 44, 52, 44, 49, 44, 49, 59, + 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 51, 50, 124, 49, 54, 44, + 52, 44, 49, 56, 52, 44, 51, 50, 48, 59, 81, 117, 97, 110, 116, 105, + 122, 101, 100, 83, 56, 124, 51, 48, 4, 0, 0, 0, 138, 65, 164, 67, + 77, 0, 0, 0, 49, 54, 44, 50, 44, 51, 54, 56, 44, 54, 52, 48, + 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 123, 78, 44, + 67, 44, 72, 44, 87, 125, 59, 123, 78, 44, 67, 47, 47, 52, 44, 72, + 44, 87, 44, 67, 37, 52, 125, 59, 48, 59, 81, 117, 97, 110, 116, 105, + 122, 101, 100, 83, 56, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, + 56, 4, 0, 0, 0, 137, 65, 190, 67, 58, 0, 0, 0, 145, 22, 69, + 252, 191, 174, 0, 1, 49, 54, 44, 51, 44, 55, 51, 54, 44, 49, 50, + 56, 48, 59, 70, 108, 111, 97, 116, 51, 50, 124, 49, 54, 44, 51, 44, + 55, 51, 54, 44, 49, 50, 56, 48, 59, 81, 117, 97, 110, 116, 105, 122, + 101, 100, 83, 56, 124, 48, 48, 4, 0, 0, 0, 38, 177, 146, 68, 89, + 0, 0, 0, 47, 225, 79, 220, 129, 196, 236, 15, 1, 0, 0, 0, 49, + 54, 44, 50, 44, 51, 54, 56, 44, 54, 52, 48, 59, 81, 117, 97, 110, + 116, 105, 122, 101, 100, 83, 56, 124, 49, 54, 44, 50, 44, 51, 54, 56, + 44, 54, 52, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, + 124, 49, 54, 44, 52, 44, 51, 54, 56, 44, 54, 52, 48, 59, 81, 117, + 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 48, 48, 4, 0, 0, 0, + 150, 67, 235, 66, 162, 0, 0, 0, 10, 151, 159, 190, 73, 14, 145, 194, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, + 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, + 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 49, 54, 44, 51, + 44, 55, 51, 54, 44, 49, 50, 56, 48, 59, 81, 117, 97, 110, 116, 105, + 122, 101, 100, 83, 56, 124, 50, 44, 51, 44, 51, 44, 51, 59, 81, 117, + 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 49, 44, 50, 44, 49, 44, + 49, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 51, 50, 124, 49, + 54, 44, 50, 44, 51, 54, 56, 44, 54, 52, 48, 59, 81, 117, 97, 110, + 116, 105, 122, 101, 100, 83, 56, 124, 51, 48, 4, 0, 0, 0, 128, 150, + 24, 75, 64, 0, 0, 0, 218, 182, 120, 146, 221, 155, 179, 235, 11, 0, + 0, 0, 49, 54, 44, 50, 44, 51, 54, 56, 44, 54, 52, 48, 59, 81, + 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 49, 54, 44, 50, 44, + 51, 54, 56, 44, 54, 52, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, + 100, 83, 56, 124, 51, 48, 4, 0, 0, 0, 227, 165, 243, 66, 77, 0, + 0, 0, 49, 54, 44, 52, 44, 49, 56, 52, 44, 51, 50, 48, 59, 81, + 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 123, 78, 44, 67, 44, + 72, 44, 87, 125, 59, 123, 78, 44, 67, 47, 47, 52, 44, 72, 44, 87, + 44, 67, 37, 52, 125, 59, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, + 100, 83, 56, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 4, + 0, 0, 0, 174, 71, 169, 66, 162, 0, 0, 0, 10, 151, 159, 190, 73, + 14, 145, 194, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, + 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, + 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 49, + 54, 44, 51, 44, 55, 51, 54, 44, 49, 50, 56, 48, 59, 81, 117, 97, + 110, 116, 105, 122, 101, 100, 83, 56, 124, 50, 44, 51, 44, 51, 44, 51, + 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 49, 44, 50, + 44, 49, 44, 49, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 51, + 50, 124, 49, 54, 44, 50, 44, 51, 54, 56, 44, 54, 52, 48, 59, 81, + 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 48, 48, 4, 0, 0, + 0, 128, 150, 24, 75, 66, 0, 0, 0, 49, 54, 44, 51, 44, 55, 51, + 54, 44, 49, 50, 56, 48, 59, 70, 108, 111, 97, 116, 51, 50, 124, 123, + 78, 44, 67, 47, 47, 52, 44, 72, 44, 87, 44, 67, 37, 52, 125, 59, + 123, 78, 44, 67, 44, 72, 44, 87, 125, 59, 48, 59, 70, 108, 111, 97, + 116, 51, 50, 59, 70, 108, 111, 97, 116, 51, 50, 4, 0, 0, 0, 196, + 32, 63, 69, 89, 0, 0, 0, 47, 225, 79, 220, 129, 196, 236, 15, 1, + 0, 0, 0, 49, 54, 44, 52, 44, 49, 56, 52, 44, 51, 50, 48, 59, + 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 49, 54, 44, 52, + 44, 49, 56, 52, 44, 51, 50, 48, 59, 81, 117, 97, 110, 116, 105, 122, + 101, 100, 83, 56, 124, 49, 54, 44, 56, 44, 49, 56, 52, 44, 51, 50, + 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 48, 48, + 4, 0, 0, 0, 156, 196, 136, 66, 77, 0, 0, 0, 49, 54, 44, 49, + 54, 44, 57, 50, 44, 49, 54, 48, 59, 81, 117, 97, 110, 116, 105, 122, + 101, 100, 83, 56, 124, 123, 78, 44, 67, 47, 47, 52, 44, 72, 44, 87, + 44, 67, 37, 52, 125, 59, 123, 78, 44, 67, 44, 72, 44, 87, 125, 59, + 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 59, 81, 117, + 97, 110, 116, 105, 122, 101, 100, 83, 56, 4, 0, 0, 0, 207, 247, 211, + 66, 77, 0, 0, 0, 49, 54, 44, 56, 44, 49, 56, 52, 44, 51, 50, + 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 123, 78, + 44, 67, 47, 47, 52, 44, 72, 44, 87, 44, 67, 37, 52, 125, 59, 123, + 78, 44, 67, 44, 72, 44, 87, 125, 59, 48, 59, 81, 117, 97, 110, 116, + 105, 122, 101, 100, 83, 56, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, + 83, 56, 4, 0, 0, 0, 54, 94, 58, 67, 64, 0, 0, 0, 218, 182, + 120, 146, 221, 155, 179, 235, 11, 0, 0, 0, 49, 54, 44, 50, 44, 51, + 54, 56, 44, 54, 52, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, + 83, 56, 124, 49, 54, 44, 50, 44, 51, 54, 56, 44, 54, 52, 48, 59, + 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 48, 48, 4, 0, + 0, 0, 155, 196, 128, 66, 78, 0, 0, 0, 49, 54, 44, 51, 44, 55, + 51, 54, 44, 49, 50, 56, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, + 100, 83, 56, 124, 123, 78, 44, 67, 44, 72, 44, 87, 125, 59, 123, 78, + 44, 67, 47, 47, 52, 44, 72, 44, 87, 44, 67, 37, 52, 125, 59, 48, + 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 59, 81, 117, 97, + 110, 116, 105, 122, 101, 100, 83, 56, 4, 0, 0, 0, 191, 74, 159, 68, + 89, 0, 0, 0, 47, 225, 79, 220, 129, 196, 236, 15, 1, 0, 0, 0, + 49, 54, 44, 52, 44, 49, 56, 52, 44, 51, 50, 48, 59, 81, 117, 97, + 110, 116, 105, 122, 101, 100, 83, 56, 124, 49, 54, 44, 52, 44, 49, 56, + 52, 44, 51, 50, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, + 56, 124, 49, 54, 44, 56, 44, 49, 56, 52, 44, 51, 50, 48, 59, 81, + 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 51, 48, 4, 0, 0, + 0, 170, 241, 138, 66, 77, 0, 0, 0, 49, 54, 44, 52, 44, 51, 54, + 56, 44, 54, 52, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, + 56, 124, 123, 78, 44, 67, 44, 72, 44, 87, 125, 59, 123, 78, 44, 67, + 47, 47, 52, 44, 72, 44, 87, 44, 67, 37, 52, 125, 59, 48, 59, 81, + 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 59, 81, 117, 97, 110, 116, + 105, 122, 101, 100, 83, 56, 4, 0, 0, 0, 79, 141, 147, 67, 77, 0, + 0, 0, 49, 54, 44, 49, 54, 44, 57, 50, 44, 49, 54, 48, 59, 81, + 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 123, 78, 44, 67, 44, + 72, 44, 87, 125, 59, 123, 78, 44, 67, 47, 47, 52, 44, 72, 44, 87, + 44, 67, 37, 52, 125, 59, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, + 100, 83, 56, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 4, + 0, 0, 0, 170, 241, 170, 66, 77, 0, 0, 0, 49, 54, 44, 56, 44, + 49, 56, 52, 44, 51, 50, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, + 100, 83, 56, 124, 123, 78, 44, 67, 44, 72, 44, 87, 125, 59, 123, 78, + 44, 67, 47, 47, 52, 44, 72, 44, 87, 44, 67, 37, 52, 125, 59, 48, + 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 59, 81, 117, 97, + 110, 116, 105, 122, 101, 100, 83, 56, 4, 0, 0, 0, 203, 161, 25, 67, + 64, 0, 0, 0, 218, 182, 120, 146, 221, 155, 179, 235, 11, 0, 0, 0, + 49, 54, 44, 52, 44, 49, 56, 52, 44, 51, 50, 48, 59, 81, 117, 97, + 110, 116, 105, 122, 101, 100, 83, 56, 124, 49, 54, 44, 52, 44, 49, 56, + 52, 44, 51, 50, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, + 56, 124, 51, 48, 4, 0, 0, 0, 102, 102, 22, 66, 78, 0, 0, 0, + 49, 54, 44, 51, 44, 55, 51, 54, 44, 49, 50, 56, 48, 59, 81, 117, + 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 123, 78, 44, 67, 47, 47, + 52, 44, 72, 44, 87, 44, 67, 37, 52, 125, 59, 123, 78, 44, 67, 44, + 72, 44, 87, 125, 59, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, + 83, 56, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 4, 0, + 0, 0, 54, 94, 72, 68}; diff --git a/src/gopt/test/embed_cache.py b/src/gopt/test/embed_cache.py index 3da02450..7ebb4c15 100644 --- a/src/gopt/test/embed_cache.py +++ b/src/gopt/test/embed_cache.py @@ -20,7 +20,7 @@ # 2. 编译megbrain_test,并运行所有全局图优化相关测试: # ./megbrain_test --gtest_filter="*LayoutTransform*" # 3. 用这个脚本把所有的cache文件打包在一起 -# python3 embed_cache.py -o cache_data.h -r $(ls /path/to/cache/*.cache) +# python3 embed_cache.py -o cache_data.h -r -a $(ls /path/to/cache/*.cache) # 4. 将步骤1中的 define 语句改回原样,这样 profile 过程就会使用 cache 下来的数据。 # 5. 最后可以重新构建一下 megbrain_test ,确保测试结果正确。 import os.path @@ -44,9 +44,10 @@ def _u32(data): class CacheDataGenerator: _cache_files = None - def __init__(self, cache_files, remove_plat_info = True): + def __init__(self, cache_files, remove_plat_info=True, append_cache=True): self._cache_files = cache_files self._remove_plat_info = remove_plat_info + self._append_cache = append_cache def _get_hash(self): return _u32(self._hash.digest()[:4]) @@ -71,9 +72,10 @@ class CacheDataGenerator: return ','.join(ret) def gen_cache_data_header(self, fout, src_map): - fout.write('// generated embed_cache.py\n') - fout.write('#include \n') - fout.write('#include \n') + if not self._append_cache: + fout.write('// generated embed_cache.py\n') + fout.write('#include \n') + fout.write('#include \n') for k, v in sorted(src_map.items()): fout.write(""" static const std::vector {} = {{ @@ -89,7 +91,11 @@ static const std::vector {} = {{ assert ext == ".cache", "ext: {}, fname {}".format(ext, fname) assert base not in fname2cache_data, "duplicated kernel: " + base fname2cache_data[base] = self.gen_cache_data(fname) - with open(output, 'w') as fout: + if self._append_cache: + mode = 'a' + else: + mode = 'w' + with open(output, mode) as fout: self.gen_cache_data_header(fout, fname2cache_data) logger.info('done') @@ -107,7 +113,15 @@ if __name__ == '__main__': default=True, help="whether remove platform infomation in the cache (default: True)" ) + parser.add_argument( + "-a", + "--append-cache", + action='store_true', + default=True, + help="whether append the cache (default: True)" + ) parser.add_argument('cache', help='cache files to be embedded', nargs='+') args = parser.parse_args() - cache_generator = CacheDataGenerator(args.cache, args.remove_plat_info) + cache_generator = CacheDataGenerator(args.cache, args.remove_plat_info, + args.append_cache) cache_generator.invoke(args.output) diff --git a/src/gopt/test/layout_transform_pass.cpp b/src/gopt/test/layout_transform_pass.cpp index e7019350..b8308f3d 100644 --- a/src/gopt/test/layout_transform_pass.cpp +++ b/src/gopt/test/layout_transform_pass.cpp @@ -812,7 +812,9 @@ TEST(TestLayoutTransform, Resnet18_F16) { .add_pass(std::move(ctx), std::move(solver)) .add_pass() .add_pass(FuseNCHW4Int8Preprocess::make()) +#if CUDA_VERSION >= 10020 .add_pass() +#endif .add_pass() .add_pass() .apply({{output}}) @@ -1205,4 +1207,112 @@ TEST(TestLayoutTransform, MobileNetV2_NCHW44_DOT) { MGB_ASSERT_TENSOR_EQ(t1, t2); } +#if MGB_CUDA +TEST(TestLayoutTransform, Concat) { + REQUIRE_GPU(1); + auto cn = CompNode::load("gpu0"); + cn.activate(); + REQUIRE_CUDA_COMPUTE_CAPABILITY(6, 1); + + constexpr size_t N = 16, C = 3, H = 736, W = 1280; + HostTensorGenerator gen; + + auto graph = ComputingGraph::make(); + auto h2d = opr::Host2DeviceCopy::make(*graph, gen({N, C, H, W}, cn)); + auto data = opr::TypeCvt::make(h2d, dtype::Float32()); + auto sub_128 = data + (-128); + auto x = opr::TypeCvt::make(sub_128, dtype::QuantizedS8(1.f)); + auto mkcvar = [&](const char* name, const TensorShape& shp, const DType& dtype) { + return opr::TypeCvt::make( + opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name), + dtype); + }; + auto w = mkcvar("w", {2, 3, 3, 3}, dtype::QuantizedS8(1.f)); + auto b = mkcvar("b", {1, 2, 1, 1}, dtype::QuantizedS32(1.f)); + opr::ConvBias::Param param; + param.format = opr::ConvBias::Param::Format::NCHW; + param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = 1; + auto conv_1 = opr::ConvBias::make( + x, w, b, param, {}, OperatorNodeConfig(dtype::QuantizedS8(1.f))); + auto conv_1_cat = opr::Concat::make({conv_1, -conv_1}, 1); + + auto w2 = mkcvar("w", {4, 4, 3, 3}, dtype::QuantizedS8(1.f)); + auto b2 = mkcvar("b", {1, 4, 1, 1}, dtype::QuantizedS32(1.f)); + auto conv_2 = opr::ConvBias::make( + conv_1_cat, w2, b2, param, {}, OperatorNodeConfig(dtype::QuantizedS8(1.f))); + auto conv_2_cat = opr::Concat::make({conv_2, -conv_2}, 1); + auto w3 = mkcvar("w", {16, 8, 3, 3}, dtype::QuantizedS8(1.f)); + auto b3 = mkcvar("b", {1, 16, 1, 1}, dtype::QuantizedS32(1.f)); + auto y = opr::ConvBias::make( + conv_2_cat, w3, b3, param, {}, OperatorNodeConfig(dtype::QuantizedS8(1.f))); + + using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; + S strategy = S::PROFILE; + gopt::modify_opr_algo_strategy_inplace({y}, strategy); + + using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; + using OprList = LayoutTransformContext::OprList; + using Attribute = LayoutTransformContext::Attribute; + using Target = LayoutTransformContext::Target; + OprList opr_list = { + opr::ConvBiasForward::typeinfo(), opr::ElemwiseMultiType::typeinfo(), + opr::Elemwise::typeinfo(), opr::TypeCvt::typeinfo(), + opr::Concat::typeinfo(), + }; + SmallVector available_tensor_formats = { + TensorFormats::NCHW, TensorFormats::NCHWc4}; + Attribute attribute = { + OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC}; + auto ctx = std::make_unique( + std::move(opr_list), std::move(available_tensor_formats), attribute); + ctx->add_opr_config( + opr::ConvBiasForward::typeinfo(), + {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW4}); +#if MGB_WITH_CACHED_TEST + auto profiler = std::make_unique( + static_cast(TestLayoutTransform_Concat.data()), + TestLayoutTransform_Concat.size()); +#else + auto profiler = + ProfilerBase::make_cached_profiler("TestLayoutTransform.Concat.cache"); +#endif + std::unique_ptr solver{ + new DynamicProgrammingSolver(std::move(profiler))}; + auto new_out_vars = + gopt::GraphOptimizer{} + .add_pass(std::move(ctx), std::move(solver)) + .add_pass() + .add_pass(FuseNCHW4Int8Preprocess::make()) +#if CUDA_VERSION >= 10020 + .add_pass() + .add_pass() +#endif + .add_pass() + .add_pass() + .apply(SymbolVarArray{y}) + .endpoint_vars(); + const auto& v = new_out_vars[0]; + using OutputSpecItem = cg::ComputingGraph::OutputSpecItem; + std::vector outs; + for (auto&& i : new_out_vars) { + outs.emplace_back(OutputSpecItem{i, {}}); + } + GraphProfiler gprof{graph.get()}; + auto func = graph->compile(outs); + func->execute(); + gprof.to_json_full(func.get())->writeto_fpath(output_file("conv_cat.json")); + SmallVector oprs; + auto cb = [&oprs](cg::OperatorNodeBase* opr) { + if (opr->same_type()) { + oprs.push_back(opr); + } + }; + cg::DepOprIter{cb}.add(v.node()->owner_opr()); + ASSERT_EQ(oprs.size(), 4); + ASSERT_EQ(oprs[0]->output(0)->shape().ndim, 4); + ASSERT_EQ(oprs[2]->output(0)->shape().ndim, 5); +} +#endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}