|
|
@@ -37,12 +37,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false}; |
|
|
|
} |
|
|
|
|
|
|
|
const dt_int32* oshp2d = nullptr; |
|
|
|
dst_layout.ndim = 4u; |
|
|
|
if (nr_inp == 1) { |
|
|
|
dst_layout[0] = src.layout[0]; |
|
|
|
dst_layout[1] = src.layout[1]; |
|
|
|
dst_layout[2] = pool.shape[0]; |
|
|
|
dst_layout[3] = pool.shape[1]; |
|
|
|
oshp2d = pool.shape.data(); |
|
|
|
} else { |
|
|
|
auto&& tshp = inputs[1]; |
|
|
|
if (tshp.value.empty()) { |
|
|
@@ -52,11 +50,21 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
tshp.layout.ndim == 1, |
|
|
|
"target shape of AdaptivePooling expects ndim=1; got ndim=%lu actually", |
|
|
|
tshp.layout.ndim); |
|
|
|
oshp2d = tshp.value.ptr<dt_int32>(); |
|
|
|
} |
|
|
|
auto param_format = pool.param().format; |
|
|
|
if (param_format == opr::AdaptivePooling::Param::Format::NCHW) { |
|
|
|
dst_layout[0] = src.layout[0]; |
|
|
|
dst_layout[1] = src.layout[1]; |
|
|
|
auto* ptr = tshp.value.ptr<dt_int32>(); |
|
|
|
dst_layout[2] = ptr[0]; |
|
|
|
dst_layout[3] = ptr[1]; |
|
|
|
dst_layout[2] = oshp2d[0]; |
|
|
|
dst_layout[3] = oshp2d[1]; |
|
|
|
} else if (param_format == opr::AdaptivePooling::Param::Format::NHWC) { |
|
|
|
dst_layout[0] = src.layout[0]; |
|
|
|
dst_layout[1] = oshp2d[0]; |
|
|
|
dst_layout[2] = oshp2d[1]; |
|
|
|
dst_layout[3] = src.layout[3]; |
|
|
|
} else { |
|
|
|
mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format"); |
|
|
|
} |
|
|
|
dst_layout.init_contiguous_stride(); |
|
|
|
return {{{dst_layout, src.comp_node}}, true}; |
|
|
@@ -71,26 +79,47 @@ SmallVector<TensorPtr> apply_on_physical_tensor( |
|
|
|
using TensorND = megdnn::TensorND; |
|
|
|
auto&& src_layout = inputs[0]->layout(); |
|
|
|
TensorLayout dst_layout = output_descs[0].layout; |
|
|
|
auto param_format = pool.format; |
|
|
|
if (!validated) { |
|
|
|
TensorShape tshp; |
|
|
|
dst_layout.ndim = src_layout.ndim; |
|
|
|
dst_layout[0] = src_layout[0]; |
|
|
|
dst_layout[1] = src_layout[1]; |
|
|
|
const dt_int32* oshp2d = nullptr; |
|
|
|
if (inputs.size() == 2) { |
|
|
|
auto&& tshp_nd = inputs[1]; |
|
|
|
cg::copy_tensor_value_to_shape( |
|
|
|
tshp, tshp_nd->get_value().proxy_to_default_cpu()); |
|
|
|
dst_layout[2] = tshp[0]; |
|
|
|
dst_layout[3] = tshp[1]; |
|
|
|
oshp2d = tshp_nd->get_value().proxy_to_default_cpu().ptr<dt_int32>(); |
|
|
|
} else { |
|
|
|
dst_layout[2] = pool.shape[0]; |
|
|
|
dst_layout[3] = pool.shape[1]; |
|
|
|
oshp2d = pool.shape.data(); |
|
|
|
} |
|
|
|
if (param_format == opr::AdaptivePooling::Param::Format::NCHW) { |
|
|
|
dst_layout[0] = src_layout[0]; |
|
|
|
dst_layout[1] = src_layout[1]; |
|
|
|
dst_layout[2] = oshp2d[0]; |
|
|
|
dst_layout[3] = oshp2d[1]; |
|
|
|
} else if (param_format == opr::AdaptivePooling::Param::Format::NHWC) { |
|
|
|
dst_layout[0] = src_layout[0]; |
|
|
|
dst_layout[1] = oshp2d[0]; |
|
|
|
dst_layout[2] = oshp2d[1]; |
|
|
|
dst_layout[3] = src_layout[3]; |
|
|
|
} else { |
|
|
|
mgb_throw( |
|
|
|
MegBrainError, "AdaptivePooling only support NCHW or NHWC format"); |
|
|
|
} |
|
|
|
dst_layout.init_contiguous_stride(); |
|
|
|
} |
|
|
|
|
|
|
|
size_t IH = src_layout[2], IW = src_layout[3], OH = dst_layout[2], |
|
|
|
OW = dst_layout[3]; |
|
|
|
size_t IH, IW, OH, OW; |
|
|
|
if (param_format == param::AdaptivePooling::Format::NCHW) { |
|
|
|
IH = src_layout[2]; |
|
|
|
IW = src_layout[3]; |
|
|
|
OH = dst_layout[2]; |
|
|
|
OW = dst_layout[3]; |
|
|
|
} else if (param_format == param::AdaptivePooling::Format::NHWC) { |
|
|
|
IH = src_layout[1]; |
|
|
|
IW = src_layout[2]; |
|
|
|
OH = dst_layout[1]; |
|
|
|
OW = dst_layout[2]; |
|
|
|
} else { |
|
|
|
mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format"); |
|
|
|
} |
|
|
|
DnnOprCaller<megdnn::Pooling> dnn_opr(cn); |
|
|
|
auto&& param = dnn_opr.op->param(); |
|
|
|
param.mode = pool.mode; |
|
|
|