|
|
@@ -43,13 +43,22 @@ void AdaptivePoolingForward::outshape_by_symvar_do_get_output_shape( |
|
|
|
"shape mismatch for AdaptivePooling: src=%s, out2d=%s", |
|
|
|
src.to_string().c_str(), oshp2d.to_string().c_str()); |
|
|
|
|
|
|
|
mgb_assert( |
|
|
|
param().format == Param::Format::NCHW, "AdaptivePooling only support NCHW"); |
|
|
|
dest.ndim = 4; |
|
|
|
dest.shape[0] = src.shape[0]; |
|
|
|
dest.shape[1] = src.shape[1]; |
|
|
|
dest.shape[2] = oshp2d.shape[0]; |
|
|
|
dest.shape[3] = oshp2d.shape[1]; |
|
|
|
auto param_format = param().format; |
|
|
|
if (param_format == Param::Format::NCHW) { |
|
|
|
dest.ndim = 4; |
|
|
|
dest.shape[0] = src.shape[0]; |
|
|
|
dest.shape[1] = src.shape[1]; |
|
|
|
dest.shape[2] = oshp2d.shape[0]; |
|
|
|
dest.shape[3] = oshp2d.shape[1]; |
|
|
|
} else if (param_format == Param::Format::NHWC) { |
|
|
|
dest.ndim = 4; |
|
|
|
dest.shape[0] = src.shape[0]; |
|
|
|
dest.shape[1] = oshp2d.shape[0]; |
|
|
|
dest.shape[2] = oshp2d.shape[1]; |
|
|
|
dest.shape[3] = src.shape[3]; |
|
|
|
} else { |
|
|
|
mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format"); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
size_t AdaptivePoolingForward::get_workspace_size_bytes( |
|
|
|