GitOrigin-RevId: 038e372cbe
release-1.6
@@ -200,6 +200,15 @@ static inline bool is_nchw_nchw4_shuffle_vec( | |||||
param.pattern[4] == 2; | param.pattern[4] == 2; | ||||
} | } | ||||
static inline bool is_shape_before_nhwc(const TensorShape& shape) { | |||||
return shape.ndim == 4 && shape[1] == 4; | |||||
} | |||||
static inline bool is_nchw_nhwc_shuffle(const opr::Dimshuffle::Param param) { | |||||
return param.ndim == 4 && param.pattern[0] == 0 && param.pattern[1] == 2 && | |||||
param.pattern[2] == 3 && param.pattern[3] == 1; | |||||
} | |||||
template <typename T> | template <typename T> | ||||
static inline bool is_immutable_equal(OperatorNodeBase* opr, T val, | static inline bool is_immutable_equal(OperatorNodeBase* opr, T val, | ||||
DTypeEnum dtype_enum) { | DTypeEnum dtype_enum) { | ||||
@@ -276,14 +285,20 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() { | |||||
auto inp0 = opr->input()[0]; | auto inp0 = opr->input()[0]; | ||||
return is_shape_nchw(inp0->shape()); | return is_shape_nchw(inp0->shape()); | ||||
}}; | }}; | ||||
SGM::Node shuffle_root{ | SGM::Node shuffle_root{ | ||||
opr::Dimshuffle::typeinfo(), | opr::Dimshuffle::typeinfo(), | ||||
{{nchwx_reshape}}, | |||||
{{nchwx_reshape}, {broadcast_concat}}, | |||||
[](OperatorNodeBase* opr) { | [](OperatorNodeBase* opr) { | ||||
auto& shuffle_opr = opr->cast_final<opr::Dimshuffle>(); | auto& shuffle_opr = opr->cast_final<opr::Dimshuffle>(); | ||||
auto& input_vec = shuffle_opr.input(); | auto& input_vec = shuffle_opr.input(); | ||||
return is_shape_before_nchw4(input_vec[0]->shape()) && | |||||
is_nchw_nchw4_shuffle_vec(shuffle_opr.param()); | |||||
bool nchw_nchw4_ok = | |||||
is_shape_before_nchw4(input_vec[0]->shape()) && | |||||
is_nchw_nchw4_shuffle_vec(shuffle_opr.param()); | |||||
bool nchw_nhwc_ok = | |||||
is_shape_before_nhwc(input_vec[0]->shape()) && | |||||
is_nchw_nhwc_shuffle(shuffle_opr.param()); | |||||
return nchw_nchw4_ok || nchw_nhwc_ok; | |||||
}}; | }}; | ||||
return shuffle_root; | return shuffle_root; | ||||
}; | }; | ||||
@@ -382,6 +397,19 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() { | |||||
auto out_node = opr::RelayoutFormat::make( | auto out_node = opr::RelayoutFormat::make( | ||||
rewriter.get_var(src_node->output()[0]), param.mode, | rewriter.get_var(src_node->output()[0]), param.mode, | ||||
config); | config); | ||||
const auto& outshp = opr->output(0)->shape(); | |||||
if (outshp.ndim == 4) { | |||||
auto shpvar = opr::GetVarShape::make(out_node); | |||||
auto cv = [&out_node](int v) { | |||||
return out_node.make_scalar(v); | |||||
}; | |||||
auto sub = [&shpvar, &cv](int idx) { | |||||
return opr::IndexAt::make(shpvar, {{0, cv(idx)}}); | |||||
}; | |||||
auto nhwc_shp = | |||||
opr::Concat::make({sub(0), sub(2), sub(3), sub(4)}, 0); | |||||
out_node = opr::Reshape::make(out_node, nhwc_shp); | |||||
} | |||||
return out_node.node()->owner_opr(); | return out_node.node()->owner_opr(); | ||||
} else { | } else { | ||||
return serialization::copy_opr_shallow(*opr, new_inp, | return serialization::copy_opr_shallow(*opr, new_inp, | ||||
@@ -740,4 +768,4 @@ void FuseWarpPerspectiveDimshufflePass::apply(OptState& opt) const { | |||||
}; | }; | ||||
opt.graph().iter(on_opr); | opt.graph().iter(on_opr); | ||||
rewriter.apply_inplace(); | rewriter.apply_inplace(); | ||||
} | |||||
} |
@@ -92,19 +92,24 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||||
bool is_parameter = | bool is_parameter = | ||||
fmtcfg.valid() && fmtcfg.val().input_tensor_types[i] == | fmtcfg.valid() && fmtcfg.val().input_tensor_types[i] == | ||||
TensorType::WEIGHT; | TensorType::WEIGHT; | ||||
ReformatManager::ReformatImpl reformat; | |||||
ReformatManager::ReformatKey key{from, to, reformat_attribute, | |||||
var->dtype().enumv(), | |||||
var->dtype().enumv()}; | |||||
if (is_parameter) { | |||||
auto aligned_desc = make_aligned_desc(base_fmt, out_fmt); | |||||
reformat = ReformatManager::instance() | |||||
.auto_aligned_reformat_weight( | |||||
var, key, aligned_desc); | |||||
} else { | |||||
reformat = ReformatManager::instance() | |||||
.auto_aligned_reformat_featrue( | |||||
var, base_fmt, key); | |||||
// need relayout | |||||
if (from != to && !new_var->shape().is_scalar()) { | |||||
ReformatManager::ReformatImpl reformat; | |||||
ReformatManager::ReformatKey key{ | |||||
from, to, reformat_attribute, var->dtype().enumv(), | |||||
var->dtype().enumv()}; | |||||
if (is_parameter) { | |||||
auto aligned_desc = ReformatManager::make_aligned_desc( | |||||
base_fmt, out_fmt); | |||||
reformat = ReformatManager::instance() | |||||
.auto_aligned_reformat_weight( | |||||
var, key, aligned_desc); | |||||
} else { | |||||
reformat = ReformatManager::instance() | |||||
.auto_aligned_reformat_featrue( | |||||
var, base_fmt, key); | |||||
} | |||||
new_var = reformat({new_var}); | |||||
} | } | ||||
if (from != to && !new_var->shape().is_scalar()) | if (from != to && !new_var->shape().is_scalar()) | ||||
new_var = reformat({new_var}); | new_var = reformat({new_var}); | ||||
@@ -165,6 +165,7 @@ public: | |||||
private: | private: | ||||
static constexpr float PROFILE_TIME_OUT = 1e7; | static constexpr float PROFILE_TIME_OUT = 1e7; | ||||
using ReformatAttribute = ReformatKey::Attribute; | |||||
/*! | /*! | ||||
* \brief profile opr format agnostic operators (like elemwise, elemwise multi type, typecvt etc.) | * \brief profile opr format agnostic operators (like elemwise, elemwise multi type, typecvt etc.) | ||||
* | * | ||||
@@ -175,40 +176,48 @@ private: | |||||
*/ | */ | ||||
OperatorNodeRecord profile_operator( | OperatorNodeRecord profile_operator( | ||||
const OperatorNodeBase* opr, TensorFormats base_format, | const OperatorNodeBase* opr, TensorFormats base_format, | ||||
const SmallVector<TensorFormats>& available_tensor_formats) const; | |||||
const SmallVector<TensorFormats>& available_tensor_formats, | |||||
ReformatAttribute extra_attribute = | |||||
ReformatAttribute::DEFAULT) const; | |||||
float profile_operator(const OperatorNodeBase* opr, | float profile_operator(const OperatorNodeBase* opr, | ||||
TensorFormats base_format, | TensorFormats base_format, | ||||
TensorFormats tensor_format) const; | |||||
TensorFormats tensor_format, | |||||
ReformatAttribute extra_attribute = | |||||
ReformatAttribute::DEFAULT) const; | |||||
/*! | /*! | ||||
* \brief profile opr format aware operators (like conv, deconv, conv_bias, etc.) | |||||
* \brief profile opr format aware operators (like conv, deconv, conv_bias, | |||||
* etc.) | |||||
* | * | ||||
* \param opr pointer to the operator node to be profiled | * \param opr pointer to the operator node to be profiled | ||||
* \param base_config the tensor formats configuration of base opr format | * \param base_config the tensor formats configuration of base opr format | ||||
* \param config all the available configuration | |||||
* \param config all the available configuration | |||||
* \return the operator node record | * \return the operator node record | ||||
*/ | */ | ||||
OperatorNodeRecord profile_operator( | OperatorNodeRecord profile_operator( | ||||
const OperatorNodeBase* opr, | const OperatorNodeBase* opr, | ||||
const OprTensorFormatsConfiguration& base_config, | const OprTensorFormatsConfiguration& base_config, | ||||
const SmallVector<OprTensorFormatsConfiguration>& available_configs) | |||||
const; | |||||
const SmallVector<OprTensorFormatsConfiguration>& available_configs, | |||||
ReformatAttribute extra_attribute = | |||||
ReformatAttribute::DEFAULT) const; | |||||
float profile_operator(const OperatorNodeBase* opr, | float profile_operator(const OperatorNodeBase* opr, | ||||
const OprTensorFormatsConfiguration& base_config, | const OprTensorFormatsConfiguration& base_config, | ||||
const OprTensorFormatsConfiguration& config) const; | |||||
const OprTensorFormatsConfiguration& config, | |||||
ReformatAttribute extra_attribute = | |||||
ReformatAttribute::DEFAULT) const; | |||||
/*! | /*! | ||||
* \brief profile layout transform of the var node | * \brief profile layout transform of the var node | ||||
* | * | ||||
* \param var pointer to the var node to be profiled | * \param var pointer to the var node to be profiled | ||||
* \param base_format the original tensor formats in which the var node is stored | |||||
* \param available_tensor_formats the available tensor formats | |||||
* \param base_format the original tensor formats in which the var node is | |||||
* stored \param available_tensor_formats the available tensor formats | |||||
* \param extra_attribute the extra attributes (options) of the problem | * \param extra_attribute the extra attributes (options) of the problem | ||||
* \return the var node record | * \return the var node record | ||||
*/ | */ | ||||
VarNodeRecord profile_var_node( | VarNodeRecord profile_var_node( | ||||
const VarNode* var, TensorFormats base_format, | const VarNode* var, TensorFormats base_format, | ||||
const SmallVector<TensorFormats>& available_tensor_formats, | const SmallVector<TensorFormats>& available_tensor_formats, | ||||
ReformatKey::Attribute extra_attribute = | |||||
ReformatKey::Attribute::DEFAULT) const; | |||||
ReformatAttribute extra_attribute = | |||||
ReformatAttribute::DEFAULT) const; | |||||
float profile_var_node(const VarNode* var, TensorFormats base_format, | float profile_var_node(const VarNode* var, TensorFormats base_format, | ||||
const ReformatKey& key) const; | const ReformatKey& key) const; | ||||
int m_runs; /// sample times of the profiler | int m_runs; /// sample times of the profiler | ||||
@@ -216,20 +225,23 @@ private: | |||||
ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( | ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( | ||||
const OperatorNodeBase* opr, TensorFormats base_format, | const OperatorNodeBase* opr, TensorFormats base_format, | ||||
const SmallVector<TensorFormats>& available_tensor_formats) const { | |||||
const SmallVector<TensorFormats>& available_tensor_formats, | |||||
ReformatAttribute extra_attribute) const { | |||||
OperatorNodeRecord record; | OperatorNodeRecord record; | ||||
record.opr = opr; | record.opr = opr; | ||||
auto& costs = record.costs; | auto& costs = record.costs; | ||||
for (auto&& f : available_tensor_formats) { | for (auto&& f : available_tensor_formats) { | ||||
auto opr_format = tensor_formats_to_opr_format(f); | auto opr_format = tensor_formats_to_opr_format(f); | ||||
costs[opr_format] = profile_operator(opr, base_format, f); | |||||
costs[opr_format] = | |||||
profile_operator(opr, base_format, f, extra_attribute); | |||||
} | } | ||||
return record; | return record; | ||||
} | } | ||||
float ProfilerImpl::profile_operator(const OperatorNodeBase* opr, | float ProfilerImpl::profile_operator(const OperatorNodeBase* opr, | ||||
TensorFormats base_format, | TensorFormats base_format, | ||||
TensorFormats tensor_format) const { | |||||
TensorFormats tensor_format, | |||||
ReformatAttribute extra_attribute) const { | |||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
graph->options().graph_opt_level = 0; | graph->options().graph_opt_level = 0; | ||||
graph->options().var_sanity_check_first_run = false; | graph->options().var_sanity_check_first_run = false; | ||||
@@ -239,8 +251,8 @@ float ProfilerImpl::profile_operator(const OperatorNodeBase* opr, | |||||
auto&& cn = var->comp_node(); | auto&& cn = var->comp_node(); | ||||
auto&& dtype = var->dtype(); | auto&& dtype = var->dtype(); | ||||
auto dval = std::make_shared<DeviceTensorND>(cn, dtype); | auto dval = std::make_shared<DeviceTensorND>(cn, dtype); | ||||
auto aligned_tensor_shape = | |||||
make_aligned_tensor_shape(var, base_format, tensor_format); | |||||
auto aligned_tensor_shape = ReformatManager::make_aligned_tensor_shape( | |||||
var, base_format, tensor_format, extra_attribute); | |||||
dval->resize(aligned_tensor_shape); | dval->resize(aligned_tensor_shape); | ||||
auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval); | auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval); | ||||
new_inps[i] = aligned_var.node(); | new_inps[i] = aligned_var.node(); | ||||
@@ -263,8 +275,8 @@ float ProfilerImpl::profile_operator(const OperatorNodeBase* opr, | |||||
ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( | ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( | ||||
const OperatorNodeBase* opr, | const OperatorNodeBase* opr, | ||||
const OprTensorFormatsConfiguration& base_config, | const OprTensorFormatsConfiguration& base_config, | ||||
const SmallVector<OprTensorFormatsConfiguration>& available_configs) | |||||
const { | |||||
const SmallVector<OprTensorFormatsConfiguration>& available_configs, | |||||
ReformatAttribute extra_attribute) const { | |||||
OperatorNodeRecord record; | OperatorNodeRecord record; | ||||
record.opr = opr; | record.opr = opr; | ||||
auto& costs = record.costs; | auto& costs = record.costs; | ||||
@@ -273,7 +285,8 @@ ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( | |||||
if (i.opr_format == OprFormat::NCHW && | if (i.opr_format == OprFormat::NCHW && | ||||
opr->input(0)->dtype().enumv() != DTypeEnum::Float32) | opr->input(0)->dtype().enumv() != DTypeEnum::Float32) | ||||
continue; | continue; | ||||
costs[i.opr_format] = profile_operator(opr, base_config, i); | |||||
costs[i.opr_format] = | |||||
profile_operator(opr, base_config, i, extra_attribute); | |||||
} | } | ||||
return record; | return record; | ||||
} | } | ||||
@@ -281,7 +294,8 @@ ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( | |||||
float ProfilerImpl::profile_operator( | float ProfilerImpl::profile_operator( | ||||
const OperatorNodeBase* opr, | const OperatorNodeBase* opr, | ||||
const OprTensorFormatsConfiguration& base_config, | const OprTensorFormatsConfiguration& base_config, | ||||
const OprTensorFormatsConfiguration& config) const { | |||||
const OprTensorFormatsConfiguration& config, | |||||
ReformatAttribute extra_attribute) const { | |||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
graph->options().graph_opt_level = 0; | graph->options().graph_opt_level = 0; | ||||
graph->options().var_sanity_check_first_run = false; | graph->options().var_sanity_check_first_run = false; | ||||
@@ -297,18 +311,18 @@ float ProfilerImpl::profile_operator( | |||||
TensorShape aligned_shape; | TensorShape aligned_shape; | ||||
if (config.input_tensor_types[i] == TensorType::WEIGHT) { | if (config.input_tensor_types[i] == TensorType::WEIGHT) { | ||||
mgb_assert(base_config.input_tensor_types[i] == TensorType::WEIGHT); | mgb_assert(base_config.input_tensor_types[i] == TensorType::WEIGHT); | ||||
aligned_shape = make_aligned_weight_shape( | |||||
aligned_shape = ReformatManager::make_aligned_weight_shape( | |||||
var, base_config.input_tensor_formats[i], | var, base_config.input_tensor_formats[i], | ||||
config.input_tensor_formats[i], | config.input_tensor_formats[i], | ||||
config.output_tensor_formats[0]); | |||||
config.output_tensor_formats[0], extra_attribute); | |||||
} else { | } else { | ||||
mgb_assert(base_config.input_tensor_types[i] == | mgb_assert(base_config.input_tensor_types[i] == | ||||
config.input_tensor_types[i]); | config.input_tensor_types[i]); | ||||
mgb_assert(base_config.input_tensor_types[i] == | mgb_assert(base_config.input_tensor_types[i] == | ||||
TensorType::FEATURE); | TensorType::FEATURE); | ||||
aligned_shape = make_aligned_tensor_shape( | |||||
aligned_shape = ReformatManager::make_aligned_tensor_shape( | |||||
var, base_config.input_tensor_formats[i], | var, base_config.input_tensor_formats[i], | ||||
config.input_tensor_formats[i]); | |||||
config.input_tensor_formats[i], extra_attribute); | |||||
} | } | ||||
dval->resize(aligned_shape); | dval->resize(aligned_shape); | ||||
auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval); | auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval); | ||||
@@ -357,7 +371,7 @@ float ProfilerImpl::profile_operator( | |||||
ProfilerImpl::VarNodeRecord ProfilerImpl::profile_var_node( | ProfilerImpl::VarNodeRecord ProfilerImpl::profile_var_node( | ||||
const VarNode* var, TensorFormats base_format, | const VarNode* var, TensorFormats base_format, | ||||
const SmallVector<TensorFormats>& available_tensor_formats, | const SmallVector<TensorFormats>& available_tensor_formats, | ||||
ReformatKey::Attribute attribute) const { | |||||
ReformatAttribute attribute) const { | |||||
VarNodeRecord record; | VarNodeRecord record; | ||||
record.var = var; | record.var = var; | ||||
auto& costs = record.costs; | auto& costs = record.costs; | ||||
@@ -379,8 +393,8 @@ float ProfilerImpl::profile_var_node(const VarNode* var, | |||||
auto&& cn = var->comp_node(); | auto&& cn = var->comp_node(); | ||||
auto&& dtype = var->dtype(); | auto&& dtype = var->dtype(); | ||||
auto dval = std::make_shared<DeviceTensorND>(cn, dtype); | auto dval = std::make_shared<DeviceTensorND>(cn, dtype); | ||||
auto aligned_tensor_shape = | |||||
make_aligned_tensor_shape(var, base_format, key.input_format); | |||||
auto aligned_tensor_shape = ReformatManager::make_aligned_tensor_shape( | |||||
var, base_format, key.input_format, key.attribute); | |||||
dval->resize(aligned_tensor_shape); | dval->resize(aligned_tensor_shape); | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
graph->options().graph_opt_level = 0; | graph->options().graph_opt_level = 0; | ||||
@@ -468,13 +482,14 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile( | |||||
auto base_format = problem.base_format(); | auto base_format = problem.base_format(); | ||||
auto&& available_tensor_formats = problem.available_tensor_formats(); | auto&& available_tensor_formats = problem.available_tensor_formats(); | ||||
auto&& reformat_attribute = problem.attribute().reformat_attribute; | |||||
ProfilingResult profiling_result; | ProfilingResult profiling_result; | ||||
auto& opr_record = profiling_result.opr_record; | auto& opr_record = profiling_result.opr_record; | ||||
auto& var_record = profiling_result.var_record; | auto& var_record = profiling_result.var_record; | ||||
for (auto&& var : vars) { | for (auto&& var : vars) { | ||||
var_record[var] = | |||||
profile_var_node(var, base_format, available_tensor_formats); | |||||
var_record[var] = profile_var_node( | |||||
var, base_format, available_tensor_formats, reformat_attribute); | |||||
} | } | ||||
for (auto&& opr : oprs) { | for (auto&& opr : oprs) { | ||||
auto&& opr_configs = problem.opr_configs(); | auto&& opr_configs = problem.opr_configs(); | ||||
@@ -482,11 +497,12 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile( | |||||
if (find == opr_configs.end()) { | if (find == opr_configs.end()) { | ||||
if (skip_oprs.count(opr) > 0) { | if (skip_oprs.count(opr) > 0) { | ||||
SmallVector<TensorFormats> tensor_formats = {base_format}; | SmallVector<TensorFormats> tensor_formats = {base_format}; | ||||
opr_record[opr] = | |||||
profile_operator(opr, base_format, tensor_formats); | |||||
opr_record[opr] = profile_operator( | |||||
opr, base_format, tensor_formats, reformat_attribute); | |||||
} else { | } else { | ||||
opr_record[opr] = profile_operator(opr, base_format, | opr_record[opr] = profile_operator(opr, base_format, | ||||
available_tensor_formats); | |||||
available_tensor_formats, | |||||
reformat_attribute); | |||||
} | } | ||||
} else { | } else { | ||||
auto&& dispatchers = find->second; | auto&& dispatchers = find->second; | ||||
@@ -498,7 +514,8 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile( | |||||
} | } | ||||
} | } | ||||
auto base_config = problem.base_config(opr); | auto base_config = problem.base_config(opr); | ||||
opr_record[opr] = profile_operator(opr, base_config, configs); | |||||
opr_record[opr] = profile_operator(opr, base_config, configs, | |||||
reformat_attribute); | |||||
} | } | ||||
} | } | ||||
for (auto&& rpair : opr_record) { | for (auto&& rpair : opr_record) { | ||||
@@ -21,7 +21,7 @@ using NamedTensorShape = megdnn::NamedTensorShape; | |||||
using Dimension = megdnn::Dimension; | using Dimension = megdnn::Dimension; | ||||
namespace { | namespace { | ||||
int gcd(const int& p, const int& q) { | |||||
static inline int gcd(const int& p, const int& q) { | |||||
int x = p, y = q; | int x = p, y = q; | ||||
while (y != 0) { | while (y != 0) { | ||||
if (x < y) { | if (x < y) { | ||||
@@ -33,6 +33,47 @@ int gcd(const int& p, const int& q) { | |||||
} | } | ||||
return x; | return x; | ||||
} | } | ||||
static inline size_t extra_alignment( | |||||
ReformatManager::ReformatKey::Attribute attr, | |||||
TensorFormats target_formats, DType dt, size_t channel_alignment) { | |||||
using Attribute = ReformatManager::ReformatKey::Attribute; | |||||
if (attr & Attribute::AUTO_PADDING_NHWC) { | |||||
constexpr size_t alignment_in_bits = 32; | |||||
size_t dtype_bits = dt.is_low_bit() ? dt.low_bit() : dt.size(1) * 8; | |||||
size_t extra_alignment = alignment_in_bits >= dtype_bits | |||||
? alignment_in_bits / dtype_bits | |||||
: 1; | |||||
if (target_formats == TensorFormats::NHWC) | |||||
channel_alignment = extra_alignment * channel_alignment / | |||||
gcd(channel_alignment, extra_alignment); | |||||
return channel_alignment; | |||||
} | |||||
return channel_alignment; | |||||
} | |||||
static inline std::tuple<size_t, size_t> extra_alignment( | |||||
const ReformatManager::ReformatKey& key, DType dt, | |||||
size_t input_channel_alignment, size_t output_channel_alignment) { | |||||
using Attribute = ReformatManager::ReformatKey::Attribute; | |||||
if (key.attribute & Attribute::AUTO_PADDING_NHWC) { | |||||
constexpr size_t alignment_in_bits = 32; | |||||
size_t dtype_bits = dt.is_low_bit() ? dt.low_bit() : dt.size(1) * 8; | |||||
size_t extra_alignment = alignment_in_bits >= dtype_bits | |||||
? alignment_in_bits / dtype_bits | |||||
: 1; | |||||
if (key.input_format == TensorFormats::NHWC) | |||||
input_channel_alignment = | |||||
input_channel_alignment * extra_alignment / | |||||
gcd(input_channel_alignment, extra_alignment); | |||||
if (key.output_format == TensorFormats::NHWC) | |||||
output_channel_alignment = | |||||
output_channel_alignment * extra_alignment / | |||||
gcd(output_channel_alignment, extra_alignment); | |||||
return {input_channel_alignment, output_channel_alignment}; | |||||
} | |||||
return {input_channel_alignment, output_channel_alignment}; | |||||
} | |||||
}; // namespace | }; // namespace | ||||
// =================== ReformatManager::ReformatKey ====================*/ | // =================== ReformatManager::ReformatKey ====================*/ | ||||
@@ -293,7 +334,8 @@ ReformatManager::ReformatImpl ReformatManager::get( | |||||
auto rst = find->second; | auto rst = find->second; | ||||
return rst; | return rst; | ||||
} | } | ||||
mgb_assert(key.attribute == Attribute::DEFAULT); | |||||
mgb_assert(!(key.attribute & Attribute::IMAGE2D) && | |||||
!(key.attribute & Attribute::IC_SMALL)); | |||||
auto&& i = key.input_format; | auto&& i = key.input_format; | ||||
auto&& o = key.output_format; | auto&& o = key.output_format; | ||||
auto ishp = tensor_formats_to_named_tensor_shape(i); | auto ishp = tensor_formats_to_named_tensor_shape(i); | ||||
@@ -346,6 +388,8 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_featrue( | |||||
"invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)", | "invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)", | ||||
input_alignment, output_alignment, | input_alignment, output_alignment, | ||||
input_shape.to_string().c_str()); | input_shape.to_string().c_str()); | ||||
std::tie(input_alignment, output_alignment) = extra_alignment( | |||||
key, orig_var->dtype(), input_alignment, output_alignment); | |||||
NamedTensorShape orig_shape = | NamedTensorShape orig_shape = | ||||
tensor_formats_to_named_tensor_shape(orig_format); | tensor_formats_to_named_tensor_shape(orig_format); | ||||
size_t orig_channel = 0; | size_t orig_channel = 0; | ||||
@@ -451,6 +495,12 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( | |||||
"invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)", | "invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)", | ||||
in_channel_alignment, out_channel_alignment, | in_channel_alignment, out_channel_alignment, | ||||
output_shape.to_string().c_str()); | output_shape.to_string().c_str()); | ||||
in_channel_alignment = | |||||
::extra_alignment(key.attribute, key.output_format, | |||||
orig_var->dtype(), in_channel_alignment); | |||||
out_channel_alignment = | |||||
::extra_alignment(key.attribute, key.output_format, | |||||
orig_var->dtype(), out_channel_alignment); | |||||
size_t aligned_in_channel = | size_t aligned_in_channel = | ||||
divup(in_channels, in_channel_alignment) * in_channel_alignment; | divup(in_channels, in_channel_alignment) * in_channel_alignment; | ||||
if (extra_alignment.name == out_channel_name) { | if (extra_alignment.name == out_channel_name) { | ||||
@@ -506,9 +556,9 @@ const ReformatManager& ReformatManager::instance() { | |||||
return inst; | return inst; | ||||
} | } | ||||
TensorShape mgb::gopt::make_aligned_tensor_shape(const VarNode* var, | |||||
TensorFormats orig_formats, | |||||
TensorFormats target_formats) { | |||||
TensorShape ReformatManager::make_aligned_tensor_shape( | |||||
const VarNode* var, TensorFormats orig_formats, | |||||
TensorFormats target_formats, ReformatKey::Attribute extra_attribute) { | |||||
using Dimension = megdnn::Dimension; | using Dimension = megdnn::Dimension; | ||||
static constexpr uint32_t UNDETERMINED_EXTENT = | static constexpr uint32_t UNDETERMINED_EXTENT = | ||||
Dimension::UNDETERMINED_EXTENT; | Dimension::UNDETERMINED_EXTENT; | ||||
@@ -545,6 +595,15 @@ TensorShape mgb::gopt::make_aligned_tensor_shape(const VarNode* var, | |||||
tshp[i] = oshp[idx] * factor; | tshp[i] = oshp[idx] * factor; | ||||
else | else | ||||
tshp[i] = divup(oshp[idx], factor); | tshp[i] = divup(oshp[idx], factor); | ||||
if (name == Dimension::Name::C) { | |||||
size_t channel_alignment = target_shape[i].stride(); | |||||
size_t channels = tshp[i] * channel_alignment; | |||||
size_t new_channel_alignment = | |||||
extra_alignment(extra_attribute, target_formats, | |||||
var->dtype(), channel_alignment); | |||||
tshp[i] = divup(channels, new_channel_alignment) * | |||||
new_channel_alignment / channel_alignment; | |||||
} | |||||
} else { | } else { | ||||
tshp[i] = target_shape[i].extent(); | tshp[i] = target_shape[i].extent(); | ||||
} | } | ||||
@@ -552,11 +611,12 @@ TensorShape mgb::gopt::make_aligned_tensor_shape(const VarNode* var, | |||||
return tshp; | return tshp; | ||||
} | } | ||||
TensorShape mgb::gopt::make_aligned_weight_shape(const VarNode* var, | |||||
TensorFormats orig_formats, | |||||
TensorFormats target_formats, | |||||
TensorFormats extra_formats) { | |||||
auto tshp = make_aligned_tensor_shape(var, orig_formats, target_formats); | |||||
TensorShape ReformatManager::make_aligned_weight_shape( | |||||
const VarNode* var, TensorFormats orig_formats, | |||||
TensorFormats target_formats, TensorFormats extra_formats, | |||||
ReformatKey::Attribute extra_attribute) { | |||||
auto tshp = make_aligned_tensor_shape(var, orig_formats, target_formats, | |||||
extra_attribute); | |||||
auto extra_shape = tensor_formats_to_named_tensor_shape(extra_formats); | auto extra_shape = tensor_formats_to_named_tensor_shape(extra_formats); | ||||
using Dimension = megdnn::Dimension; | using Dimension = megdnn::Dimension; | ||||
static constexpr uint32_t UNDETERMINED_EXTENT = | static constexpr uint32_t UNDETERMINED_EXTENT = | ||||
@@ -567,6 +627,9 @@ TensorShape mgb::gopt::make_aligned_weight_shape(const VarNode* var, | |||||
if (name == Dimension::Name::C && | if (name == Dimension::Name::C && | ||||
extra_shape[i].extent() == UNDETERMINED_EXTENT) { | extra_shape[i].extent() == UNDETERMINED_EXTENT) { | ||||
out_channel_alignment = extra_shape[i].stride(); | out_channel_alignment = extra_shape[i].stride(); | ||||
out_channel_alignment = | |||||
extra_alignment(extra_attribute, target_formats, | |||||
var->dtype(), out_channel_alignment); | |||||
} | } | ||||
} | } | ||||
@@ -583,9 +646,8 @@ TensorShape mgb::gopt::make_aligned_weight_shape(const VarNode* var, | |||||
return tshp; | return tshp; | ||||
} | } | ||||
ReformatManager::AlignmentDesc mgb::gopt::make_aligned_desc( | |||||
ReformatManager::AlignmentDesc ReformatManager::make_aligned_desc( | |||||
TensorFormats weight_format, TensorFormats out_feature_format) { | TensorFormats weight_format, TensorFormats out_feature_format) { | ||||
using AlignmentDesc = ReformatManager::AlignmentDesc; | |||||
using Name = Dimension::Name; | using Name = Dimension::Name; | ||||
auto weight_shape = tensor_formats_to_named_tensor_shape(weight_format); | auto weight_shape = tensor_formats_to_named_tensor_shape(weight_format); | ||||
auto out_shape = tensor_formats_to_named_tensor_shape(out_feature_format); | auto out_shape = tensor_formats_to_named_tensor_shape(out_feature_format); | ||||
@@ -143,6 +143,7 @@ public: | |||||
TensorFormats base_format() const { | TensorFormats base_format() const { | ||||
return m_ctx.attribute().base_tensor_formats; | return m_ctx.attribute().base_tensor_formats; | ||||
} | } | ||||
Attribute attribute() const { return m_ctx.attribute(); } | |||||
/*! | /*! | ||||
* \brief return the tensor formats configuration of an operator in the | * \brief return the tensor formats configuration of an operator in the | ||||
* default op format | * default op format | ||||
@@ -74,6 +74,7 @@ public: | |||||
DEFAULT = 0, | DEFAULT = 0, | ||||
IMAGE2D = 1 << 0, | IMAGE2D = 1 << 0, | ||||
IC_SMALL = 1 << 1, | IC_SMALL = 1 << 1, | ||||
AUTO_PADDING_NHWC = 1 << 2, | |||||
}; | }; | ||||
TensorFormats input_format, output_format; | TensorFormats input_format, output_format; | ||||
DTypeEnum input_dtype, output_dtype; | DTypeEnum input_dtype, output_dtype; | ||||
@@ -124,23 +125,40 @@ public: | |||||
ReformatImpl auto_aligned_reformat_weight( | ReformatImpl auto_aligned_reformat_weight( | ||||
const VarNode* orig_var, const ReformatKey& key, | const VarNode* orig_var, const ReformatKey& key, | ||||
const AlignmentDesc& extra_alignment = {}) const; | const AlignmentDesc& extra_alignment = {}) const; | ||||
static TensorShape make_aligned_tensor_shape( | |||||
const VarNode* var, TensorFormats orig_formats, | |||||
TensorFormats target_formats, | |||||
ReformatKey::Attribute extra_attribute = | |||||
ReformatKey::Attribute::DEFAULT); | |||||
static TensorShape make_aligned_weight_shape( | |||||
const VarNode* var, TensorFormats orig_formats, | |||||
TensorFormats target_formats, TensorFormats extra_formats, | |||||
ReformatKey::Attribute extra_attribute = | |||||
ReformatKey::Attribute::DEFAULT); | |||||
static AlignmentDesc make_aligned_desc(TensorFormats weight_format, | |||||
TensorFormats out_feature_format); | |||||
static const ReformatManager& instance(); | static const ReformatManager& instance(); | ||||
private: | private: | ||||
ReformatCache m_cache; | ReformatCache m_cache; | ||||
}; | }; | ||||
TensorShape make_aligned_tensor_shape(const VarNode* var, | |||||
TensorFormats orig_formats, | |||||
TensorFormats target_formats); | |||||
TensorShape make_aligned_weight_shape(const VarNode* var, | |||||
TensorFormats orig_formats, | |||||
TensorFormats target_formats, | |||||
TensorFormats extra_formats); | |||||
MGB_DEF_ENUM_CLASS_BIT_OPR(ReformatManager::ReformatKey::Attribute); | |||||
// | |||||
//TensorShape make_aligned_tensor_shape( | |||||
// const VarNode* var, TensorFormats orig_formats, | |||||
// TensorFormats target_formats, | |||||
// ReformatManager::ReformatKey::Attribute extra_attribute = | |||||
// ReformatManager::ReformatKey::Attribute::DEFAULT); | |||||
// | |||||
//TensorShape make_aligned_weight_shape( | |||||
// const VarNode* var, TensorFormats orig_formats, | |||||
// TensorFormats target_formats, TensorFormats extra_formats, | |||||
// ReformatManager::ReformatKey::Attribute extra_attribute = | |||||
// ReformatManager::ReformatKey::Attribute::DEFAULT); | |||||
ReformatManager::AlignmentDesc make_aligned_desc( | |||||
TensorFormats weight_format, TensorFormats out_feature_format); | |||||
} // namespace gopt | } // namespace gopt | ||||
} // namespace mgb | } // namespace mgb | ||||