add a special opr_format modify function for concat operators to modify concat axis when input's layout has been changed
GitOrigin-RevId: 4094208057
release-1.7
@@ -831,7 +831,7 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet | |||||
dst[4] = 32; | dst[4] = 32; | ||||
} else if (param().format == Param::Format::NCHW88) { | } else if (param().format == Param::Format::NCHW88) { | ||||
megdnn_assert( | megdnn_assert( | ||||
src.ndim == 5 || src.ndim == 4, | |||||
src.ndim == 5 || (src.ndim == 4 && src[1] <= 8), | |||||
"invalid src ndim for NCHW88, expected=5 or 4, got=%zu", src.ndim); | "invalid src ndim for NCHW88, expected=5 or 4, got=%zu", src.ndim); | ||||
dst.ndim = 5; | dst.ndim = 5; | ||||
dst[0] = src[0]; | dst[0] = src[0]; | ||||
@@ -854,7 +854,7 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet | |||||
param().format == Param::Format::NCHW44 || | param().format == Param::Format::NCHW44 || | ||||
param().format == Param::Format::NCHW44_DOT) { | param().format == Param::Format::NCHW44_DOT) { | ||||
megdnn_assert( | megdnn_assert( | ||||
src.ndim == 5 || src.ndim == 4, | |||||
src.ndim == 5 || (src.ndim == 4 && src[1] <= 4), | |||||
"invalid src ndim for NCHW44, expected=5 or 4, got=%zu", src.ndim); | "invalid src ndim for NCHW44, expected=5 or 4, got=%zu", src.ndim); | ||||
dst.ndim = 5; | dst.ndim = 5; | ||||
dst[0] = src[0]; | dst[0] = src[0]; | ||||
@@ -840,7 +840,9 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options( | |||||
if (need_param_fuse) { | if (need_param_fuse) { | ||||
add_pass<ParamFusePass>(); | add_pass<ParamFusePass>(); | ||||
add_pass<ParamMergePass>(); | |||||
} | } | ||||
return *this; | return *this; | ||||
} | } | ||||
@@ -66,7 +66,6 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx( | |||||
ctx->add_opr_config( | ctx->add_opr_config( | ||||
opr::ConvBiasForward::typeinfo(), | opr::ConvBiasForward::typeinfo(), | ||||
{OprFormatConfigID::NCHW, OprFormatConfigID::NHWC, | {OprFormatConfigID::NCHW, OprFormatConfigID::NHWC, | ||||
OprFormatConfigID::NCHW4_NCHW32, OprFormatConfigID::NCHW32_NCHW4, | |||||
OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, | OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, | ||||
OprFormatConfigID::NCHW64, OprFormatConfigID::CHWN4}) | OprFormatConfigID::NCHW64, OprFormatConfigID::CHWN4}) | ||||
.add_opr_config( | .add_opr_config( | ||||
@@ -18,6 +18,7 @@ | |||||
#include "megbrain/gopt/solver.h" | #include "megbrain/gopt/solver.h" | ||||
#include "megbrain/opr/dnn/pooling.h" | #include "megbrain/opr/dnn/pooling.h" | ||||
#include "megbrain/opr/imgproc.h" | #include "megbrain/opr/imgproc.h" | ||||
#include "megbrain/opr/tensor_manip.h" | |||||
#include "megbrain/serialization/sereg.h" | #include "megbrain/serialization/sereg.h" | ||||
#include "megbrain/utils/hash_ct.h" | #include "megbrain/utils/hash_ct.h" | ||||
@@ -64,11 +65,6 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||||
auto&& base_cfg_id = m_ctx->attribute().base_config_id; | auto&& base_cfg_id = m_ctx->attribute().base_config_id; | ||||
auto&& reformat_attribute = m_ctx->attribute().reformat_attribute; | auto&& reformat_attribute = m_ctx->attribute().reformat_attribute; | ||||
ThinHashMap<VarNode*, TensorFormats> var2fmts; | ThinHashMap<VarNode*, TensorFormats> var2fmts; | ||||
static ThinHashSet<Typeinfo*> format_aware_oprs = { | |||||
#define cb(_Opr) opr::_Opr::typeinfo(), | |||||
FOREACH_FORMAT_AWARE_OPR(cb) | |||||
#undef cb | |||||
}; | |||||
auto rewriter = opt.graph().make_rewriter(); | auto rewriter = opt.graph().make_rewriter(); | ||||
auto on_opr = [&opr_configs, &base_fmt, &base_cfg_id, &reformat_attribute, | auto on_opr = [&opr_configs, &base_fmt, &base_cfg_id, &reformat_attribute, | ||||
&rewriter, &solution, &var2fmts, | &rewriter, &solution, &var2fmts, | ||||
@@ -141,8 +137,12 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||||
new_inp[i] = new_var; | new_inp[i] = new_var; | ||||
} | } | ||||
VarNode* new_out; | VarNode* new_out; | ||||
if (format_aware_oprs.count(opr->dyn_typeinfo()) > 0) { | |||||
new_out = intl::modify_opr_format(opr_fmt.val(), new_inp, opr); | |||||
if (intl::has_opr_format_modifier(opr)) { | |||||
intl::OprFormatInfo opr_format_info; | |||||
opr_format_info.opr_format = opr_fmt.val(); | |||||
opr_format_info.tensor_formats = { | |||||
base_fmt, opr_format_to_tensor_formats(opr_fmt.val())}; | |||||
new_out = intl::modify_opr_format(opr_format_info, new_inp, opr); | |||||
} else { | } else { | ||||
new_out = serialization::copy_opr_shallow(*opr, new_inp, opr->config()) | new_out = serialization::copy_opr_shallow(*opr, new_inp, opr->config()) | ||||
->output(0); | ->output(0); | ||||
@@ -11,10 +11,12 @@ | |||||
*/ | */ | ||||
#include "./opr_format_modifier.h" | #include "./opr_format_modifier.h" | ||||
#include "./utils.h" | |||||
#include "megbrain/opr/dnn/convolution.h" | #include "megbrain/opr/dnn/convolution.h" | ||||
#include "megbrain/opr/dnn/pooling.h" | #include "megbrain/opr/dnn/pooling.h" | ||||
#include "megbrain/opr/imgproc.h" | #include "megbrain/opr/imgproc.h" | ||||
#include "megbrain/opr/io.h" | #include "megbrain/opr/io.h" | ||||
#include "megbrain/opr/tensor_manip.h" | |||||
#include "megbrain/serialization/sereg.h" | #include "megbrain/serialization/sereg.h" | ||||
#include "midout.h" | #include "midout.h" | ||||
@@ -201,6 +203,37 @@ INST(ConvolutionBackwardData) | |||||
INST(PoolingForward) | INST(PoolingForward) | ||||
#undef APPLY | #undef APPLY | ||||
#undef INST | #undef INST | ||||
VarNode* modify_concat_opr_format( | |||||
gopt::intl::OprFormatInfo::TensorFormatsInfo tensor_formats, | |||||
const VarNodeArray& i, const cg::OperatorNodeBase* opr) { | |||||
auto base_format = tensor_formats.from; | |||||
auto tensor_format = tensor_formats.to; | |||||
int axis = opr->cast_final_safe<Concat>().axis(); | |||||
/// modify axis | |||||
using Dimension = megdnn::Dimension; | |||||
static constexpr uint32_t UNDETERMINED_EXTENT = Dimension::UNDETERMINED_EXTENT; | |||||
auto orig_shape = tensor_formats_to_named_tensor_shape(base_format); | |||||
auto target_shape = tensor_formats_to_named_tensor_shape(tensor_format); | |||||
mgb_assert( | |||||
static_cast<size_t>(axis) < orig_shape.ndim, | |||||
"invalid axis of concat opr(axis:%d,shp:%s)", axis, | |||||
orig_shape.to_string().c_str()); | |||||
if (orig_shape[axis].extent() != UNDETERMINED_EXTENT) | |||||
return nullptr; | |||||
auto axis_name = orig_shape[axis].name(); | |||||
int new_axis = target_shape.ndim; | |||||
for (size_t i = 0; i < target_shape.ndim; ++i) { | |||||
if (target_shape[i].name() == axis_name && | |||||
target_shape[i].extent() == UNDETERMINED_EXTENT) { | |||||
new_axis = i; | |||||
break; | |||||
} | |||||
} | |||||
if (static_cast<size_t>(new_axis) >= target_shape.ndim) | |||||
return nullptr; | |||||
return opr::Concat::make(i, new_axis, opr->config()).node(); | |||||
} | |||||
} // namespace | } // namespace | ||||
namespace mgb { | namespace mgb { | ||||
@@ -275,13 +308,16 @@ INST(Resize, 2); | |||||
#undef INST | #undef INST | ||||
VarNode* modify_opr_format( | VarNode* modify_opr_format( | ||||
opr::ConvBias::Param::Format opr_format, const VarNodeArray& i, | |||||
OprFormatInfo opr_format_info, const VarNodeArray& i, | |||||
const cg::OperatorNodeBase* opr) { | const cg::OperatorNodeBase* opr) { | ||||
#define cb(_Opr) \ | |||||
if (opr->dyn_typeinfo() == _Opr::typeinfo()) { \ | |||||
return OprFormatModifier<_Opr>::make(opr_format, i, opr); \ | |||||
#define cb(_Opr) \ | |||||
if (opr->dyn_typeinfo() == _Opr::typeinfo()) { \ | |||||
return OprFormatModifier<_Opr>::make(opr_format_info.opr_format, i, opr); \ | |||||
} else | } else | ||||
FOREACH_FORMAT_AWARE_OPR(cb) { | |||||
FOREACH_FORMAT_AWARE_OPR(cb) | |||||
if (opr->dyn_typeinfo() == opr::Concat::typeinfo()) { | |||||
return modify_concat_opr_format(opr_format_info.tensor_formats, i, opr); | |||||
} else { | |||||
mgb_throw( | mgb_throw( | ||||
InternalError, "invalid format aware operator(got:%s)", | InternalError, "invalid format aware operator(got:%s)", | ||||
opr->dyn_typeinfo()->name); | opr->dyn_typeinfo()->name); | ||||
@@ -302,6 +338,28 @@ bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr) | |||||
InternalError, "invalid multi-algo operator(got:%s)", | InternalError, "invalid multi-algo operator(got:%s)", | ||||
opr->dyn_typeinfo()->name); | opr->dyn_typeinfo()->name); | ||||
} | } | ||||
#undef cb | |||||
} | |||||
bool has_opr_format(const cg::OperatorNodeBase* opr) { | |||||
bool ret = false; | |||||
#define cb(_Opr) ret |= opr->dyn_typeinfo() == _Opr::typeinfo(); | |||||
FOREACH_FORMAT_AWARE_OPR(cb) | |||||
#undef cb | |||||
return ret; | |||||
} | |||||
bool has_opr_format_modifier(const cg::OperatorNodeBase* opr) { | |||||
bool ret = false; | |||||
#define cb(_Opr) ret |= opr->dyn_typeinfo() == _Opr::typeinfo(); | |||||
FOREACH_MODIFY_OPR_FORMAT_OPR(cb) | |||||
#undef cb | |||||
return ret; | |||||
} | |||||
bool allow_aligned_layout(const cg::OperatorNodeBase* opr) { | |||||
return opr->dyn_typeinfo() != opr::Concat::typeinfo() && | |||||
opr->dyn_typeinfo() != opr::Reduce::typeinfo(); | |||||
} | } | ||||
} // namespace intl | } // namespace intl | ||||
@@ -16,17 +16,36 @@ | |||||
namespace mgb { | namespace mgb { | ||||
namespace gopt { | namespace gopt { | ||||
enum class TensorFormats : uint32_t; | |||||
namespace intl { | namespace intl { | ||||
#define FOREACH_FORMAT_AWARE_OPR(cb) \ | #define FOREACH_FORMAT_AWARE_OPR(cb) \ | ||||
cb(Convolution) cb(ConvBiasForward) cb(ConvolutionBackwardData) cb(PoolingForward) \ | cb(Convolution) cb(ConvBiasForward) cb(ConvolutionBackwardData) cb(PoolingForward) \ | ||||
cb(WarpPerspective) cb(Resize) | cb(WarpPerspective) cb(Resize) | ||||
#define FOREACH_MODIFY_OPR_FORMAT_OPR(cb) FOREACH_FORMAT_AWARE_OPR(cb) cb(Concat) | |||||
bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr); | bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr); | ||||
struct OprFormatInfo { | |||||
opr::Convolution::Param::Format opr_format; | |||||
struct TensorFormatsInfo { | |||||
TensorFormats from; | |||||
TensorFormats to; | |||||
}; | |||||
TensorFormatsInfo tensor_formats; | |||||
}; | |||||
VarNode* modify_opr_format( | VarNode* modify_opr_format( | ||||
opr::Convolution::Param::Format opr_format, const VarNodeArray& i, | |||||
OprFormatInfo opr_format, const VarNodeArray& i, | |||||
const cg::OperatorNodeBase* opr); | const cg::OperatorNodeBase* opr); | ||||
bool has_opr_format(const cg::OperatorNodeBase* opr); | |||||
bool has_opr_format_modifier(const cg::OperatorNodeBase* opr); | |||||
bool allow_aligned_layout(const cg::OperatorNodeBase* opr); | |||||
} // namespace intl | } // namespace intl | ||||
} // namespace gopt | } // namespace gopt | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -625,7 +625,13 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_HYBRID> { | |||||
TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | ||||
config.input_tensor_types.emplace_back(tensor_type); | config.input_tensor_types.emplace_back(tensor_type); | ||||
} | } | ||||
available &= check_dtype(opr->output(0)->dtype(), false); | |||||
// FIXME: hack for nchw nchw44 hybrid mode | |||||
static_assert( | |||||
std::is_same<Opr, opr::ConvolutionForward>::value || | |||||
std::is_same<Opr, opr::ConvBiasForward>::value, | |||||
"nchw44 hybrid only support conv or conv_bias opr"); | |||||
size_t in_channel = opr->input(0)->shape()[1]; | |||||
available &= in_channel <= 4_z && check_dtype(opr->output(0)->dtype(), false); | |||||
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | ||||
available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | ||||
config.input_tensor_formats = { | config.input_tensor_formats = { | ||||
@@ -696,7 +702,14 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW88_HYBRID> { | |||||
TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | ||||
config.input_tensor_types.emplace_back(tensor_type); | config.input_tensor_types.emplace_back(tensor_type); | ||||
} | } | ||||
available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32; | |||||
// FIXME: hack for nchw nchw88 hybrid mode | |||||
static_assert( | |||||
std::is_same<Opr, opr::ConvolutionForward>::value || | |||||
std::is_same<Opr, opr::ConvBiasForward>::value, | |||||
"nchw nchw88 hybrid mode only support conv or conv_bias opr"); | |||||
size_t in_channel = opr->input(0)->shape()[1]; | |||||
available &= in_channel <= 8_z && | |||||
opr->output(0)->dtype().enumv() == DTypeEnum::Float32; | |||||
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | ||||
available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | ||||
// setup tensor formats | // setup tensor formats | ||||
@@ -783,6 +796,13 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_DOT_HYBRID | |||||
opr->output(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm; | opr->output(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm; | ||||
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | ||||
available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | ||||
// FIXME: hack for nchw nchw44 dot hybrid mode | |||||
static_assert( | |||||
std::is_same<Opr, opr::ConvolutionForward>::value || | |||||
std::is_same<Opr, opr::ConvBiasForward>::value, | |||||
"nchw44 dot hybrid only support conv or conv_bias opr"); | |||||
size_t in_channel = opr->input(0)->shape()[1]; | |||||
available &= in_channel <= 4_z; | |||||
// setup tensor formats | // setup tensor formats | ||||
config.input_tensor_formats = { | config.input_tensor_formats = { | ||||
TensorFormats::NCHW, TensorFormats::KRSCk4, TensorFormats::NCHWc4, | TensorFormats::NCHW, TensorFormats::KRSCk4, TensorFormats::NCHWc4, | ||||
@@ -940,6 +960,8 @@ StaticData::StaticData() { | |||||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW4); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW4); | ||||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, CHWN4); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, CHWN4); | ||||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW32); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW32); | ||||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW32_NCHW4); | |||||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW4_NCHW32); | |||||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW64); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW64); | ||||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44); | ||||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88); | ||||
@@ -19,6 +19,7 @@ | |||||
#include "megbrain/opr/imgproc.h" | #include "megbrain/opr/imgproc.h" | ||||
#include "megbrain/opr/io.h" | #include "megbrain/opr/io.h" | ||||
#include "megbrain/opr/nn_int.h" | #include "megbrain/opr/nn_int.h" | ||||
#include "megbrain/opr/tensor_manip.h" | |||||
#include "megbrain/plugin/base.h" | #include "megbrain/plugin/base.h" | ||||
#include "megbrain/serialization/sereg.h" | #include "megbrain/serialization/sereg.h" | ||||
@@ -202,20 +203,43 @@ float ProfilerImpl::profile_operator( | |||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
graph->options().graph_opt_level = 0; | graph->options().graph_opt_level = 0; | ||||
graph->options().var_sanity_check_first_run = false; | graph->options().var_sanity_check_first_run = false; | ||||
OperatorNodeBase* new_opr; | |||||
/// \note: Concat operators are specially treated. The reasons are as | |||||
/// follows: | |||||
/// 1. Padding the input varnodes of the | |||||
/// Concat opr is not allowed. If we pad the input varnodes of Concat opr, the | |||||
/// padding information of output varnode should be propagated in the layout | |||||
/// selection algorithm. But this feature has not been implemented now. | |||||
/// 2. The axis of concat operator should be modified, because the layouts of the | |||||
/// input varnodes has been modified. So we handle the new axis in the OprMaker | |||||
/// function. | |||||
bool allow_aligned = intl::allow_aligned_layout(opr); | |||||
VarNodeArray new_inps(opr->input().size()); | VarNodeArray new_inps(opr->input().size()); | ||||
for (size_t i = 0; i < opr->input().size(); ++i) { | for (size_t i = 0; i < opr->input().size(); ++i) { | ||||
auto&& var = opr->input(i); | auto&& var = opr->input(i); | ||||
auto&& cn = var->comp_node(); | auto&& cn = var->comp_node(); | ||||
auto&& dtype = var->dtype(); | auto&& dtype = var->dtype(); | ||||
auto dval = std::make_shared<DeviceTensorND>(cn, dtype); | auto dval = std::make_shared<DeviceTensorND>(cn, dtype); | ||||
auto aligned_tensor_shape = ReformatManager::make_aligned_tensor_shape( | |||||
var, base_format, tensor_format, extra_attribute); | |||||
dval->resize(aligned_tensor_shape); | |||||
auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval); | |||||
new_inps[i] = aligned_var.node(); | |||||
auto new_shape = ReformatManager::try_make_tensor_shape( | |||||
var, base_format, tensor_format, extra_attribute, allow_aligned); | |||||
if (new_shape.ndim == 0) | |||||
return PROFILE_TIME_OUT; | |||||
dval->resize(new_shape); | |||||
auto new_var = opr::VolatileSharedDeviceTensor::make(*graph, dval); | |||||
new_inps[i] = new_var.node(); | |||||
} | |||||
if (intl::has_opr_format_modifier(opr)) { | |||||
intl::OprFormatInfo opr_format_info; | |||||
opr_format_info.tensor_formats = {base_format, tensor_format}; | |||||
auto new_var = intl::modify_opr_format(opr_format_info, new_inps, opr); | |||||
if (new_var) | |||||
new_opr = new_var->owner_opr(); | |||||
else | |||||
return PROFILE_TIME_OUT; | |||||
} else { | |||||
new_opr = serialization::copy_opr_shallow( | |||||
*opr, new_inps, opr->config(), {graph.get()}); | |||||
} | } | ||||
auto new_opr = serialization::copy_opr_shallow( | |||||
*opr, new_inps, opr->config(), {graph.get()}); | |||||
if (!m_opr_filter(opr, new_opr)) | if (!m_opr_filter(opr, new_opr)) | ||||
return PROFILE_TIME_OUT; | return PROFILE_TIME_OUT; | ||||
auto y = new_opr->output(0); | auto y = new_opr->output(0); | ||||
@@ -248,6 +272,8 @@ float ProfilerImpl::profile_operator( | |||||
ReformatAttribute extra_attribute) const { | ReformatAttribute extra_attribute) const { | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
graph->options().graph_opt_level = 0; | graph->options().graph_opt_level = 0; | ||||
graph->options().graph_opt.weight_preprocess = | |||||
opr->owner_graph()->options().graph_opt.weight_preprocess; | |||||
graph->options().var_sanity_check_first_run = false; | graph->options().var_sanity_check_first_run = false; | ||||
VarNodeArray new_inps(opr->input().size()); | VarNodeArray new_inps(opr->input().size()); | ||||
size_t i = 0; | size_t i = 0; | ||||
@@ -274,8 +300,11 @@ float ProfilerImpl::profile_operator( | |||||
config.input_tensor_formats[i], extra_attribute); | config.input_tensor_formats[i], extra_attribute); | ||||
} | } | ||||
dval->resize(aligned_shape); | dval->resize(aligned_shape); | ||||
auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval); | |||||
new_inps[i] = aligned_var.node(); | |||||
if (config.input_tensor_types[i] == TensorType::WEIGHT) { | |||||
new_inps[i] = opr::SharedDeviceTensor::make_const(*graph, dval).node(); | |||||
} else { | |||||
new_inps[i] = opr::VolatileSharedDeviceTensor::make(*graph, dval).node(); | |||||
} | |||||
} | } | ||||
for (; i < opr->input().size(); ++i) { | for (; i < opr->input().size(); ++i) { | ||||
auto&& var = opr->input(i); | auto&& var = opr->input(i); | ||||
@@ -291,7 +320,9 @@ float ProfilerImpl::profile_operator( | |||||
auto imm = opr::ImmutableTensor::make(*graph, *hval); | auto imm = opr::ImmutableTensor::make(*graph, *hval); | ||||
new_inps[i] = imm.node(); | new_inps[i] = imm.node(); | ||||
} | } | ||||
VarNode* y = mgb::gopt::intl::modify_opr_format(config.opr_format, new_inps, opr); | |||||
intl::OprFormatInfo opr_format_info; | |||||
opr_format_info.opr_format = config.opr_format; | |||||
VarNode* y = mgb::gopt::intl::modify_opr_format(opr_format_info, new_inps, opr); | |||||
static const ThinHashSet<Typeinfo*> multi_algo_oprs = { | static const ThinHashSet<Typeinfo*> multi_algo_oprs = { | ||||
opr::Convolution::typeinfo(), | opr::Convolution::typeinfo(), | ||||
opr::ConvBiasForward::typeinfo(), | opr::ConvBiasForward::typeinfo(), | ||||
@@ -587,9 +587,9 @@ const ReformatManager& ReformatManager::instance() { | |||||
return inst; | return inst; | ||||
} | } | ||||
TensorShape ReformatManager::make_aligned_tensor_shape( | |||||
TensorShape ReformatManager::try_make_tensor_shape( | |||||
const VarNode* var, TensorFormats orig_formats, TensorFormats target_formats, | const VarNode* var, TensorFormats orig_formats, TensorFormats target_formats, | ||||
ReformatKey::Attribute extra_attribute) { | |||||
ReformatKey::Attribute extra_attribute, bool allow_aligned) { | |||||
using Dimension = megdnn::Dimension; | using Dimension = megdnn::Dimension; | ||||
static constexpr uint32_t UNDETERMINED_EXTENT = Dimension::UNDETERMINED_EXTENT; | static constexpr uint32_t UNDETERMINED_EXTENT = Dimension::UNDETERMINED_EXTENT; | ||||
auto orig_shape = tensor_formats_to_named_tensor_shape(orig_formats); | auto orig_shape = tensor_formats_to_named_tensor_shape(orig_formats); | ||||
@@ -623,8 +623,17 @@ TensorShape ReformatManager::make_aligned_tensor_shape( | |||||
: (orig_shape[idx] / target_shape[i]).extent(); | : (orig_shape[idx] / target_shape[i]).extent(); | ||||
if (mul) | if (mul) | ||||
tshp[i] = oshp[idx] * factor; | tshp[i] = oshp[idx] * factor; | ||||
else | |||||
tshp[i] = divup(oshp[idx], factor); | |||||
else { | |||||
if (allow_aligned) | |||||
tshp[i] = divup(oshp[idx], factor); | |||||
else if (!(oshp[idx] % factor)) { | |||||
tshp[i] = oshp[idx] / factor; | |||||
} else { | |||||
return TensorShape{}; | |||||
} | |||||
} | |||||
/// hack for nhwc auto padding | |||||
if (name == Dimension::Name::C) { | if (name == Dimension::Name::C) { | ||||
size_t channel_alignment = target_shape[i].stride(); | size_t channel_alignment = target_shape[i].stride(); | ||||
size_t channels = tshp[i] * channel_alignment; | size_t channels = tshp[i] * channel_alignment; | ||||
@@ -641,6 +650,15 @@ TensorShape ReformatManager::make_aligned_tensor_shape( | |||||
return tshp; | return tshp; | ||||
} | } | ||||
TensorShape ReformatManager::make_aligned_tensor_shape( | |||||
const VarNode* var, TensorFormats orig_formats, TensorFormats target_formats, | |||||
ReformatKey::Attribute extra_attribute) { | |||||
auto tshp = ReformatManager::try_make_tensor_shape( | |||||
var, orig_formats, target_formats, extra_attribute); | |||||
mgb_assert(tshp.ndim); | |||||
return tshp; | |||||
} | |||||
TensorShape ReformatManager::make_aligned_weight_shape( | TensorShape ReformatManager::make_aligned_weight_shape( | ||||
const VarNode* var, TensorFormats orig_formats, TensorFormats target_formats, | const VarNode* var, TensorFormats orig_formats, TensorFormats target_formats, | ||||
TensorFormats extra_formats, ReformatKey::Attribute extra_attribute) { | TensorFormats extra_formats, ReformatKey::Attribute extra_attribute) { | ||||
@@ -93,6 +93,8 @@ static inline TensorFormats opr_format_to_tensor_formats( | |||||
return TensorFormats::NCHWc4; | return TensorFormats::NCHWc4; | ||||
case OprFormat::NCHW8: | case OprFormat::NCHW8: | ||||
return TensorFormats::NCHWc8; | return TensorFormats::NCHWc8; | ||||
case OprFormat::NCHW44_DOT: | |||||
return TensorFormats::NCHWc4; | |||||
default: | default: | ||||
mgb_throw( | mgb_throw( | ||||
AssertionError, "format(%s) is not supported", | AssertionError, "format(%s) is not supported", | ||||
@@ -171,7 +173,6 @@ static inline megdnn::NamedTensorShape tensor_formats_to_named_tensor_shape( | |||||
static_cast<uint32_t>(format)); | static_cast<uint32_t>(format)); | ||||
} | } | ||||
} | } | ||||
} // namespace gopt | } // namespace gopt | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -135,6 +135,13 @@ public: | |||||
const VarNode* orig_var, const ReformatKey& key, | const VarNode* orig_var, const ReformatKey& key, | ||||
const AlignmentDesc& extra_alignment = {}) const; | const AlignmentDesc& extra_alignment = {}) const; | ||||
/// return empty shape, if shape of origin varnode does not satisfy the alignment | |||||
/// requirement of the target tensor formats | |||||
static TensorShape try_make_tensor_shape( | |||||
const VarNode* var, TensorFormats orig_formats, | |||||
TensorFormats target_formats, | |||||
ReformatKey::Attribute extra_attribute = ReformatKey::Attribute::DEFAULT, | |||||
bool allow_aligned = true); | |||||
static TensorShape make_aligned_tensor_shape( | static TensorShape make_aligned_tensor_shape( | ||||
const VarNode* var, TensorFormats orig_formats, | const VarNode* var, TensorFormats orig_formats, | ||||
TensorFormats target_formats, | TensorFormats target_formats, | ||||
@@ -18104,3 +18104,217 @@ static const std::vector<uint8_t> TestLayoutTransform_Wide = { | |||||
108, 111, 97, 116, 51, 50, 124, 49, 54, 44, 49, 54, 44, 56, 44, 56, | 108, 111, 97, 116, 51, 50, 124, 49, 54, 44, 49, 54, 44, 56, 44, 56, | ||||
59, 70, 108, 111, 97, 116, 51, 50, 124, 49, 48, 4, 0, 0, 0, 233, | 59, 70, 108, 111, 97, 116, 51, 50, 124, 49, 48, 4, 0, 0, 0, 233, | ||||
38, 177, 64}; | 38, 177, 64}; | ||||
static const std::vector<uint8_t> TestLayoutTransform_Concat = { | |||||
1, 0, 0, 0, 34, 0, 0, 0, 108, 97, 121, 111, 117, 116, 95, 116, | |||||
114, 97, 110, 115, 102, 111, 114, 109, 95, 112, 114, 111, 102, 105, 108, 101, | |||||
58, 112, 108, 97, 116, 61, 99, 117, 100, 97, 33, 0, 0, 0, 77, 0, | |||||
0, 0, 49, 54, 44, 50, 44, 51, 54, 56, 44, 54, 52, 48, 59, 81, | |||||
117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 123, 78, 44, 67, 47, | |||||
47, 52, 44, 72, 44, 87, 44, 67, 37, 52, 125, 59, 123, 78, 44, 67, | |||||
44, 72, 44, 87, 125, 59, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, | |||||
100, 83, 56, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 4, | |||||
0, 0, 0, 28, 90, 20, 67, 89, 0, 0, 0, 47, 225, 79, 220, 129, | |||||
196, 236, 15, 1, 0, 0, 0, 49, 54, 44, 50, 44, 51, 54, 56, 44, | |||||
54, 52, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, | |||||
49, 54, 44, 50, 44, 51, 54, 56, 44, 54, 52, 48, 59, 81, 117, 97, | |||||
110, 116, 105, 122, 101, 100, 83, 56, 124, 49, 54, 44, 52, 44, 51, 54, | |||||
56, 44, 54, 52, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, | |||||
56, 124, 51, 48, 4, 0, 0, 0, 128, 150, 24, 75, 77, 0, 0, 0, | |||||
49, 54, 44, 52, 44, 49, 56, 52, 44, 51, 50, 48, 59, 81, 117, 97, | |||||
110, 116, 105, 122, 101, 100, 83, 56, 124, 123, 78, 44, 67, 47, 47, 52, | |||||
44, 72, 44, 87, 44, 67, 37, 52, 125, 59, 123, 78, 44, 67, 44, 72, | |||||
44, 87, 125, 59, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, | |||||
56, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 4, 0, 0, | |||||
0, 236, 81, 208, 66, 52, 0, 0, 0, 145, 22, 69, 252, 191, 174, 0, | |||||
1, 49, 54, 44, 51, 44, 55, 51, 54, 44, 49, 50, 56, 48, 59, 85, | |||||
105, 110, 116, 56, 124, 49, 54, 44, 51, 44, 55, 51, 54, 44, 49, 50, | |||||
56, 48, 59, 70, 108, 111, 97, 116, 51, 50, 124, 48, 48, 4, 0, 0, | |||||
0, 88, 57, 79, 68, 66, 0, 0, 0, 49, 54, 44, 51, 44, 55, 51, | |||||
54, 44, 49, 50, 56, 48, 59, 70, 108, 111, 97, 116, 51, 50, 124, 123, | |||||
78, 44, 67, 44, 72, 44, 87, 125, 59, 123, 78, 44, 67, 47, 47, 52, | |||||
44, 72, 44, 87, 44, 67, 37, 52, 125, 59, 48, 59, 70, 108, 111, 97, | |||||
116, 51, 50, 59, 70, 108, 111, 97, 116, 51, 50, 4, 0, 0, 0, 134, | |||||
75, 130, 69, 163, 0, 0, 0, 10, 151, 159, 190, 73, 14, 145, 194, 1, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, | |||||
0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 1, | |||||
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, | |||||
0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 49, 54, 44, 56, 44, | |||||
49, 56, 52, 44, 51, 50, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, | |||||
100, 83, 56, 124, 49, 54, 44, 56, 44, 51, 44, 51, 59, 81, 117, 97, | |||||
110, 116, 105, 122, 101, 100, 83, 56, 124, 49, 44, 49, 54, 44, 49, 44, | |||||
49, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 51, 50, 124, 49, | |||||
54, 44, 49, 54, 44, 57, 50, 44, 49, 54, 48, 59, 81, 117, 97, 110, | |||||
116, 105, 122, 101, 100, 83, 56, 124, 51, 48, 4, 0, 0, 0, 68, 139, | |||||
104, 67, 60, 0, 0, 0, 49, 54, 44, 51, 44, 55, 51, 54, 44, 49, | |||||
50, 56, 48, 59, 85, 105, 110, 116, 56, 124, 123, 78, 44, 67, 47, 47, | |||||
52, 44, 72, 44, 87, 44, 67, 37, 52, 125, 59, 123, 78, 44, 67, 44, | |||||
72, 44, 87, 125, 59, 48, 59, 85, 105, 110, 116, 56, 59, 85, 105, 110, | |||||
116, 56, 4, 0, 0, 0, 201, 161, 135, 68, 64, 0, 0, 0, 218, 182, | |||||
120, 146, 221, 155, 179, 235, 11, 0, 0, 0, 49, 54, 44, 52, 44, 49, | |||||
56, 52, 44, 51, 50, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, | |||||
83, 56, 124, 49, 54, 44, 52, 44, 49, 56, 52, 44, 51, 50, 48, 59, | |||||
81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 48, 48, 4, 0, | |||||
0, 0, 182, 243, 29, 66, 161, 0, 0, 0, 10, 151, 159, 190, 73, 14, | |||||
145, 194, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, 0, | |||||
0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, | |||||
0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 49, 54, | |||||
44, 52, 44, 51, 54, 56, 44, 54, 52, 48, 59, 81, 117, 97, 110, 116, | |||||
105, 122, 101, 100, 83, 56, 124, 52, 44, 52, 44, 51, 44, 51, 59, 81, | |||||
117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 49, 44, 52, 44, 49, | |||||
44, 49, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 51, 50, 124, | |||||
49, 54, 44, 52, 44, 49, 56, 52, 44, 51, 50, 48, 59, 81, 117, 97, | |||||
110, 116, 105, 122, 101, 100, 83, 56, 124, 48, 48, 4, 0, 0, 0, 128, | |||||
150, 24, 75, 60, 0, 0, 0, 49, 54, 44, 51, 44, 55, 51, 54, 44, | |||||
49, 50, 56, 48, 59, 85, 105, 110, 116, 56, 124, 123, 78, 44, 67, 44, | |||||
72, 44, 87, 125, 59, 123, 78, 44, 67, 47, 47, 52, 44, 72, 44, 87, | |||||
44, 67, 37, 52, 125, 59, 48, 59, 85, 105, 110, 116, 56, 59, 85, 105, | |||||
110, 116, 56, 4, 0, 0, 0, 203, 33, 219, 68, 68, 0, 0, 0, 218, | |||||
182, 120, 146, 221, 155, 179, 235, 16, 0, 0, 0, 49, 54, 44, 51, 44, | |||||
55, 51, 54, 44, 49, 50, 56, 48, 59, 70, 108, 111, 97, 116, 51, 50, | |||||
124, 49, 59, 70, 108, 111, 97, 116, 51, 50, 124, 49, 54, 44, 51, 44, | |||||
55, 51, 54, 44, 49, 50, 56, 48, 59, 70, 108, 111, 97, 116, 51, 50, | |||||
124, 48, 48, 4, 0, 0, 0, 205, 204, 161, 68, 77, 0, 0, 0, 49, | |||||
54, 44, 52, 44, 51, 54, 56, 44, 54, 52, 48, 59, 81, 117, 97, 110, | |||||
116, 105, 122, 101, 100, 83, 56, 124, 123, 78, 44, 67, 47, 47, 52, 44, | |||||
72, 44, 87, 44, 67, 37, 52, 125, 59, 123, 78, 44, 67, 44, 72, 44, | |||||
87, 125, 59, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, | |||||
59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 4, 0, 0, 0, | |||||
200, 118, 174, 67, 163, 0, 0, 0, 10, 151, 159, 190, 73, 14, 145, 194, | |||||
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
1, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, | |||||
1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, | |||||
0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 49, 54, 44, 56, | |||||
44, 49, 56, 52, 44, 51, 50, 48, 59, 81, 117, 97, 110, 116, 105, 122, | |||||
101, 100, 83, 56, 124, 49, 54, 44, 56, 44, 51, 44, 51, 59, 81, 117, | |||||
97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 49, 44, 49, 54, 44, 49, | |||||
44, 49, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 51, 50, 124, | |||||
49, 54, 44, 49, 54, 44, 57, 50, 44, 49, 54, 48, 59, 81, 117, 97, | |||||
110, 116, 105, 122, 101, 100, 83, 56, 124, 48, 48, 4, 0, 0, 0, 128, | |||||
150, 24, 75, 161, 0, 0, 0, 10, 151, 159, 190, 73, 14, 145, 194, 1, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, | |||||
0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 1, | |||||
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, | |||||
0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 49, 54, 44, 52, 44, | |||||
51, 54, 56, 44, 54, 52, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, | |||||
100, 83, 56, 124, 52, 44, 52, 44, 51, 44, 51, 59, 81, 117, 97, 110, | |||||
116, 105, 122, 101, 100, 83, 56, 124, 49, 44, 52, 44, 49, 44, 49, 59, | |||||
81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 51, 50, 124, 49, 54, 44, | |||||
52, 44, 49, 56, 52, 44, 51, 50, 48, 59, 81, 117, 97, 110, 116, 105, | |||||
122, 101, 100, 83, 56, 124, 51, 48, 4, 0, 0, 0, 138, 65, 164, 67, | |||||
77, 0, 0, 0, 49, 54, 44, 50, 44, 51, 54, 56, 44, 54, 52, 48, | |||||
59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 123, 78, 44, | |||||
67, 44, 72, 44, 87, 125, 59, 123, 78, 44, 67, 47, 47, 52, 44, 72, | |||||
44, 87, 44, 67, 37, 52, 125, 59, 48, 59, 81, 117, 97, 110, 116, 105, | |||||
122, 101, 100, 83, 56, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, | |||||
56, 4, 0, 0, 0, 137, 65, 190, 67, 58, 0, 0, 0, 145, 22, 69, | |||||
252, 191, 174, 0, 1, 49, 54, 44, 51, 44, 55, 51, 54, 44, 49, 50, | |||||
56, 48, 59, 70, 108, 111, 97, 116, 51, 50, 124, 49, 54, 44, 51, 44, | |||||
55, 51, 54, 44, 49, 50, 56, 48, 59, 81, 117, 97, 110, 116, 105, 122, | |||||
101, 100, 83, 56, 124, 48, 48, 4, 0, 0, 0, 38, 177, 146, 68, 89, | |||||
0, 0, 0, 47, 225, 79, 220, 129, 196, 236, 15, 1, 0, 0, 0, 49, | |||||
54, 44, 50, 44, 51, 54, 56, 44, 54, 52, 48, 59, 81, 117, 97, 110, | |||||
116, 105, 122, 101, 100, 83, 56, 124, 49, 54, 44, 50, 44, 51, 54, 56, | |||||
44, 54, 52, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, | |||||
124, 49, 54, 44, 52, 44, 51, 54, 56, 44, 54, 52, 48, 59, 81, 117, | |||||
97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 48, 48, 4, 0, 0, 0, | |||||
150, 67, 235, 66, 162, 0, 0, 0, 10, 151, 159, 190, 73, 14, 145, 194, | |||||
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
1, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, | |||||
1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, | |||||
0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 49, 54, 44, 51, | |||||
44, 55, 51, 54, 44, 49, 50, 56, 48, 59, 81, 117, 97, 110, 116, 105, | |||||
122, 101, 100, 83, 56, 124, 50, 44, 51, 44, 51, 44, 51, 59, 81, 117, | |||||
97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 49, 44, 50, 44, 49, 44, | |||||
49, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 51, 50, 124, 49, | |||||
54, 44, 50, 44, 51, 54, 56, 44, 54, 52, 48, 59, 81, 117, 97, 110, | |||||
116, 105, 122, 101, 100, 83, 56, 124, 51, 48, 4, 0, 0, 0, 128, 150, | |||||
24, 75, 64, 0, 0, 0, 218, 182, 120, 146, 221, 155, 179, 235, 11, 0, | |||||
0, 0, 49, 54, 44, 50, 44, 51, 54, 56, 44, 54, 52, 48, 59, 81, | |||||
117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 49, 54, 44, 50, 44, | |||||
51, 54, 56, 44, 54, 52, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, | |||||
100, 83, 56, 124, 51, 48, 4, 0, 0, 0, 227, 165, 243, 66, 77, 0, | |||||
0, 0, 49, 54, 44, 52, 44, 49, 56, 52, 44, 51, 50, 48, 59, 81, | |||||
117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 123, 78, 44, 67, 44, | |||||
72, 44, 87, 125, 59, 123, 78, 44, 67, 47, 47, 52, 44, 72, 44, 87, | |||||
44, 67, 37, 52, 125, 59, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, | |||||
100, 83, 56, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 4, | |||||
0, 0, 0, 174, 71, 169, 66, 162, 0, 0, 0, 10, 151, 159, 190, 73, | |||||
14, 145, 194, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, | |||||
0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, | |||||
0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 49, | |||||
54, 44, 51, 44, 55, 51, 54, 44, 49, 50, 56, 48, 59, 81, 117, 97, | |||||
110, 116, 105, 122, 101, 100, 83, 56, 124, 50, 44, 51, 44, 51, 44, 51, | |||||
59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 49, 44, 50, | |||||
44, 49, 44, 49, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 51, | |||||
50, 124, 49, 54, 44, 50, 44, 51, 54, 56, 44, 54, 52, 48, 59, 81, | |||||
117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 48, 48, 4, 0, 0, | |||||
0, 128, 150, 24, 75, 66, 0, 0, 0, 49, 54, 44, 51, 44, 55, 51, | |||||
54, 44, 49, 50, 56, 48, 59, 70, 108, 111, 97, 116, 51, 50, 124, 123, | |||||
78, 44, 67, 47, 47, 52, 44, 72, 44, 87, 44, 67, 37, 52, 125, 59, | |||||
123, 78, 44, 67, 44, 72, 44, 87, 125, 59, 48, 59, 70, 108, 111, 97, | |||||
116, 51, 50, 59, 70, 108, 111, 97, 116, 51, 50, 4, 0, 0, 0, 196, | |||||
32, 63, 69, 89, 0, 0, 0, 47, 225, 79, 220, 129, 196, 236, 15, 1, | |||||
0, 0, 0, 49, 54, 44, 52, 44, 49, 56, 52, 44, 51, 50, 48, 59, | |||||
81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 49, 54, 44, 52, | |||||
44, 49, 56, 52, 44, 51, 50, 48, 59, 81, 117, 97, 110, 116, 105, 122, | |||||
101, 100, 83, 56, 124, 49, 54, 44, 56, 44, 49, 56, 52, 44, 51, 50, | |||||
48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 48, 48, | |||||
4, 0, 0, 0, 156, 196, 136, 66, 77, 0, 0, 0, 49, 54, 44, 49, | |||||
54, 44, 57, 50, 44, 49, 54, 48, 59, 81, 117, 97, 110, 116, 105, 122, | |||||
101, 100, 83, 56, 124, 123, 78, 44, 67, 47, 47, 52, 44, 72, 44, 87, | |||||
44, 67, 37, 52, 125, 59, 123, 78, 44, 67, 44, 72, 44, 87, 125, 59, | |||||
48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 59, 81, 117, | |||||
97, 110, 116, 105, 122, 101, 100, 83, 56, 4, 0, 0, 0, 207, 247, 211, | |||||
66, 77, 0, 0, 0, 49, 54, 44, 56, 44, 49, 56, 52, 44, 51, 50, | |||||
48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 123, 78, | |||||
44, 67, 47, 47, 52, 44, 72, 44, 87, 44, 67, 37, 52, 125, 59, 123, | |||||
78, 44, 67, 44, 72, 44, 87, 125, 59, 48, 59, 81, 117, 97, 110, 116, | |||||
105, 122, 101, 100, 83, 56, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, | |||||
83, 56, 4, 0, 0, 0, 54, 94, 58, 67, 64, 0, 0, 0, 218, 182, | |||||
120, 146, 221, 155, 179, 235, 11, 0, 0, 0, 49, 54, 44, 50, 44, 51, | |||||
54, 56, 44, 54, 52, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, | |||||
83, 56, 124, 49, 54, 44, 50, 44, 51, 54, 56, 44, 54, 52, 48, 59, | |||||
81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 48, 48, 4, 0, | |||||
0, 0, 155, 196, 128, 66, 78, 0, 0, 0, 49, 54, 44, 51, 44, 55, | |||||
51, 54, 44, 49, 50, 56, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, | |||||
100, 83, 56, 124, 123, 78, 44, 67, 44, 72, 44, 87, 125, 59, 123, 78, | |||||
44, 67, 47, 47, 52, 44, 72, 44, 87, 44, 67, 37, 52, 125, 59, 48, | |||||
59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 59, 81, 117, 97, | |||||
110, 116, 105, 122, 101, 100, 83, 56, 4, 0, 0, 0, 191, 74, 159, 68, | |||||
89, 0, 0, 0, 47, 225, 79, 220, 129, 196, 236, 15, 1, 0, 0, 0, | |||||
49, 54, 44, 52, 44, 49, 56, 52, 44, 51, 50, 48, 59, 81, 117, 97, | |||||
110, 116, 105, 122, 101, 100, 83, 56, 124, 49, 54, 44, 52, 44, 49, 56, | |||||
52, 44, 51, 50, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, | |||||
56, 124, 49, 54, 44, 56, 44, 49, 56, 52, 44, 51, 50, 48, 59, 81, | |||||
117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 51, 48, 4, 0, 0, | |||||
0, 170, 241, 138, 66, 77, 0, 0, 0, 49, 54, 44, 52, 44, 51, 54, | |||||
56, 44, 54, 52, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, | |||||
56, 124, 123, 78, 44, 67, 44, 72, 44, 87, 125, 59, 123, 78, 44, 67, | |||||
47, 47, 52, 44, 72, 44, 87, 44, 67, 37, 52, 125, 59, 48, 59, 81, | |||||
117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 59, 81, 117, 97, 110, 116, | |||||
105, 122, 101, 100, 83, 56, 4, 0, 0, 0, 79, 141, 147, 67, 77, 0, | |||||
0, 0, 49, 54, 44, 49, 54, 44, 57, 50, 44, 49, 54, 48, 59, 81, | |||||
117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 123, 78, 44, 67, 44, | |||||
72, 44, 87, 125, 59, 123, 78, 44, 67, 47, 47, 52, 44, 72, 44, 87, | |||||
44, 67, 37, 52, 125, 59, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, | |||||
100, 83, 56, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 4, | |||||
0, 0, 0, 170, 241, 170, 66, 77, 0, 0, 0, 49, 54, 44, 56, 44, | |||||
49, 56, 52, 44, 51, 50, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, | |||||
100, 83, 56, 124, 123, 78, 44, 67, 44, 72, 44, 87, 125, 59, 123, 78, | |||||
44, 67, 47, 47, 52, 44, 72, 44, 87, 44, 67, 37, 52, 125, 59, 48, | |||||
59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 59, 81, 117, 97, | |||||
110, 116, 105, 122, 101, 100, 83, 56, 4, 0, 0, 0, 203, 161, 25, 67, | |||||
64, 0, 0, 0, 218, 182, 120, 146, 221, 155, 179, 235, 11, 0, 0, 0, | |||||
49, 54, 44, 52, 44, 49, 56, 52, 44, 51, 50, 48, 59, 81, 117, 97, | |||||
110, 116, 105, 122, 101, 100, 83, 56, 124, 49, 54, 44, 52, 44, 49, 56, | |||||
52, 44, 51, 50, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, | |||||
56, 124, 51, 48, 4, 0, 0, 0, 102, 102, 22, 66, 78, 0, 0, 0, | |||||
49, 54, 44, 51, 44, 55, 51, 54, 44, 49, 50, 56, 48, 59, 81, 117, | |||||
97, 110, 116, 105, 122, 101, 100, 83, 56, 124, 123, 78, 44, 67, 47, 47, | |||||
52, 44, 72, 44, 87, 44, 67, 37, 52, 125, 59, 123, 78, 44, 67, 44, | |||||
72, 44, 87, 125, 59, 48, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, | |||||
83, 56, 59, 81, 117, 97, 110, 116, 105, 122, 101, 100, 83, 56, 4, 0, | |||||
0, 0, 54, 94, 72, 68}; |
@@ -20,7 +20,7 @@ | |||||
# 2. 编译megbrain_test,并运行所有全局图优化相关测试: | # 2. 编译megbrain_test,并运行所有全局图优化相关测试: | ||||
# ./megbrain_test --gtest_filter="*LayoutTransform*" | # ./megbrain_test --gtest_filter="*LayoutTransform*" | ||||
# 3. 用这个脚本把所有的cache文件打包在一起 | # 3. 用这个脚本把所有的cache文件打包在一起 | ||||
# python3 embed_cache.py -o cache_data.h -r $(ls /path/to/cache/*.cache) | |||||
# python3 embed_cache.py -o cache_data.h -r -a $(ls /path/to/cache/*.cache) | |||||
# 4. 将步骤1中的 define 语句改回原样,这样 profile 过程就会使用 cache 下来的数据。 | # 4. 将步骤1中的 define 语句改回原样,这样 profile 过程就会使用 cache 下来的数据。 | ||||
# 5. 最后可以重新构建一下 megbrain_test ,确保测试结果正确。 | # 5. 最后可以重新构建一下 megbrain_test ,确保测试结果正确。 | ||||
import os.path | import os.path | ||||
@@ -44,9 +44,10 @@ def _u32(data): | |||||
class CacheDataGenerator: | class CacheDataGenerator: | ||||
_cache_files = None | _cache_files = None | ||||
def __init__(self, cache_files, remove_plat_info = True): | |||||
def __init__(self, cache_files, remove_plat_info=True, append_cache=True): | |||||
self._cache_files = cache_files | self._cache_files = cache_files | ||||
self._remove_plat_info = remove_plat_info | self._remove_plat_info = remove_plat_info | ||||
self._append_cache = append_cache | |||||
def _get_hash(self): | def _get_hash(self): | ||||
return _u32(self._hash.digest()[:4]) | return _u32(self._hash.digest()[:4]) | ||||
@@ -71,9 +72,10 @@ class CacheDataGenerator: | |||||
return ','.join(ret) | return ','.join(ret) | ||||
def gen_cache_data_header(self, fout, src_map): | def gen_cache_data_header(self, fout, src_map): | ||||
fout.write('// generated embed_cache.py\n') | |||||
fout.write('#include <vector>\n') | |||||
fout.write('#include <stdint.h>\n') | |||||
if not self._append_cache: | |||||
fout.write('// generated embed_cache.py\n') | |||||
fout.write('#include <vector>\n') | |||||
fout.write('#include <stdint.h>\n') | |||||
for k, v in sorted(src_map.items()): | for k, v in sorted(src_map.items()): | ||||
fout.write(""" | fout.write(""" | ||||
static const std::vector<uint8_t> {} = {{ | static const std::vector<uint8_t> {} = {{ | ||||
@@ -89,7 +91,11 @@ static const std::vector<uint8_t> {} = {{ | |||||
assert ext == ".cache", "ext: {}, fname {}".format(ext, fname) | assert ext == ".cache", "ext: {}, fname {}".format(ext, fname) | ||||
assert base not in fname2cache_data, "duplicated kernel: " + base | assert base not in fname2cache_data, "duplicated kernel: " + base | ||||
fname2cache_data[base] = self.gen_cache_data(fname) | fname2cache_data[base] = self.gen_cache_data(fname) | ||||
with open(output, 'w') as fout: | |||||
if self._append_cache: | |||||
mode = 'a' | |||||
else: | |||||
mode = 'w' | |||||
with open(output, mode) as fout: | |||||
self.gen_cache_data_header(fout, fname2cache_data) | self.gen_cache_data_header(fout, fname2cache_data) | ||||
logger.info('done') | logger.info('done') | ||||
@@ -107,7 +113,15 @@ if __name__ == '__main__': | |||||
default=True, | default=True, | ||||
help="whether remove platform infomation in the cache (default: True)" | help="whether remove platform infomation in the cache (default: True)" | ||||
) | ) | ||||
parser.add_argument( | |||||
"-a", | |||||
"--append-cache", | |||||
action='store_true', | |||||
default=True, | |||||
help="whether append the cache (default: True)" | |||||
) | |||||
parser.add_argument('cache', help='cache files to be embedded', nargs='+') | parser.add_argument('cache', help='cache files to be embedded', nargs='+') | ||||
args = parser.parse_args() | args = parser.parse_args() | ||||
cache_generator = CacheDataGenerator(args.cache, args.remove_plat_info) | |||||
cache_generator = CacheDataGenerator(args.cache, args.remove_plat_info, | |||||
args.append_cache) | |||||
cache_generator.invoke(args.output) | cache_generator.invoke(args.output) |
@@ -812,7 +812,9 @@ TEST(TestLayoutTransform, Resnet18_F16) { | |||||
.add_pass<LayoutTransformPass>(std::move(ctx), std::move(solver)) | .add_pass<LayoutTransformPass>(std::move(ctx), std::move(solver)) | ||||
.add_pass<ShuffleShuffleRemovePass>() | .add_pass<ShuffleShuffleRemovePass>() | ||||
.add_pass(FuseNCHW4Int8Preprocess::make()) | .add_pass(FuseNCHW4Int8Preprocess::make()) | ||||
#if CUDA_VERSION >= 10020 | |||||
.add_pass<FoldingConvBiasDimshufflePass>() | .add_pass<FoldingConvBiasDimshufflePass>() | ||||
#endif | |||||
.add_pass<ParamFusePass>() | .add_pass<ParamFusePass>() | ||||
.add_pass<ParamMergePass>() | .add_pass<ParamMergePass>() | ||||
.apply({{output}}) | .apply({{output}}) | ||||
@@ -1205,4 +1207,112 @@ TEST(TestLayoutTransform, MobileNetV2_NCHW44_DOT) { | |||||
MGB_ASSERT_TENSOR_EQ(t1, t2); | MGB_ASSERT_TENSOR_EQ(t1, t2); | ||||
} | } | ||||
#if MGB_CUDA | |||||
TEST(TestLayoutTransform, Concat) { | |||||
REQUIRE_GPU(1); | |||||
auto cn = CompNode::load("gpu0"); | |||||
cn.activate(); | |||||
REQUIRE_CUDA_COMPUTE_CAPABILITY(6, 1); | |||||
constexpr size_t N = 16, C = 3, H = 736, W = 1280; | |||||
HostTensorGenerator<dtype::Uint8> gen; | |||||
auto graph = ComputingGraph::make(); | |||||
auto h2d = opr::Host2DeviceCopy::make(*graph, gen({N, C, H, W}, cn)); | |||||
auto data = opr::TypeCvt::make(h2d, dtype::Float32()); | |||||
auto sub_128 = data + (-128); | |||||
auto x = opr::TypeCvt::make(sub_128, dtype::QuantizedS8(1.f)); | |||||
auto mkcvar = [&](const char* name, const TensorShape& shp, const DType& dtype) { | |||||
return opr::TypeCvt::make( | |||||
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name), | |||||
dtype); | |||||
}; | |||||
auto w = mkcvar("w", {2, 3, 3, 3}, dtype::QuantizedS8(1.f)); | |||||
auto b = mkcvar("b", {1, 2, 1, 1}, dtype::QuantizedS32(1.f)); | |||||
opr::ConvBias::Param param; | |||||
param.format = opr::ConvBias::Param::Format::NCHW; | |||||
param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; | |||||
param.stride_h = param.stride_w = 2; | |||||
param.pad_h = param.pad_w = 1; | |||||
auto conv_1 = opr::ConvBias::make( | |||||
x, w, b, param, {}, OperatorNodeConfig(dtype::QuantizedS8(1.f))); | |||||
auto conv_1_cat = opr::Concat::make({conv_1, -conv_1}, 1); | |||||
auto w2 = mkcvar("w", {4, 4, 3, 3}, dtype::QuantizedS8(1.f)); | |||||
auto b2 = mkcvar("b", {1, 4, 1, 1}, dtype::QuantizedS32(1.f)); | |||||
auto conv_2 = opr::ConvBias::make( | |||||
conv_1_cat, w2, b2, param, {}, OperatorNodeConfig(dtype::QuantizedS8(1.f))); | |||||
auto conv_2_cat = opr::Concat::make({conv_2, -conv_2}, 1); | |||||
auto w3 = mkcvar("w", {16, 8, 3, 3}, dtype::QuantizedS8(1.f)); | |||||
auto b3 = mkcvar("b", {1, 16, 1, 1}, dtype::QuantizedS32(1.f)); | |||||
auto y = opr::ConvBias::make( | |||||
conv_2_cat, w3, b3, param, {}, OperatorNodeConfig(dtype::QuantizedS8(1.f))); | |||||
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||||
S strategy = S::PROFILE; | |||||
gopt::modify_opr_algo_strategy_inplace({y}, strategy); | |||||
using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; | |||||
using OprList = LayoutTransformContext::OprList; | |||||
using Attribute = LayoutTransformContext::Attribute; | |||||
using Target = LayoutTransformContext::Target; | |||||
OprList opr_list = { | |||||
opr::ConvBiasForward::typeinfo(), opr::ElemwiseMultiType::typeinfo(), | |||||
opr::Elemwise::typeinfo(), opr::TypeCvt::typeinfo(), | |||||
opr::Concat::typeinfo(), | |||||
}; | |||||
SmallVector<TensorFormats> available_tensor_formats = { | |||||
TensorFormats::NCHW, TensorFormats::NCHWc4}; | |||||
Attribute attribute = { | |||||
OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC}; | |||||
auto ctx = std::make_unique<LayoutTransformContext>( | |||||
std::move(opr_list), std::move(available_tensor_formats), attribute); | |||||
ctx->add_opr_config( | |||||
opr::ConvBiasForward::typeinfo(), | |||||
{OprFormatConfigID::NCHW, OprFormatConfigID::NCHW4}); | |||||
#if MGB_WITH_CACHED_TEST | |||||
auto profiler = std::make_unique<ProfilerMock>( | |||||
static_cast<const uint8_t*>(TestLayoutTransform_Concat.data()), | |||||
TestLayoutTransform_Concat.size()); | |||||
#else | |||||
auto profiler = | |||||
ProfilerBase::make_cached_profiler("TestLayoutTransform.Concat.cache"); | |||||
#endif | |||||
std::unique_ptr<SolverBase> solver{ | |||||
new DynamicProgrammingSolver(std::move(profiler))}; | |||||
auto new_out_vars = | |||||
gopt::GraphOptimizer{} | |||||
.add_pass<LayoutTransformPass>(std::move(ctx), std::move(solver)) | |||||
.add_pass<ShuffleShuffleRemovePass>() | |||||
.add_pass(FuseNCHW4Int8Preprocess::make()) | |||||
#if CUDA_VERSION >= 10020 | |||||
.add_pass<FoldingConvBiasDimshufflePass>() | |||||
.add_pass<FoldingConvBiasTypecvtPass>() | |||||
#endif | |||||
.add_pass<ParamFusePass>() | |||||
.add_pass<ParamMergePass>() | |||||
.apply(SymbolVarArray{y}) | |||||
.endpoint_vars(); | |||||
const auto& v = new_out_vars[0]; | |||||
using OutputSpecItem = cg::ComputingGraph::OutputSpecItem; | |||||
std::vector<OutputSpecItem> outs; | |||||
for (auto&& i : new_out_vars) { | |||||
outs.emplace_back(OutputSpecItem{i, {}}); | |||||
} | |||||
GraphProfiler gprof{graph.get()}; | |||||
auto func = graph->compile(outs); | |||||
func->execute(); | |||||
gprof.to_json_full(func.get())->writeto_fpath(output_file("conv_cat.json")); | |||||
SmallVector<cg::OperatorNodeBase*> oprs; | |||||
auto cb = [&oprs](cg::OperatorNodeBase* opr) { | |||||
if (opr->same_type<opr::Concat>()) { | |||||
oprs.push_back(opr); | |||||
} | |||||
}; | |||||
cg::DepOprIter{cb}.add(v.node()->owner_opr()); | |||||
ASSERT_EQ(oprs.size(), 4); | |||||
ASSERT_EQ(oprs[0]->output(0)->shape().ndim, 4); | |||||
ASSERT_EQ(oprs[2]->output(0)->shape().ndim, 5); | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |