diff --git a/dnn/include/megdnn/oprs/cv.h b/dnn/include/megdnn/oprs/cv.h index b46ac2ac..c20f8945 100644 --- a/dnn/include/megdnn/oprs/cv.h +++ b/dnn/include/megdnn/oprs/cv.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "megdnn/internal/opr_header_prologue.h" @@ -21,23 +22,23 @@ class FlipBase : public OperatorBase { DEF_OPR_IMPL_CTOR(FlipBase, OperatorBase); DEF_OPR_PARAM(Flip); - protected: - void deduce_layout_fwd(const TensorLayout &src, TensorLayout &dst); - void check_layout_fwd(const TensorLayout &src, const TensorLayout &dst); +protected: + void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst); + void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst); }; class FlipForward : public FlipBase { DEF_OPR_IMPL(FlipForward, FlipBase, 1, 1); - public: +public: virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout &src, TensorLayout &dst); - virtual size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &dst) = 0; + void deduce_layout(const TensorLayout& src, TensorLayout& dst); + virtual size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& dst) = 0; - protected: - void check_exec(const TensorLayout &src, const TensorLayout &dst, +protected: + void check_exec(const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes); }; using Flip = FlipForward; @@ -46,23 +47,23 @@ class RotateBase : public OperatorBase { DEF_OPR_IMPL_CTOR(RotateBase, OperatorBase); DEF_OPR_PARAM(Rotate); - protected: - void deduce_layout_fwd(const TensorLayout &src, TensorLayout &dst); - void check_layout_fwd(const TensorLayout &src, const TensorLayout &dst); +protected: + void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst); + void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst); }; class RotateForward : public RotateBase { DEF_OPR_IMPL(RotateForward, RotateBase, 1, 1); - public: +public: virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout &src, TensorLayout &dst); - virtual size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &dst) = 0; + void deduce_layout(const TensorLayout& src, TensorLayout& dst); + virtual size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& dst) = 0; - protected: - void check_exec(const TensorLayout &src, const TensorLayout &dst, +protected: + void check_exec(const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes); }; using Rotate = RotateForward; @@ -71,23 +72,23 @@ class ROICopyBase : public OperatorBase { DEF_OPR_IMPL_CTOR(ROICopyBase, OperatorBase); DEF_OPR_PARAM(ROICopy); - protected: - void deduce_layout_fwd(const TensorLayout &src, TensorLayout &dst); - void check_layout_fwd(const TensorLayout &src, const TensorLayout &dst); +protected: + void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst); + void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst); }; class ROICopyForward : public ROICopyBase { DEF_OPR_IMPL(ROICopyForward, ROICopyBase, 1, 1); - public: +public: virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout &src, TensorLayout &dst); - virtual size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &dst) = 0; + void deduce_layout(const TensorLayout& src, TensorLayout& dst); + virtual size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& dst) = 0; - protected: - void check_exec(const TensorLayout &src, const TensorLayout &dst, +protected: + void check_exec(const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes); }; using ROICopy = ROICopyForward; @@ -96,23 +97,23 @@ class CvtColorBase : public OperatorBase { DEF_OPR_IMPL_CTOR(CvtColorBase, OperatorBase); DEF_OPR_PARAM(CvtColor); - protected: - void deduce_layout_fwd(const TensorLayout &src, TensorLayout &dst); - void check_layout_fwd(const TensorLayout &src, const TensorLayout &dst); +protected: + void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst); + void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst); }; class CvtColorForward : public CvtColorBase { DEF_OPR_IMPL(CvtColorForward, CvtColorBase, 1, 1); - public: +public: virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout &src, TensorLayout &dst); - virtual size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &dst) = 0; + void deduce_layout(const TensorLayout& src, TensorLayout& dst); + virtual size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& dst) = 0; - protected: - void check_exec(const TensorLayout &src, const TensorLayout &dst, +protected: + void check_exec(const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes); }; using CvtColor = CvtColorForward; @@ -124,20 +125,21 @@ class WarpAffineBase : public OperatorBase { DEF_OPR_IMPL_CTOR(WarpAffineBase, OperatorBase); DEF_OPR_PARAM(WarpAffine); - public: - using InterpolationMode = Param::InterpolationMode; - using BorderMode = Param::BorderMode; - protected: - void check_layout_fwd(const TensorLayout& src, const TensorLayout& trans, - const TensorLayout& dst); - std::string param_msg() const; - int get_real_coord(int p, int len); +public: + using InterpolationMode = Param::InterpolationMode; + using BorderMode = Param::BorderMode; + +protected: + void check_layout_fwd(const TensorLayout& src, const TensorLayout& trans, + const TensorLayout& dst); + std::string param_msg() const; + int get_real_coord(int p, int len); }; class WarpAffineForward : public WarpAffineBase { DEF_OPR_IMPL(WarpAffineForward, WarpAffineBase, 2, 1); - public: +public: /** * \param[in] src input tensor * \param[in] trans transform matrix tensor @@ -148,13 +150,13 @@ class WarpAffineForward : public WarpAffineBase { */ virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &trans, - const TensorLayout &dst) = 0; + virtual size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& trans, + const TensorLayout& dst) = 0; - protected: - void check_exec(const TensorLayout &src, const TensorLayout &trans, - const TensorLayout &dst, size_t workspace_in_bytes); +protected: + void check_exec(const TensorLayout& src, const TensorLayout& trans, + const TensorLayout& dst, size_t workspace_in_bytes); }; using WarpAffine = WarpAffineForward; @@ -162,23 +164,23 @@ class GaussianBlurBase : public OperatorBase { DEF_OPR_IMPL_CTOR(GaussianBlurBase, OperatorBase); DEF_OPR_PARAM(GaussianBlur); - protected: - void deduce_layout_fwd(const TensorLayout &src, TensorLayout &dst); - void check_layout_fwd(const TensorLayout &src, const TensorLayout &dst); +protected: + void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst); + void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst); }; class GaussianBlurForward : public GaussianBlurBase { DEF_OPR_IMPL(GaussianBlurForward, GaussianBlurBase, 1, 1); - public: +public: virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout &src, TensorLayout &dst); - virtual size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &dst) = 0; + void deduce_layout(const TensorLayout& src, TensorLayout& dst); + virtual size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& dst) = 0; - protected: - void check_exec(const TensorLayout &src, const TensorLayout &dst, +protected: + void check_exec(const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes); }; using GaussianBlur = GaussianBlurForward; @@ -230,41 +232,75 @@ protected: size_t workspace_in_bytes); }; -class SeparableFilterBase: public OperatorBase { +/** + * \brief Remap opr. + */ +class RemapBase : public OperatorBase { + DEF_OPR_PARAM(Remap); + DEF_OPR_IMPL(RemapBase, OperatorBase, 2, 1); + +public: + using InterpolationMode = Param::InterpolationMode; + using BorderMode = Param::BorderMode; + +protected: + void check_layout_fwd(const TensorLayout& src, const TensorLayout& map_xy, + const TensorLayout& dst); + void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& map_xy, + TensorLayout& dst); +}; + +class RemapForward : public RemapBase { + DEF_OPR_IMPL(RemapForward, RemapBase, 2, 1); + +public: + virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, + _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; + + void deduce_layout(const TensorLayout& src, const TensorLayout& map_xy, + TensorLayout& dst); + + virtual size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& map_xy, + const TensorLayout& dst) = 0; + +protected: + void check_exec(const TensorLayout& src, const TensorLayout& map_xy, + const TensorLayout& dst, size_t workspace_in_bytes); +}; +using Remap = RemapForward; + +class SeparableFilterBase : public OperatorBase { DEF_OPR_IMPL_CTOR(SeparableFilterBase, OperatorBase); DEF_OPR_PARAM(SeparableFilter); - protected: - void deduce_layout_fwd(const TensorLayout &src, - const TensorLayout &filter_x, - const TensorLayout &filter_y, - TensorLayout &dst); - void check_layout_fwd(const TensorLayout &src, - const TensorLayout &filter_x, - const TensorLayout &filter_y, - const TensorLayout &dst); + +protected: + void deduce_layout_fwd(const TensorLayout& src, + const TensorLayout& filter_x, + const TensorLayout& filter_y, TensorLayout& dst); + void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter_x, + const TensorLayout& filter_y, + const TensorLayout& dst); }; -class SeparableFilterForward: public SeparableFilterBase { +class SeparableFilterForward : public SeparableFilterBase { DEF_OPR_IMPL(SeparableFilterForward, SeparableFilterBase, 3, 1); - public: - virtual void exec(_megdnn_tensor_in src, - _megdnn_tensor_in filter_x, - _megdnn_tensor_in filter_y, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout &src, - const TensorLayout &filter_x, - const TensorLayout &filter_y, - TensorLayout &dst); - virtual size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &filter_x, - const TensorLayout &filter_y, - const TensorLayout &dst) = 0; - protected: - void check_exec(const TensorLayout &src, - const TensorLayout &filter_x, - const TensorLayout &filter_y, - const TensorLayout &dst, size_t workspace_in_bytes); + +public: + virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter_x, + _megdnn_tensor_in filter_y, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + void deduce_layout(const TensorLayout& src, const TensorLayout& filter_x, + const TensorLayout& filter_y, TensorLayout& dst); + virtual size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& filter_x, + const TensorLayout& filter_y, + const TensorLayout& dst) = 0; + +protected: + void check_exec(const TensorLayout& src, const TensorLayout& filter_x, + const TensorLayout& filter_y, const TensorLayout& dst, + size_t workspace_in_bytes); }; using SeparableFilter = SeparableFilterForward; diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py old mode 100644 new mode 100755 index 23bd1f48..46e77482 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -35,10 +35,10 @@ pdef('Axis').add_fields('int32', 'axis', 0) ). add_enum(Doc('Format', 'convolution data/filter/output format; see ' ':class:`RelayoutFormat` for more details'), - 'NCHW', 'NHWC', 'NHWCD4', 'NCHW4', 'NCHW8', 'NCHW32', 'NCHW88', 'NCHW44', - Doc('NCHW_WINOGRAD', 'NCHW layout with weights tranformed by winograd'), - Doc('NCHW88_WINOGRAD', 'NCHW88 layout with weights tranformed by winograd'), - Doc('NCHW44_WINOGRAD', 'NCHW44 layout with weights tranformed by winograd'), + 'NCHW', 'NHWC', 'NHWCD4', 'NCHW4', 'NCHW8', 'NCHW32', 'NCHW88', 'NCHW44', + Doc('NCHW_WINOGRAD', 'NCHW layout with weights tranformed by winograd'), + Doc('NCHW88_WINOGRAD', 'NCHW88 layout with weights tranformed by winograd'), + Doc('NCHW44_WINOGRAD', 'NCHW44 layout with weights tranformed by winograd'), Doc('CHWN4', 'CHWN4 is currently only used on Nvidia platform for fast implementation ' 'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.')) ) @@ -699,6 +699,12 @@ pdef('UniformRNG').add_fields('uint64', 'seed', 0) .add_enum_alias('InterpolationMode', 'WarpPerspective', name_field='imode') .add_enum_alias('Format', 'ConvolutionV0', default=1)) +(pdef('Remap', version=0) + .add_enum_alias('InterpolationMode', 'WarpPerspective', name_field='imode') + .add_enum_alias('BorderMode', 'WarpPerspective', name_field='border_type') + .add_enum_alias('Format', 'ConvolutionV0', default=1) + .add_fields('float32', 'scalar', '0.f')) + (pdef('Convolution3D'). add_enum('Mode', 'CROSS_CORRELATION', 'CONVOLUTION'). add_fields( @@ -840,8 +846,8 @@ when the ``I`` suffix is present. 'INTER_WEIGHT_CHAN', 'INTER_WEIGHT_CHANI', 'INTER_WEIGHT_DENSEI_DOT', - 'INTER_WEIGHT_GROUPI_DOT', - 'NCHW4_CHWN4', + 'INTER_WEIGHT_GROUPI_DOT', + 'NCHW4_CHWN4', 'CHWN4_NCHW4', 'NCHW_NCHW88_CONV_DENSE_WEIGHT', 'NCHW_NCHW88_CONV_CHAN_WEIGHT', @@ -849,7 +855,7 @@ when the ``I`` suffix is present. 'NCHW_NCHW88', 'NCHW88_NCHW') ) - + (pdef('SeparableFilter'). add_enum_alias('Format', 'ConvolutionV0'). @@ -882,10 +888,10 @@ when the ``I`` suffix is present. add_enum_alias('Format', 'ConvolutionV0'). add_fields('float32', 'spatial_scale', '1.0'). add_fields('float32', 'offset', '0.0'). - add_fields('uint32', - 'pooled_height', '1', + add_fields('uint32', + 'pooled_height', '1', 'pooled_width', '1', - 'sample_height', '2', + 'sample_height', '2', 'sample_width', '2') ) (pdef('DeformablePSROIPooling'). diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index c8d654b7..168482fa 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -188,6 +188,7 @@ private: cb(ROIAlignForward) \ cb(ROIAlignBackward) \ cb(BatchConvBiasForward) \ + cb(Remap) \ /*! * \brief specialize HandleImpl::create_operator for a single opr type; diff --git a/dnn/src/common/remap.cpp b/dnn/src/common/remap.cpp new file mode 100644 index 00000000..ff76f866 --- /dev/null +++ b/dnn/src/common/remap.cpp @@ -0,0 +1,90 @@ +/** + * \file dnn/src/common/remap.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "megdnn/oprs.h" +#include "src/common/cv/common.h" +#include "src/common/cv/helper.h" +#include "src/common/utils.h" + +namespace megdnn { + +void RemapBase::deduce_layout_fwd(const TensorLayout& src, + const TensorLayout& map_xy, + TensorLayout& dst) { + dst.dtype = src.dtype; + dst.ndim = src.ndim; + dst.shape[0] = src.shape[0]; + size_t height_index, channel_index; + if (param().format == param::Remap::Format::NHWC) { + height_index = 1; + channel_index = 3; + } else { + megdnn_assert(param().format == param::Remap::Format::NCHW); + height_index = 2; + channel_index = 1; + } + dst.shape[height_index] = map_xy.shape[1]; + dst.shape[height_index + 1] = map_xy.shape[2]; + dst.shape[channel_index] = src.shape[channel_index]; +} + +void RemapBase::check_layout_fwd(const TensorLayout& src, + const TensorLayout& map_xy, + const TensorLayout& dst) { + auto errmsg = [&]() { + return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(map_xy) + + ", " + megdnn_layout_msg(dst); + }; + MEGDNN_MARK_USED_VAR(errmsg); + megdnn_assert(src.ndim == map_xy.ndim && src.ndim == dst.ndim && + src.ndim == 4); + megdnn_assert(dst.dtype == src.dtype); + megdnn_assert(dst.shape[0] == src.shape[0], "%s", errmsg().c_str()); + megdnn_assert(map_xy.shape[3] == 2); + megdnn_assert(map_xy.shape[0] == src.shape[0]); + + // map_xy only support floa32 type + // map_xy always in NHWC format + megdnn_assert(map_xy.dtype.enumv() == DTypeEnum::Float32); + + // In remap opr, H, W is same as H W in map_xy. + if (param().format == param::Remap::Format::NHWC) { + megdnn_assert(src.shape[3] == dst.shape[3], "%s", errmsg().c_str()); + megdnn_assert(dst.shape[2] == map_xy.shape[2] && + dst.shape[1] == map_xy.shape[1], + "%s", errmsg().c_str()); + } else if (param().format == param::Remap::Format::NCHW) { + megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str()); + megdnn_assert(dst.shape[2] == map_xy.shape[1] && + dst.shape[3] == map_xy.shape[2], + "%s", errmsg().c_str()); + } else { + megdnn_throw( + "megdnn currently do not support other param.format except " + "NHWC and NCHW"); + } +} + +void Remap::deduce_layout(const TensorLayout& src, const TensorLayout& map_xy, + TensorLayout& dst) { + deduce_layout_fwd(src, map_xy, dst); +} + +void Remap::check_exec(const TensorLayout& src, const TensorLayout& map_xy, + const TensorLayout& dst, size_t workspace_in_bytes) { + check_layout_fwd(src, map_xy, dst); + auto required_workspace_in_bytes = get_workspace_in_bytes(src, map_xy, dst); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); +} + +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp index 890a21ce..e060a43d 100644 --- a/dnn/src/cuda/handle_create.cpp +++ b/dnn/src/cuda/handle_create.cpp @@ -74,6 +74,7 @@ #include "src/cuda/local_share/opr_impl.h" #include "src/cuda/roi_align/opr_impl.h" #include "src/cuda/batch_conv_bias/opr_impl.h" +#include "src/cuda/remap/opr_impl.h" namespace megdnn { namespace cuda { diff --git a/dnn/src/cuda/remap/opr_impl.cpp b/dnn/src/cuda/remap/opr_impl.cpp new file mode 100644 index 00000000..191f4cda --- /dev/null +++ b/dnn/src/cuda/remap/opr_impl.cpp @@ -0,0 +1,26 @@ +/** + * \file dnn/src/opencl/cuda/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/cuda/remap/opr_impl.h" +#include "megdnn/config/config.h" +#include "src/common/utils.h" + +using namespace megdnn; +using namespace cuda; + +void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy, + _megdnn_tensor_in dst, _megdnn_workspace workspace) { + check_exec(src.layout, map_xy.layout, dst.layout, workspace.size); + megdnn_throw("megdnn currently do not support remap in cuda"); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/remap/opr_impl.h b/dnn/src/cuda/remap/opr_impl.h new file mode 100644 index 00000000..f4fd4f31 --- /dev/null +++ b/dnn/src/cuda/remap/opr_impl.h @@ -0,0 +1,29 @@ +/** + * \file dnn/src/opencl/cuda/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace cuda { +class RemapImpl final : public Remap { + using Remap::Remap; + void exec(_megdnn_tensor_in, _megdnn_tensor_in, _megdnn_tensor_out, + _megdnn_workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, + const TensorLayout&) override { + return 0; + } +}; +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index c6a58330..3de0244c 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -75,6 +75,7 @@ #include "src/naive/warp_affine/opr_impl.h" #include "src/naive/warp_perspective/opr_impl.h" #include "src/naive/winograd_filter_preprocess/opr_impl.h" +#include "src/naive/remap/opr_impl.h" static size_t g_image2d_pitch_alignment = 1; diff --git a/dnn/src/naive/remap/opr_impl.cpp b/dnn/src/naive/remap/opr_impl.cpp new file mode 100644 index 00000000..0fe67a41 --- /dev/null +++ b/dnn/src/naive/remap/opr_impl.cpp @@ -0,0 +1,194 @@ +/** + * \file dnn/src/naive/remap/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/naive/remap/opr_impl.h" +#include "src/common/cv/helper.h" +#include "src/common/utils.h" +#include "src/naive/handle.h" + +using namespace megdnn; +using namespace naive; + +namespace { +template +inline int get_offset(int, int, int, int, int, int); + +template <> +inline int get_offset(int height, int width, + int channel, int h, int w, + int) { + return channel * h * w + height * w + width; +} + +template <> +inline int get_offset(int height, int width, + int channel, int, int w, + int c) { + return height * w * c + width * c + channel; +} + +template +struct GetSrcData { + static inline DataType get(const DataType*, int, int, int, int, int, int, + int); +}; + +template +struct GetSrcData { + static inline DataType get(const DataType* src, int height, int width, + int channel, int h, int w, int c, float scalar) { + return (height >= 0 && height < h && width >= 0 && width < w) + ? src[get_offset(height, width, channel, h, w, + c)] + : static_cast(std::round(scalar)); + } +}; + +#define cb(bmode) \ + template \ + struct GetSrcData { \ + static inline DataType get(const DataType* src, int height, int width, \ + int channel, int h, int w, int c, float) { \ + height = megcv::border_interpolate< \ + param::Remap::BorderMode::bmode>(height, h); \ + width = megcv::border_interpolate< \ + param::Remap::BorderMode::bmode>(width, w); \ + return src[get_offset(height, width, channel, h, w, c)]; \ + } \ + }; + +cb(REPLICATE); +cb(REFLECT); +cb(REFLECT_101); +cb(WRAP); +#undef cb + +template +void remap_LINEAR(const DataType* src, const float* map_xy, DataType* dst, + int N, int C, int IH, int IW, int OH, int OW, float scalar, + std::function round) { + for (int n = 0; n < N; + ++n, src += C * IH * IW, dst += C * OH * OW, map_xy += OH * OW * 2) { + for (int h = 0; h < OH; ++h) { + for (int w = 0; w < OW; ++w) { + float index_col = map_xy[h * OW * 2 + w * 2 + 0]; + float index_row = map_xy[h * OW * 2 + w * 2 + 1]; + int col = static_cast(floor(index_col)); + int row = static_cast(floor(index_row)); + float v = index_col - col; + float u = index_row - row; + float one = 1.f; + for (int c = 0; c < C; ++c) { + DataType a00 = + GetSrcData::get( + src, row + 0, col + 0, c, IH, IW, C, + scalar); + DataType a01 = + GetSrcData::get( + src, row + 0, col + 1, c, IH, IW, C, + scalar); + DataType a10 = + GetSrcData::get( + src, row + 1, col + 0, c, IH, IW, C, + scalar); + DataType a11 = + GetSrcData::get( + src, row + 1, col + 1, c, IH, IW, C, + scalar); + + dst[get_offset(h, w, c, OH, OW, C)] = + static_cast( + round(a00 * (one - u) * (one - v) + + a01 * (one - u) * v + + a10 * (one - v) * u + a11 * u * v)); + } + } + } + } +} + +template +struct Round { + static inline DataType round(float x) { return std::round(x); } +}; + +template +struct Round { + static inline DataType round(float x) { return static_cast(x); } +}; + +} // namespace + +void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, + _megdnn_tensor_out dst, _megdnn_workspace workspace) { + check_exec(src.layout, map_xy.layout, dst.layout, workspace.size); + int N, C, IH, IW, OH, OW; + if (param().format == param::Remap::Format::NCHW) { + N = src.layout.shape[0]; + C = src.layout.shape[1]; + IH = src.layout.shape[2]; + IW = src.layout.shape[3]; + } else { + N = src.layout.shape[0]; + C = src.layout.shape[3]; + IH = src.layout.shape[1]; + IW = src.layout.shape[2]; + } + OH = map_xy.layout.shape[1]; + OW = map_xy.layout.shape[2]; + switch (src.layout.dtype.enumv()) { +#define cb(dt, fmt, border, interpolation) \ + if (param().format == param::Remap::Format::fmt && \ + param().border_type == param::Remap::BorderMode::border && \ + param().imode == param::Remap::InterpolationMode::interpolation) { \ + using ctype = DTypeTrait
::ctype; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + (remap_##interpolation( \ + src.compatible_ptr(), \ + map_xy.compatible_ptr(), \ + dst.compatible_ptr(), N, C, IH, IW, OH, OW, \ + param().scalar, \ + Round::category>::round))); \ + break; \ + } + +#define support_dtype(dt) \ + case DTypeTrait
::enumv: { \ + cb(dt, NCHW, CONSTANT, LINEAR); \ + cb(dt, NCHW, REPLICATE, LINEAR); \ + cb(dt, NCHW, REFLECT, LINEAR); \ + cb(dt, NCHW, REFLECT_101, LINEAR); \ + cb(dt, NCHW, WRAP, LINEAR); \ + cb(dt, NHWC, CONSTANT, LINEAR); \ + cb(dt, NHWC, REPLICATE, LINEAR); \ + cb(dt, NHWC, REFLECT, LINEAR); \ + cb(dt, NHWC, REFLECT_101, LINEAR); \ + cb(dt, NHWC, WRAP, LINEAR); \ + megdnn_throw( \ + "format, border type or imode is incorrect in remap navie " \ + "with dtype = " #dt); \ + } + + support_dtype(dtype::Float32); + MEGDNN_INC_FLOAT16(support_dtype(dtype::Float16)); + support_dtype(dtype::Int8); + support_dtype(dtype::Uint8); +#undef cb +#undef support_dtype + + default: + megdnn_throw("unsupported dtype in remap naive\n"); + } +} diff --git a/dnn/src/naive/remap/opr_impl.h b/dnn/src/naive/remap/opr_impl.h new file mode 100644 index 00000000..5423d3cf --- /dev/null +++ b/dnn/src/naive/remap/opr_impl.h @@ -0,0 +1,29 @@ +/** + * \file dnn/src/naive/remap/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace naive { +class RemapImpl final : public Remap { + using Remap::Remap; + void exec(_megdnn_tensor_in, _megdnn_tensor_in, _megdnn_tensor_out, + _megdnn_workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, + const TensorLayout&) override { + return 0; + } +}; +} // namespace naive +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/common/opr_trait.h b/dnn/test/common/opr_trait.h index 611afad6..7ff1c6c9 100644 --- a/dnn/test/common/opr_trait.h +++ b/dnn/test/common/opr_trait.h @@ -105,6 +105,7 @@ DEF(DeformableConvBackwardData, 8, true, false); DEF(DeformablePSROIPoolingForward, 5, true, true); DEF(DeformablePSROIPoolingBackward, 7, true, false); DEF(BatchConvBiasForward, 5, true, true); +DEF(Remap, 3, true, true); } // namespace test } // namespace megdnn diff --git a/dnn/test/common/remap.h b/dnn/test/common/remap.h new file mode 100644 index 00000000..9267364f --- /dev/null +++ b/dnn/test/common/remap.h @@ -0,0 +1,127 @@ +/** + * \file dnn/test/common/remap.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include +#include "megdnn/basic_types.h" +#include "megdnn/opr_param_defs.h" + +#include "./rng.h" +namespace megdnn { +namespace test { +namespace remap { + +struct TestArg { + param::Remap param; + TensorShape src; + TensorShape map_xy; + TensorShape dst; + TestArg(param::Remap param_, TensorShape src_, TensorShape map_xy_, + TensorShape dst_) + : param(param_), src(src_), map_xy(map_xy_), dst(dst_) {} +}; + +static inline std::vector get_nchw_args() { + std::vector args; + + param::Remap param; + std::vector format_vec = {param::Remap::Format::NCHW}; + std::vector border_mode_vec = { + param::Remap::BorderMode::CONSTANT, + param::Remap::BorderMode::REFLECT_101, + param::Remap::BorderMode::REFLECT, + param::Remap::BorderMode::WRAP, + param::Remap::BorderMode::REPLICATE}; + // current do not test this. + std::vector scalar; + for (auto fmt : format_vec) { + for (auto border_type : border_mode_vec) { + param.format = fmt; + param.border_type = border_type; + args.emplace_back(param, TensorShape{1, 1, 2, 2}, + TensorShape{1, 2, 2, 2}, TensorShape{1, 1, 2, 2}); + + args.emplace_back(param, TensorShape{1, 3, 2, 2}, + TensorShape{1, 2, 2, 2}, TensorShape{1, 3, 2, 2}); + + args.emplace_back(param, TensorShape{1, 10, 100, 100}, + TensorShape{1, 100, 100, 2}, + TensorShape{1, 10, 100, 100}); + + args.emplace_back(param, TensorShape{2, 4, 100, 200}, + TensorShape{2, 100, 200, 2}, + TensorShape{2, 4, 100, 200}); + + args.emplace_back(param, TensorShape{2, 4, 100, 200}, + TensorShape{2, 20, 30, 2}, + TensorShape{2, 4, 20, 30}); + + args.emplace_back(param, TensorShape{2, 4, 10, 10}, + TensorShape{2, 20, 30, 2}, + TensorShape{2, 4, 20, 30}); + } + } + return args; +} + +static inline std::vector get_nhwc_args() { + std::vector args; + + param::Remap param; + std::vector format_vec = {param::Remap::Format::NHWC}; + std::vector border_mode_vec = { + param::Remap::BorderMode::CONSTANT, + param::Remap::BorderMode::REFLECT_101, + param::Remap::BorderMode::REFLECT, + param::Remap::BorderMode::WRAP, + param::Remap::BorderMode::REPLICATE}; + // current do not test this. + std::vector scalar; + for (auto fmt : format_vec) { + for (auto border_type : border_mode_vec) { + param.format = fmt; + param.border_type = border_type; + param.scalar = 12.f; + args.emplace_back(param, TensorShape{1, 2, 2, 1}, + TensorShape{1, 2, 2, 2}, TensorShape{1, 2, 2, 1}); + + args.emplace_back(param, TensorShape{1, 2, 2, 3}, + TensorShape{1, 2, 2, 2}, TensorShape{1, 2, 2, 3}); + + args.emplace_back(param, TensorShape{1, 2, 2, 66}, + TensorShape{1, 2, 2, 2}, + TensorShape{1, 2, 2, 66}); + + args.emplace_back(param, TensorShape{1, 100, 100, 10}, + TensorShape{1, 100, 100, 2}, + TensorShape{1, 100, 100, 10}); + + args.emplace_back(param, TensorShape{2, 100, 200, 4}, + TensorShape{2, 100, 200, 2}, + TensorShape{2, 100, 200, 4}); + + args.emplace_back(param, TensorShape{2, 100, 200, 4}, + TensorShape{2, 20, 30, 2}, + TensorShape{2, 20, 30, 4}); + + args.emplace_back(param, TensorShape{2, 10, 10, 4}, + TensorShape{2, 20, 30, 2}, + TensorShape{2, 20, 30, 4}); + } + } + return args; +} + +} // namespace remap +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index bbddd96e..68902d2f 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -842,6 +842,30 @@ std::unique_ptr ConvertF32ToF16Pass::make( return new_warp.node()->owner_opr(); }; + auto replace_remap_opr = [](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size() && + (new_inp.size() == 2)); + auto& remap_opr = opr->cast_final(); + // map tensor must be float32 + auto new_map = new_inp[1]; + if (new_inp[1]->dtype() != dtype::Float32()) { + if (try_cast_as_op(new_map->owner_opr()) && + new_map->owner_opr()->input(0)->dtype() == dtype::Float32()) + new_map = new_map->owner_opr()->input(0); + else + new_map = + opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}).node(); + } + SymbolVar new_remap; + + new_remap = opr::Remap::make(new_inp[0], new_map, + remap_opr.param(), + remap_opr.config()); + return new_remap.node()->owner_opr(); + }; + + auto ret = std::make_unique(); // don't check dtype ret->set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^ @@ -855,6 +879,7 @@ std::unique_ptr ConvertF32ToF16Pass::make( replace_func[opr::ImmutableTensor::typeinfo()] = replace_imt_opr; replace_func[opr::TypeCvt::typeinfo()] = replace_cvt_opr; replace_func[opr::WarpPerspective::typeinfo()] = replace_warp_opr; + replace_func[opr::Remap::typeinfo()] = replace_remap_opr; return ret; #endif } diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 9ee21c1c..457941a0 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -693,6 +693,46 @@ TEST(TestGoptInference, Float16IOFloat32ComputeWarpPerspective) { MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); } +TEST(TestGoptInference, Float16IOFloat32ComputeRemap) { + auto cn = CompNode::load("cpu1"); + constexpr size_t INP_H = 10, INP_W = 10, N = 2; + HostTensorGenerator<> gen; + auto graph = ComputingGraph::make(); + auto mkvar = [&](const char* name, const TensorShape& shp) { + return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); + }; + graph->options().graph_opt_level = 0; + auto a = mkvar("a", {N, 4, INP_H, INP_W}); + auto gen_map = [&](HostTensorND& mat) { + auto ptr = mat.ptr(); + for(size_t n = 0; n < N; ++n){ + for(int h = 0; h < 5; ++h){ + for(int w = 0; w < 5; ++w){ + *ptr++ = (h * 5 * 2) + 5 * 2 + 0; + *ptr++ = (h * 5 * 2) + 5 * 2 + 1; + } + } + } + mgb_assert(ptr == mat.ptr() + mat.shape().total_nr_elems()); + }; + auto map_host = std::make_shared( + a.node()->comp_node(), TensorShape{N, 5, 5, 2}, dtype::Float32()); + gen_map(*map_host); + auto map = opr::Host2DeviceCopy::make(*graph, map_host).rename("map"); + auto y = opr::Remap::make(a, map); + SymbolVar y_opt; + unpack_vector(gopt::optimize_for_inference( + {y}, gopt::OptimizeForInferenceOptions{} + .enable_f16_io_f32_comp()), + y_opt); + ASSERT_EQ(y_opt.dtype(), dtype::Float32()); + HostTensorND host_y, host_y_opt; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); +} + TEST(TestGoptInference, Uint8IOFloat16ComputeWarpPerspective) { constexpr size_t INP_H = 10, INP_W = 10, N = 2; HostTensorGenerator gen_uint8; @@ -1987,7 +2027,7 @@ TEST(TestGoptInference, EnableCHWN4WarpPespective) { auto y = opr::ConvBiasForward::make( x, w, b, param, {}, OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); - + opr::WarpPerspective::Param warp_param; warp_param.format = opr::WarpPerspective::Param::Format::NCHW4; auto y1 = opr::WarpPerspective::make(y, mat_var, TensorShape{16, 16}, warp_param); diff --git a/src/opr/impl/imgproc.cpp b/src/opr/impl/imgproc.cpp index c2dc8fc4..e5c407f9 100644 --- a/src/opr/impl/imgproc.cpp +++ b/src/opr/impl/imgproc.cpp @@ -316,4 +316,13 @@ void WarpAffineForward::record_execute_deps(ExecDependencyArray &deps) { record_megdnn_opr(deps); } +/* ======================= RemapForward ======================= */ + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemapForward); +MEGDNN_OPR_INIT2(RemapForward, "remap") + +void RemapForward::init_output_dtype(){ + output(0)->dtype(input(0)->dtype()); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/imgproc.oprdecl b/src/opr/impl/imgproc.oprdecl index 0706e94d..38031ba5 100644 --- a/src/opr/impl/imgproc.oprdecl +++ b/src/opr/impl/imgproc.oprdecl @@ -79,4 +79,15 @@ decl_opr( 'for details on affine transformations.', version=1) +decl_opr( + 'Remap', + inputs=[ + Doc('src', 'input image, in NCHW format or NHWC format'), + Doc('map_xy', 'map matrix with NHWC format. C must euqal to 2. ' + 'dst(x, y) = src(mapX(x, y), mapY(x, y)' + 'col in channel 0, and row in channel 1')], + params='Remap', + desc='Remap transformation to batched 2D images; ' + 'see https://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=remap' + 'for details on remap transformations.') # vim: ft=python diff --git a/src/opr/impl/imgproc.sereg.h b/src/opr/impl/imgproc.sereg.h index bacf7daa..0b7e3bd6 100644 --- a/src/opr/impl/imgproc.sereg.h +++ b/src/opr/impl/imgproc.sereg.h @@ -50,6 +50,7 @@ namespace opr { MGB_SEREG_OPR(GaussianBlur, 1); MGB_SEREG_OPR(ResizeBackward, 2); + MGB_SEREG_OPR(Remap, 2); //! current warp affine version using WarpAffineV1 = opr::WarpAffine; diff --git a/src/opr/include/megbrain/opr/imgproc.h b/src/opr/include/megbrain/opr/imgproc.h index cbe6752c..18c4a03e 100644 --- a/src/opr/include/megbrain/opr/imgproc.h +++ b/src/opr/include/megbrain/opr/imgproc.h @@ -165,6 +165,21 @@ MGB_DEFINE_OPR_CLASS(ResizeBackward, const OperatorNodeConfig &config = {}); }; +MGB_DEFINE_OPR_CLASS(RemapForward, + intl::MegDNNOprWrapperFwd) // { + public: + RemapForward( + VarNode *in_tensor, VarNode* map, + const Param ¶m, const OperatorNodeConfig &config); + + static SymbolVar make(SymbolVar in_tensor, SymbolVar map, const Param ¶m = {}, + const OperatorNodeConfig &config = {}); + + private: + void init_output_dtype() override; +}; +using Remap = RemapForward; + /*! * \brief apply affine transformation to batched 2D images * diff --git a/src/opr/test/imgproc.cpp b/src/opr/test/imgproc.cpp index 44aebe3b..5a12ce69 100644 --- a/src/opr/test/imgproc.cpp +++ b/src/opr/test/imgproc.cpp @@ -636,4 +636,55 @@ TEST(TestOprImgproc, WarpAffineForward) { run({TensorShape{N, 10, 9, C}, {N, 2, 3}}, opt); } +TEST(TestOprImgproc, Remap_NCHW) { + constexpr size_t N = 2, C = 8; + + opr::Remap::Param param; + using Checker = AutoOprChecker<2, 1>; + TensorShape out_shp{N, C, 10, 10}; + param.format = opr::Remap::Param::Format::NCHW; + auto make_graph = [&](const Checker::SymInpArray &inputs) -> + Checker::SymOutArray { + return {opr::Remap::make(inputs[0], inputs[1], param)}; + }; + auto fwd = [&](Checker::NumOutArray &dest, Checker::NumInpArray inp) { + auto opr = megdnn_naive_handle()->create_operator(); + opr->param() = param; + dest[0].resize(out_shp); + opr->exec(inp[0]->as_megdnn(), inp[1]->as_megdnn(), dest[0].as_megdnn(), {}); + }; + + Checker::RunOptions opt; + Checker(make_graph, fwd, CompNode::load("cpu1")) + .disable_grad_check() + .run({TensorShape{N, C, 3, 20}, TensorShape{N, 10, 10, 2}}, opt) + .run({TensorShape{N, C, 6, 5}, TensorShape{N, 10, 10, 2}}, opt) + .run({TensorShape{N, C, 20, 20}, TensorShape{N, 10, 10, 2}}, opt); +} + +TEST(TestOprImgproc, Remap_NHWC) { + constexpr size_t N = 2, C = 8; + + opr::Remap::Param param; + using Checker = AutoOprChecker<2, 1>; + TensorShape out_shp{N, 10, 10, C}; + param.format = opr::Remap::Param::Format::NHWC; + auto make_graph = [&](const Checker::SymInpArray &inputs) -> + Checker::SymOutArray { + return {opr::Remap::make(inputs[0], inputs[1], param)}; + }; + auto fwd = [&](Checker::NumOutArray &dest, Checker::NumInpArray inp) { + auto opr = megdnn_naive_handle()->create_operator(); + opr->param() = param; + dest[0].resize(out_shp); + opr->exec(inp[0]->as_megdnn(), inp[1]->as_megdnn(), dest[0].as_megdnn(), {}); + }; + + Checker::RunOptions opt; + Checker(make_graph, fwd, CompNode::load("cpu1")) + .disable_grad_check() + .run({TensorShape{N, 3, 20, C}, TensorShape{N, 10, 10, 2}}, opt) + .run({TensorShape{N, 6, 5, C}, TensorShape{N, 10, 10, 2}}, opt) + .run({TensorShape{N, 20, 20, C}, TensorShape{N, 10, 10, 2}}, opt); +} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/serialization/impl/schema.fbs b/src/serialization/impl/schema.fbs index 12b730fe..acfc0d82 100644 --- a/src/serialization/impl/schema.fbs +++ b/src/serialization/impl/schema.fbs @@ -70,6 +70,7 @@ union OperatorParam { param.WarpAffine, param.GaussianBlur, param.Resize, + param.Remap, param.Convolution3D, param.Conv3DBias, param.SeparableConv3D,