|
|
@@ -7,53 +7,62 @@ namespace imperative { |
|
|
|
|
|
|
|
using FT = Format::Type; |
|
|
|
|
|
|
|
TypedValueRef<FormattedTensorValue> FormattedTensorValue::as(const FT& target) const { |
|
|
|
return FormattedTensorValue::make(m_value, target); |
|
|
|
TypedValueRef<FormattedTensorValue> FormatTransformation::as( |
|
|
|
const FormattedTensorValue& tensor, const FT& target) const { |
|
|
|
return m_value_type.make(tensor.value(), target); |
|
|
|
} |
|
|
|
|
|
|
|
TypedValueRef<FormattedTensorValue> FormattedTensorValue::to( |
|
|
|
const FT& target, const std::string& scope) const { |
|
|
|
TypedValueRef<FormattedTensorValue> FormatTransformation::to( |
|
|
|
const FormattedTensorValue& tensor, const FT& target, |
|
|
|
const std::string& scope) const { |
|
|
|
std::vector<int32_t> pattern; |
|
|
|
if (m_format == FT::NHWC && target == FT::NCHW) { |
|
|
|
if (tensor.format() == FT::NHWC && target == FT::NCHW) { |
|
|
|
pattern = {0, 3, 1, 2}; |
|
|
|
} else if (m_format == FT::NCHW && target == FT::NHWC) { |
|
|
|
} else if (tensor.format() == FT::NCHW && target == FT::NHWC) { |
|
|
|
pattern = {0, 2, 3, 1}; |
|
|
|
} else { |
|
|
|
mgb_throw( |
|
|
|
MegBrainError, "Unsupport format conversion from %s to %s", |
|
|
|
m_format.to_string().c_str(), Format(target).to_string().c_str()); |
|
|
|
tensor.format().to_string().c_str(), |
|
|
|
Format(target).to_string().c_str()); |
|
|
|
} |
|
|
|
auto output = imperative::apply( |
|
|
|
*Dimshuffle::make(pattern, scope), std::vector<ValueRef>{m_value})[0]; |
|
|
|
return FormattedTensorValue::make(output, target); |
|
|
|
*Dimshuffle::make(pattern, scope), |
|
|
|
SmallVector<ValueRef>{tensor.value()})[0]; |
|
|
|
return m_value_type.make(output, target); |
|
|
|
} |
|
|
|
|
|
|
|
namespace { |
|
|
|
|
|
|
|
ValueRef unwrap_input(const ValueRef& input) { |
|
|
|
if (auto format_input = input.as_ref<FormattedTensorValue>()) { |
|
|
|
inline ValueRef FormatTransformation::unwrap_input(const ValueRef& input) const { |
|
|
|
if (auto format_input = input.as_ref(m_value_type)) { |
|
|
|
return format_input->value(); |
|
|
|
} else { |
|
|
|
return input; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<ValueRef> unwrap_inputs(const Span<ValueRef>& inputs) { |
|
|
|
std::vector<ValueRef> unwrapped_inputs; |
|
|
|
for (auto&& input : inputs) { |
|
|
|
unwrapped_inputs.push_back(unwrap_input(input)); |
|
|
|
inline ValueRefList FormatTransformation::unwrap_inputs( |
|
|
|
const Span<ValueRef>& inputs) const { |
|
|
|
ValueRefList unwrapped_inputs(inputs.size()); |
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
|
unwrapped_inputs[i] = unwrap_input(inputs[i]); |
|
|
|
} |
|
|
|
return unwrapped_inputs; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<ValueRef> wrap_outputs( |
|
|
|
const std::vector<ValueRef>& outputs, FT type = FT::DEFAULT) { |
|
|
|
std::vector<ValueRef> wrapped_outputs; |
|
|
|
for (auto&& output : outputs) { |
|
|
|
wrapped_outputs.push_back(FormattedTensorValue::make(output, type)); |
|
|
|
inline ValueRef FormatTransformation::wrap_output( |
|
|
|
const ValueRef& output, FT type) const { |
|
|
|
return m_value_type.make(output, type); |
|
|
|
} |
|
|
|
|
|
|
|
inline ValueRefList FormatTransformation::wrap_outputs( |
|
|
|
const ValueRefList& outputs, FT type) const { |
|
|
|
ValueRefList wrapped_outputs(outputs.size()); |
|
|
|
for (size_t i = 0; i < outputs.size(); ++i) { |
|
|
|
wrapped_outputs[i] = wrap_output(outputs[i], type); |
|
|
|
} |
|
|
|
return wrapped_outputs; |
|
|
|
} |
|
|
|
namespace { |
|
|
|
|
|
|
|
ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { |
|
|
|
mgb_assert(shape.ndim == 4); |
|
|
@@ -64,20 +73,21 @@ ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { |
|
|
|
return out; |
|
|
|
} |
|
|
|
|
|
|
|
using FormatRule = std::function<std::vector<ValueRef>( |
|
|
|
const OpDef&, Span<ValueRef>&, const bool&)>; |
|
|
|
using FormatRule = std::function<ValueRefList( |
|
|
|
const OpDef&, Span<ValueRef>&, const bool&, const FormatTransformation&)>; |
|
|
|
static std::unordered_map<Typeinfo*, FormatRule> format_rules; |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void register_format_rule( |
|
|
|
std::vector<ValueRef> (*rule)(const T&, Span<ValueRef>&, const bool&)) { |
|
|
|
void register_format_rule(ValueRefList (*rule)( |
|
|
|
const T&, Span<ValueRef>&, const bool&, const FormatTransformation&)) { |
|
|
|
format_rules[T::typeinfo()] = [rule](const OpDef& def, Span<ValueRef>& inputs, |
|
|
|
const bool& auto_convert) { |
|
|
|
return (*rule)(def.cast_final_safe<T>(), inputs, auto_convert); |
|
|
|
const bool& auto_convert, |
|
|
|
const FormatTransformation& t) { |
|
|
|
return (*rule)(def.cast_final_safe<T>(), inputs, auto_convert, t); |
|
|
|
}; |
|
|
|
} |
|
|
|
|
|
|
|
auto convert_nchw2nhwc_pattern(const std::vector<int32_t>& pattern) { |
|
|
|
inline auto convert_nchw2nhwc_pattern(const std::vector<int32_t>& pattern) { |
|
|
|
mgb_assert(pattern.size() == 4); |
|
|
|
auto nhwc_pattern = pattern; |
|
|
|
for (size_t idx = 0; idx < 4; ++idx) { |
|
|
@@ -93,19 +103,20 @@ auto convert_nchw2nhwc_pattern(const std::vector<int32_t>& pattern) { |
|
|
|
return nhwc_pattern; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<ValueRef> dimshuffle_rule( |
|
|
|
const Dimshuffle& op, Span<ValueRef>& inputs, const bool& auto_convert) { |
|
|
|
ValueRefList dimshuffle_rule( |
|
|
|
const Dimshuffle& op, Span<ValueRef>& inputs, const bool& auto_convert, |
|
|
|
const FormatTransformation& t) { |
|
|
|
mgb_assert(inputs.size() == 1); |
|
|
|
auto& src = inputs[0].cast<FormattedTensorValue>(); |
|
|
|
auto& src = inputs[0].cast(t.value_type()); |
|
|
|
// Only support converting pattern from NCHW to NHWC currently. |
|
|
|
if (auto_convert && src.format() == FT::NHWC) { |
|
|
|
auto pattern = convert_nchw2nhwc_pattern(op.pattern); |
|
|
|
// dimshuffle will not maintain NHWC Format |
|
|
|
return wrap_outputs(imperative::apply( |
|
|
|
return t.wrap_outputs(imperative::apply( |
|
|
|
*Dimshuffle::make(std::move(pattern), op.scope()), |
|
|
|
unwrap_inputs(inputs))); |
|
|
|
t.unwrap_inputs(inputs))); |
|
|
|
} |
|
|
|
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); |
|
|
|
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); |
|
|
|
} |
|
|
|
|
|
|
|
ValueRef convert_nchw2nhwc_tensornd(const HostTensorND& shape) { |
|
|
@@ -125,53 +136,55 @@ ValueRef convert_nchw2nhwc_tensornd(const HostTensorND& shape) { |
|
|
|
return nhwc_shape_input; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<ValueRef> reshape_rule( |
|
|
|
const Reshape& op, Span<ValueRef>& inputs, const bool& auto_convert) { |
|
|
|
ValueRefList reshape_rule( |
|
|
|
const Reshape& op, Span<ValueRef>& inputs, const bool& auto_convert, |
|
|
|
const FormatTransformation& t) { |
|
|
|
mgb_assert(inputs.size() == 2); |
|
|
|
auto& src = inputs[0].cast<FormattedTensorValue>(); |
|
|
|
auto& src = inputs[0].cast(t.value_type()); |
|
|
|
if (auto_convert && src.format() == FT::NHWC) { |
|
|
|
auto shape = unwrap_input(inputs[1]).numpy().cast<HostValue>().as_nd(); |
|
|
|
auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd(); |
|
|
|
if (shape.layout().total_nr_elems() == 4) { |
|
|
|
// output is still NHWC format |
|
|
|
auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); |
|
|
|
auto outputs = imperative::apply( |
|
|
|
op, std::vector<ValueRef>{unwrap_input(inputs[0]), nhwc_shape}); |
|
|
|
return wrap_outputs(outputs, FT::NHWC); |
|
|
|
op, SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape}); |
|
|
|
return t.wrap_outputs(outputs, FT::NHWC); |
|
|
|
} else { |
|
|
|
// will not maintain src's format |
|
|
|
auto nchw_src = src.to(FT::NCHW, op.scope())->value(); |
|
|
|
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); |
|
|
|
auto outputs = imperative::apply( |
|
|
|
op, std::vector<ValueRef>{nchw_src, unwrap_input(inputs[1])}); |
|
|
|
return wrap_outputs(outputs); |
|
|
|
op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])}); |
|
|
|
return t.wrap_outputs(outputs); |
|
|
|
} |
|
|
|
} |
|
|
|
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); |
|
|
|
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<ValueRef> broadcast_rule( |
|
|
|
const Broadcast& op, Span<ValueRef>& inputs, const bool& auto_convert) { |
|
|
|
ValueRefList broadcast_rule( |
|
|
|
const Broadcast& op, Span<ValueRef>& inputs, const bool& auto_convert, |
|
|
|
const FormatTransformation& t) { |
|
|
|
mgb_assert(inputs.size() == 2); |
|
|
|
auto& src = inputs[0].cast<FormattedTensorValue>(); |
|
|
|
auto& src = inputs[0].cast(t.value_type()); |
|
|
|
if (auto_convert && src.format() == FT::NHWC) { |
|
|
|
auto shape = unwrap_input(inputs[1]).numpy().cast<HostValue>().as_nd(); |
|
|
|
auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd(); |
|
|
|
if (shape.layout().total_nr_elems() == 4) { |
|
|
|
// output is still NHWC format |
|
|
|
auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); |
|
|
|
auto outputs = imperative::apply( |
|
|
|
op, std::vector<ValueRef>{unwrap_input(inputs[0]), nhwc_shape}); |
|
|
|
return wrap_outputs(outputs, FT::NHWC); |
|
|
|
op, SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape}); |
|
|
|
return t.wrap_outputs(outputs, FT::NHWC); |
|
|
|
} else { |
|
|
|
// will not maintain src's format |
|
|
|
auto nchw_src = src.to(FT::NCHW, op.scope())->value(); |
|
|
|
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); |
|
|
|
auto outputs = imperative::apply( |
|
|
|
op, std::vector<ValueRef>{nchw_src, unwrap_input(inputs[1])}); |
|
|
|
return wrap_outputs(outputs); |
|
|
|
op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])}); |
|
|
|
return t.wrap_outputs(outputs); |
|
|
|
} |
|
|
|
} |
|
|
|
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); |
|
|
|
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); |
|
|
|
} |
|
|
|
|
|
|
|
bool is_reduce_ndim_idx_items( |
|
|
|
inline bool is_reduce_ndim_idx_items( |
|
|
|
const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& items, |
|
|
|
const Span<ValueRef>& inputs) { |
|
|
|
for (auto i = 0; i < items.size(); ++i) { |
|
|
@@ -184,7 +197,7 @@ bool is_reduce_ndim_idx_items( |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
auto convert_nchw2nhwc_idx_items( |
|
|
|
inline auto convert_nchw2nhwc_idx_items( |
|
|
|
const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& items) { |
|
|
|
auto nhwc_items = items; |
|
|
|
for (auto i = 0; i < nhwc_items.size(); ++i) { |
|
|
@@ -199,51 +212,55 @@ auto convert_nchw2nhwc_idx_items( |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
std::vector<ValueRef> subtensor_rule( |
|
|
|
const T& op, Span<ValueRef>& inputs, const bool& auto_convert) { |
|
|
|
ValueRefList subtensor_rule( |
|
|
|
const T& op, Span<ValueRef>& inputs, const bool& auto_convert, |
|
|
|
const FormatTransformation& t) { |
|
|
|
mgb_assert(inputs.size() >= 1); |
|
|
|
auto& src = inputs[0].cast<FormattedTensorValue>(); |
|
|
|
auto& src = inputs[0].cast(t.value_type()); |
|
|
|
bool is_reduce_ndim = is_reduce_ndim_idx_items( |
|
|
|
op.items, {&inputs[1], &inputs[inputs.size() - 1]}); |
|
|
|
if (!is_reduce_ndim) { |
|
|
|
// only support NHWC2NCHW convert, otherwise maintain src's format |
|
|
|
if (!(auto_convert && src.format() == FT::NHWC)) { |
|
|
|
return {FormattedTensorValue::make( |
|
|
|
imperative::apply(op, unwrap_inputs(inputs))[0], src.format())}; |
|
|
|
return {t.wrap_output( |
|
|
|
imperative::apply(op, t.unwrap_inputs(inputs))[0], |
|
|
|
src.format().type())}; |
|
|
|
} |
|
|
|
auto nhwc_items = convert_nchw2nhwc_idx_items(op.items); |
|
|
|
auto outputs = imperative::apply( |
|
|
|
*T::make(std::move(nhwc_items), op.scope()), unwrap_inputs(inputs)); |
|
|
|
return wrap_outputs(outputs, FT::NHWC); |
|
|
|
*T::make(std::move(nhwc_items), op.scope()), t.unwrap_inputs(inputs)); |
|
|
|
return t.wrap_outputs(outputs, FT::NHWC); |
|
|
|
} |
|
|
|
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); |
|
|
|
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
std::vector<ValueRef> setsubtensor_rule( |
|
|
|
const T& op, Span<ValueRef>& inputs, const bool& auto_convert) { |
|
|
|
ValueRefList setsubtensor_rule( |
|
|
|
const T& op, Span<ValueRef>& inputs, const bool& auto_convert, |
|
|
|
const FormatTransformation& t) { |
|
|
|
mgb_assert(inputs.size() >= 2); |
|
|
|
auto& src = inputs[0].cast<FormattedTensorValue>(); |
|
|
|
auto& src = inputs[0].cast(t.value_type()); |
|
|
|
bool is_reduce_ndim = is_reduce_ndim_idx_items( |
|
|
|
op.items, {&inputs[2], &inputs[inputs.size() - 1]}); |
|
|
|
if (!is_reduce_ndim) { |
|
|
|
// only support NHWC2NCHW convert, otherwise maintain src's format |
|
|
|
if (!(auto_convert && src.format() == FT::NHWC)) { |
|
|
|
return {FormattedTensorValue::make( |
|
|
|
imperative::apply(op, unwrap_inputs(inputs))[0], src.format())}; |
|
|
|
return {t.wrap_output( |
|
|
|
imperative::apply(op, t.unwrap_inputs(inputs))[0], |
|
|
|
src.format().type())}; |
|
|
|
} |
|
|
|
// value has been broadcasted to src's fake NCHW shape. |
|
|
|
auto& value = inputs[1].cast<FormattedTensorValue>(); |
|
|
|
auto& value = inputs[1].cast(t.value_type()); |
|
|
|
auto& format = value.format(); |
|
|
|
auto nhwc_inputs = std::vector<ValueRef>(inputs.size()); |
|
|
|
auto nhwc_inputs = ValueRefList(inputs.size()); |
|
|
|
if (format == FT::DEFAULT || format == FT::NCHW) { |
|
|
|
// value for setsubtensor should transpose to match shape. |
|
|
|
auto nhwc_value = value.as(FT::NCHW)->to(FT::NHWC); |
|
|
|
auto nhwc_value = t.to(*(t.as(value, FT::NCHW)), FT::NHWC); |
|
|
|
// make new inputs for setsubtensor |
|
|
|
nhwc_inputs[0] = src.value(); |
|
|
|
nhwc_inputs[1] = nhwc_value->value(); |
|
|
|
for (auto i = 2; i < inputs.size(); ++i) { |
|
|
|
nhwc_inputs[i] = inputs[i].as_ref<FormattedTensorValue>()->value(); |
|
|
|
nhwc_inputs[i] = t.unwrap_input(inputs[i]); |
|
|
|
} |
|
|
|
} else if (format != FT::NHWC) { |
|
|
|
mgb_throw( |
|
|
@@ -253,15 +270,15 @@ std::vector<ValueRef> setsubtensor_rule( |
|
|
|
auto nhwc_items = convert_nchw2nhwc_idx_items(op.items); |
|
|
|
auto outputs = imperative::apply( |
|
|
|
*T::make(std::move(nhwc_items), op.scope()), nhwc_inputs); |
|
|
|
return wrap_outputs(outputs, FT::NHWC); |
|
|
|
return t.wrap_outputs(outputs, FT::NHWC); |
|
|
|
} |
|
|
|
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); |
|
|
|
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); |
|
|
|
} |
|
|
|
|
|
|
|
FT get_inputs_format(Span<ValueRef>& inputs) { |
|
|
|
inline FT get_inputs_format(Span<ValueRef>& inputs, const FormatTransformation& t) { |
|
|
|
FT format(FT::DEFAULT); |
|
|
|
for (auto& inp : inputs) { |
|
|
|
auto& inp_format = inp.cast<FormattedTensorValue>().format(); |
|
|
|
auto& inp_format = inp.cast(t.value_type()).format(); |
|
|
|
if (inp_format != FT::DEFAULT) { |
|
|
|
mgb_assert(format == FT::DEFAULT || inp_format == format); |
|
|
|
format = inp_format.type(); |
|
|
@@ -270,11 +287,12 @@ FT get_inputs_format(Span<ValueRef>& inputs) { |
|
|
|
return format; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<ValueRef> concat_rule( |
|
|
|
const Concat& op, Span<ValueRef>& inputs, const bool& auto_convert) { |
|
|
|
FT format = get_inputs_format(inputs); |
|
|
|
ValueRefList concat_rule( |
|
|
|
const Concat& op, Span<ValueRef>& inputs, const bool& auto_convert, |
|
|
|
const FormatTransformation& t) { |
|
|
|
FT format = get_inputs_format(inputs, t); |
|
|
|
if (!(format == FT::NHWC && auto_convert)) { |
|
|
|
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format); |
|
|
|
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format); |
|
|
|
} |
|
|
|
// TODO: handle 5D NHWC Tensor from group conv |
|
|
|
auto axis = op.axis; |
|
|
@@ -283,25 +301,26 @@ std::vector<ValueRef> concat_rule( |
|
|
|
} else if (axis == 1) { |
|
|
|
axis = 3; |
|
|
|
} |
|
|
|
return wrap_outputs( |
|
|
|
return t.wrap_outputs( |
|
|
|
imperative::apply( |
|
|
|
*Concat::make(axis, op.comp_node, op.scope()), |
|
|
|
unwrap_inputs(inputs)), |
|
|
|
t.unwrap_inputs(inputs)), |
|
|
|
format); |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<ValueRef> elemwise_rule( |
|
|
|
const Elemwise& op, Span<ValueRef>& inputs, const bool& auto_convert) { |
|
|
|
FT format = get_inputs_format(inputs); |
|
|
|
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format); |
|
|
|
ValueRefList elemwise_rule( |
|
|
|
const Elemwise& op, Span<ValueRef>& inputs, const bool& auto_convert, |
|
|
|
const FormatTransformation& t) { |
|
|
|
FT format = get_inputs_format(inputs, t); |
|
|
|
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format); |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<ValueRef> identity_rule_helper( |
|
|
|
const OpDef& op, const Span<ValueRef>& inputs) { |
|
|
|
ValueRefList identity_rule_helper( |
|
|
|
const OpDef& op, const Span<ValueRef>& inputs, const FormatTransformation& t) { |
|
|
|
// mgb_assert(inputs.size() == 1); |
|
|
|
auto& src = inputs[0].cast<FormattedTensorValue>(); |
|
|
|
return wrap_outputs( |
|
|
|
imperative::apply(op, unwrap_inputs(inputs)), src.format().type()); |
|
|
|
auto& src = inputs[0].cast(t.value_type()); |
|
|
|
return t.wrap_outputs( |
|
|
|
imperative::apply(op, t.unwrap_inputs(inputs)), src.format().type()); |
|
|
|
} |
|
|
|
|
|
|
|
// clang-format off |
|
|
@@ -318,10 +337,11 @@ std::vector<ValueRef> identity_rule_helper( |
|
|
|
cb(Identity) |
|
|
|
// clang-format on |
|
|
|
|
|
|
|
#define CREATE_IDENTITY_OP_RULE(op) \ |
|
|
|
std::vector<ValueRef> op##_rule( \ |
|
|
|
const op& _op, Span<ValueRef>& inputs, const bool& auto_convert) { \ |
|
|
|
return identity_rule_helper(_op, inputs); \ |
|
|
|
#define CREATE_IDENTITY_OP_RULE(op) \ |
|
|
|
ValueRefList op##_rule( \ |
|
|
|
const op& _op, Span<ValueRef>& inputs, const bool& auto_convert, \ |
|
|
|
const FormatTransformation& t) { \ |
|
|
|
return identity_rule_helper(_op, inputs, t); \ |
|
|
|
} |
|
|
|
FOREACH_IDENTITY_OP(CREATE_IDENTITY_OP_RULE) |
|
|
|
#undef CREATE_IDENTITY_OP_RULE |
|
|
@@ -344,22 +364,26 @@ struct FormatRuleRegistry { |
|
|
|
#undef REGISTER_IDENTITY_OP_RULE |
|
|
|
} // namespace |
|
|
|
|
|
|
|
std::vector<ValueRef> FormatTransformation::apply_transformation( |
|
|
|
ValueRefList FormatTransformation::apply_transformation( |
|
|
|
const Operator& op, Span<ValueRef> inputs) { |
|
|
|
if (auto* apply_op = op.as<ApplyOp>()) { |
|
|
|
// all inputs should be FormattedTensorValue |
|
|
|
auto iter = format_rules.find(apply_op->op().dyn_typeinfo()); |
|
|
|
if (iter != format_rules.end()) { |
|
|
|
return iter->second(apply_op->op(), inputs, m_auto_convert); |
|
|
|
return iter->second(apply_op->op(), inputs, m_auto_convert, *this); |
|
|
|
} else { |
|
|
|
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); |
|
|
|
} |
|
|
|
} else if (auto* create_tensor = op.as<CreateTensor>()) { |
|
|
|
auto format = create_tensor->format(); |
|
|
|
return {FormattedTensorValue::make(imperative::apply(op, inputs)[0], format)}; |
|
|
|
return {wrap_output(imperative::apply(op, inputs)[0], format.type())}; |
|
|
|
} else if (auto* get_attr = op.as<GetAttr>()) { |
|
|
|
auto* src = inputs.as_array<1>()[0].as<FormattedTensorValue>(); |
|
|
|
if (!m_auto_convert || !src || src->format() != FT::NHWC) { |
|
|
|
auto&& input = inputs.item(); |
|
|
|
if (!input.is(m_value_type)) { |
|
|
|
return imperative::apply(op, input); |
|
|
|
} |
|
|
|
auto& src = input.cast(m_value_type); |
|
|
|
if (!(m_auto_convert && src.format() == FT::NHWC)) { |
|
|
|
return imperative::apply(op, unwrap_inputs(inputs)); |
|
|
|
} |
|
|
|
switch (get_attr->attr()) { |
|
|
@@ -369,16 +393,16 @@ std::vector<ValueRef> FormatTransformation::apply_transformation( |
|
|
|
return {ShapeValue::make(shape)}; |
|
|
|
} |
|
|
|
case GetAttr::Value: { |
|
|
|
auto nchw_src = unwrap_input(src->to(FT::NCHW, "")); |
|
|
|
return imperative::apply(op, std::vector<ValueRef>{nchw_src}); |
|
|
|
auto nchw_src = unwrap_input(to(src, FT::NCHW, "")); |
|
|
|
return imperative::apply(op, SmallVector<ValueRef>{nchw_src}); |
|
|
|
} |
|
|
|
default: |
|
|
|
return imperative::apply(op, unwrap_inputs(inputs)); |
|
|
|
} |
|
|
|
} else if (op.is<GetFormat>()) { |
|
|
|
bool is_formatted_tensor = inputs.as_array<1>()[0].is<FormattedTensorValue>(); |
|
|
|
bool is_formatted_tensor = inputs.item().is(m_value_type); |
|
|
|
if (is_formatted_tensor) { |
|
|
|
return {FormatValue::make(inputs[0].cast<FormattedTensorValue>().format())}; |
|
|
|
return {FormatValue::make(inputs[0].cast(m_value_type).format())}; |
|
|
|
} else { |
|
|
|
mgb_log_warn( |
|
|
|
"Not FormattedTensorValue input for GetFormat op: %s", |
|
|
@@ -386,9 +410,9 @@ std::vector<ValueRef> FormatTransformation::apply_transformation( |
|
|
|
return {FormatValue::make(FT::DEFAULT)}; |
|
|
|
} |
|
|
|
} else if (op.is<Operator::IdentityLike>()) { |
|
|
|
bool is_formatted_tensor = inputs.as_array<1>()[0].is<FormattedTensorValue>(); |
|
|
|
bool is_formatted_tensor = inputs.item().is(m_value_type); |
|
|
|
if (is_formatted_tensor) { |
|
|
|
auto& format = inputs[0].cast<FormattedTensorValue>().format(); |
|
|
|
auto&& format = inputs[0].cast(m_value_type).format(); |
|
|
|
return wrap_outputs( |
|
|
|
imperative::apply(op, unwrap_inputs(inputs)), format.type()); |
|
|
|
} else { |
|
|
|