|
|
@@ -23,24 +23,42 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to( |
|
|
|
if (format == target) |
|
|
|
return as(tensor, target); |
|
|
|
|
|
|
|
auto&& shape = tensor.value().shape().cast<ShapeValue>(); |
|
|
|
if (format == FT::NHWC && (target == FT::NCHW || target == FT::DEFAULT)) { |
|
|
|
// FIXME(czh): temporary fast path for group conv 5D weight. |
|
|
|
if (tensor.value().shape().cast<ShapeValue>().ndim == 5) { |
|
|
|
if (shape.ndim == 5) { |
|
|
|
pattern = {0, 1, 4, 2, 3}; |
|
|
|
} else { |
|
|
|
} else if (shape.ndim == 4) { |
|
|
|
pattern = {0, 3, 1, 2}; |
|
|
|
} else { |
|
|
|
mgb_throw( |
|
|
|
MegBrainError, |
|
|
|
"Unsupport format conversion for tensor %s(shape=%s) from %s to %s", |
|
|
|
tensor.to_string().c_str(), shape.to_string().c_str(), |
|
|
|
format.to_string().c_str(), Format(target).to_string().c_str()); |
|
|
|
} |
|
|
|
} else if ((format == FT::NCHW || format == FT::DEFAULT) && target == FT::NHWC) { |
|
|
|
if (tensor.value().shape().cast<ShapeValue>().ndim == 5) { |
|
|
|
if (shape.ndim == 5) { |
|
|
|
pattern = {0, 1, 3, 4, 2}; |
|
|
|
} else { |
|
|
|
} else if (shape.ndim == 4) { |
|
|
|
pattern = {0, 2, 3, 1}; |
|
|
|
} else { |
|
|
|
mgb_throw( |
|
|
|
MegBrainError, |
|
|
|
"Unsupport format conversion for tensor %s(shape=%s) from %s to %s", |
|
|
|
tensor.to_string().c_str(), shape.to_string().c_str(), |
|
|
|
format.to_string().c_str(), Format(target).to_string().c_str()); |
|
|
|
} |
|
|
|
} else { |
|
|
|
mgb_throw( |
|
|
|
MegBrainError, "Unsupport format conversion from %s to %s", |
|
|
|
MegBrainError, |
|
|
|
"Unsupport format conversion for tensor %s(shape=%s) from %s to %s", |
|
|
|
tensor.to_string().c_str(), shape.to_string().c_str(), |
|
|
|
format.to_string().c_str(), Format(target).to_string().c_str()); |
|
|
|
} |
|
|
|
mgb_log_debug( |
|
|
|
"Change tensor %s from %s to %s", tensor.to_string().c_str(), |
|
|
|
format.to_string().c_str(), Format(target).to_string().c_str()); |
|
|
|
auto output = |
|
|
|
imperative::apply(*Dimshuffle::make(pattern, scope), {tensor.value()})[0]; |
|
|
|
return m_value_type.make(output, target); |
|
|
@@ -380,9 +398,7 @@ inline ValueRefList unify_inputs_format( |
|
|
|
ValueRefList unified_inputs(inputs.size()); |
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
|
auto&& inp = inputs[i].cast(t.value_type()); |
|
|
|
if (inp.format() != dst_fmt && |
|
|
|
(inp.value().shape().cast<ShapeValue>().ndim == 4 || |
|
|
|
inp.value().shape().cast<ShapeValue>().ndim == 5)) { |
|
|
|
if (inp.format() != dst_fmt) { |
|
|
|
unified_inputs[i] = t.to(inp, dst_fmt, scope); |
|
|
|
} else { |
|
|
|
unified_inputs[i] = inputs[i]; |
|
|
@@ -396,7 +412,16 @@ ValueRefList elemwise_rule( |
|
|
|
const FormatTransformation& t) { |
|
|
|
FT format = get_inputs_format(inputs, t); |
|
|
|
if (format == FT::NHWC && auto_convert) { |
|
|
|
auto unified_inputs = unify_inputs_format(inputs, FT::NHWC, op.scope(), t); |
|
|
|
ValueRefList unified_inputs(inputs.size()); |
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
|
auto&& inp = inputs[i].cast(t.value_type()); |
|
|
|
if (inp.format() != FT::NHWC && inp.value().is_scalar()) { |
|
|
|
unified_inputs[i] = t.value_type().make(inp.value(), FT::NHWC); |
|
|
|
} else { |
|
|
|
unified_inputs[i] = inputs[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
unified_inputs = unify_inputs_format(unified_inputs, FT::NHWC, op.scope(), t); |
|
|
|
return t.wrap_outputs( |
|
|
|
imperative::apply(op, t.unwrap_inputs(unified_inputs)), format); |
|
|
|
} |
|
|
@@ -410,7 +435,16 @@ ValueRefList concat_rule( |
|
|
|
if (!(format == FT::NHWC && auto_convert)) { |
|
|
|
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format); |
|
|
|
} |
|
|
|
auto unified_inputs = unify_inputs_format(inputs, FT::NHWC, op.scope(), t); |
|
|
|
ValueRefList unified_inputs(inputs.size()); |
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
|
auto&& inp = inputs[i].cast(t.value_type()); |
|
|
|
if (inp.format() != FT::NHWC && inp.value().is_scalar()) { |
|
|
|
unified_inputs[i] = t.value_type().make(inp.value(), FT::NHWC); |
|
|
|
} else { |
|
|
|
unified_inputs[i] = inputs[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
unified_inputs = unify_inputs_format(unified_inputs, FT::NHWC, op.scope(), t); |
|
|
|
// TODO: handle 5D NHWC Tensor from group conv |
|
|
|
auto axis = op.axis; |
|
|
|
if (axis == 2 || axis == 3) { |
|
|
@@ -441,7 +475,7 @@ ValueRefList batchnorm_rule( |
|
|
|
const FormatTransformation& t) { |
|
|
|
auto&& inp_format = inputs[0].cast(t.value_type()).format(); |
|
|
|
if (inp_format == FT::NHWC) { |
|
|
|
auto&& new_param = op.param(); |
|
|
|
auto new_param = op.param(); |
|
|
|
new_param.param_dim = BatchNorm::ParamDim::DIM_111C; |
|
|
|
auto new_op = BatchNorm::make(new_param); |
|
|
|
return identity_rule_helper(*new_op, inputs, t); |
|
|
@@ -454,7 +488,7 @@ ValueRefList adaptive_pooling_rule( |
|
|
|
const FormatTransformation& t) { |
|
|
|
auto&& inp_format = inputs[0].cast(t.value_type()).format(); |
|
|
|
if (inp_format == FT::NHWC) { |
|
|
|
auto&& new_param = op.param(); |
|
|
|
auto new_param = op.param(); |
|
|
|
new_param.format = AdaptivePooling::Format::NHWC; |
|
|
|
auto new_op = AdaptivePooling::make(new_param, op.shape); |
|
|
|
return identity_rule_helper(*new_op, inputs, t); |
|
|
@@ -518,7 +552,7 @@ FOREACH_IDENTITY_OP(CREATE_IDENTITY_OP_RULE) |
|
|
|
const FormatTransformation& t) { \ |
|
|
|
auto&& inp_format = inputs[0].cast(t.value_type()).format(); \ |
|
|
|
if (inp_format == FT::NHWC) { \ |
|
|
|
auto&& new_param = _op.param(); \ |
|
|
|
auto new_param = _op.param(); \ |
|
|
|
new_param.format = Op::Format::NHWC; \ |
|
|
|
auto new_op = Op::make(new_param); \ |
|
|
|
return identity_rule_helper(*new_op, inputs, t); \ |
|
|
@@ -535,7 +569,7 @@ FOREACH_FORMAT_OP(CREATE_FORMAT_OP_RULE) |
|
|
|
const FormatTransformation& t) { \ |
|
|
|
auto&& inp_format = inputs[0].cast(t.value_type()).format(); \ |
|
|
|
if (inp_format == FT::NHWC) { \ |
|
|
|
auto&& new_param = _op.param(); \ |
|
|
|
auto new_param = _op.param(); \ |
|
|
|
new_param.format = Op::Format::NHWC; \ |
|
|
|
auto new_op = Op::make(new_param, _op.policy()); \ |
|
|
|
return identity_rule_helper(*new_op, inputs, t); \ |
|
|
|