Browse Source

feat(mge/module): add python wrapper for unfold

GitOrigin-RevId: 562103186f
release-1.4
Megvii Engine Team 4 years ago
parent
commit
052a600f03
14 changed files with 287 additions and 37 deletions
  1. +1
    -1
      dnn/scripts/opr_param_defs.py
  2. +5
    -1
      dnn/src/common/images2neibs.cpp
  3. +20
    -17
      dnn/src/cuda/images2neibs/kernel.cu
  4. +2
    -2
      dnn/src/cuda/images2neibs/kernel.cuh
  5. +4
    -2
      dnn/src/cuda/images2neibs/opr_impl.cpp
  6. +17
    -8
      dnn/src/naive/images2neibs/opr_impl.cpp
  7. +10
    -6
      dnn/test/common/images2neibs.h
  8. +59
    -0
      dnn/test/naive/images2neibs.cpp
  9. +39
    -0
      imperative/python/megengine/functional/nn.py
  10. +1
    -0
      imperative/python/megengine/module/__init__.py
  11. +88
    -0
      imperative/python/megengine/module/sliding_window.py
  12. +25
    -0
      imperative/python/test/unit/functional/test_functional.py
  13. +14
    -0
      imperative/src/impl/ops/specializations.cpp
  14. +2
    -0
      src/core/include/megbrain/ir/ops.td

+ 1
- 1
dnn/scripts/opr_param_defs.py View File

@@ -220,7 +220,7 @@ pdef('Axis').add_fields('int32', 'axis', 0)

(pdef('Images2Neibs').
add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 1, 'stride_w', 1,
'window_h', 3, 'window_w', 3))
'dilate_h', 1, 'dilate_w', 1, 'window_h', 3, 'window_w', 3))

(pdef('Pooling', version=0, is_legacy=True).
add_enum(


+ 5
- 1
dnn/src/common/images2neibs.cpp View File

@@ -23,6 +23,8 @@ void Images2NeibsBase::deduce_layout_fwd(const TensorLayout &src,
"pad_w=" + std::to_string(param().pad_w) + ", " +
"stride_h=" + std::to_string(param().stride_h) + ", " +
"stride_w=" + std::to_string(param().stride_w) + ", " +
"dilate_h=" + std::to_string(param().dilate_h) + ", " +
"dilate_w=" + std::to_string(param().dilate_w) + ", " +
"window_h=" + std::to_string(param().window_h) + ", " +
"window_w=" + std::to_string(param().window_w);
};
@@ -34,11 +36,13 @@ void Images2NeibsBase::deduce_layout_fwd(const TensorLayout &src,
size_t pw = this->param().pad_w;
size_t sh = this->param().stride_h;
size_t sw = this->param().stride_w;
size_t dh = this->param().dilate_h;
size_t dw = this->param().dilate_w;
size_t wh = this->param().window_h;
size_t ww = this->param().window_w;
size_t oh, ow;

infer_conv_shape2d(ih, iw, wh, ww, sh, sw, ph, pw, oh, ow);
infer_conv_shape2d(ih, iw, wh+(wh-1)*(dh-1), ww+(ww-1)*(dw-1), sh, sw, ph, pw, oh, ow);
dst = TensorLayout(TensorShape({n, ic, oh, ow, wh, ww}), src.dtype);
}



+ 20
- 17
dnn/src/cuda/images2neibs/kernel.cu View File

@@ -24,7 +24,7 @@ namespace images2neibs {
template <typename T>
__global__ void forward_kernel(const T *src, T *dst,
int N, int C, int IH, int IW, int OH, int OW,
int ph, int pw, int sh, int sw, int WH, int WW)
int ph, int pw, int sh, int sw, int dh, int dw, int WH, int WW)
{
int NC = N * C;
int WP = WH*WW;
@@ -37,8 +37,8 @@ __global__ void forward_kernel(const T *src, T *dst,
if (op < OH * OW) {
int oh = op / OW;
int ow = op % OW;
int ih = -ph + sh * oh + wh;
int iw = -pw + sw * ow + ww;
int ih = -ph + sh * oh + wh* dh;
int iw = -pw + sw * ow + ww* dw;
int dst_pos = nc * OH * OW * WH * WW + op * WH * WW + wp;
int src_pos = nc * IH * IW + ih * IW + iw;
dst[dst_pos] = (ih >= 0 && ih < IH && iw >= 0 && iw < IW)
@@ -52,7 +52,7 @@ __global__ void forward_kernel(const T *src, T *dst,

template <typename T>
void forward(const T* src, T* dst, int N, int C, int IH, int IW, int OH, int OW,
int ph, int pw, int sh, int sw, int wh, int ww,
int ph, int pw, int sh, int sw, int dh, int dw, int wh, int ww,
cudaStream_t stream) {
int spatial_size = OH * OW;
int kernel_size = wh * ww;
@@ -63,7 +63,7 @@ void forward(const T* src, T* dst, int N, int C, int IH, int IW, int OH, int OW,
int by = N * C;

forward_kernel<<<dim3(bx, std::min(grid_y_max, by)), dim3(tx, ty), 0,
stream>>>(src, dst, N, C, IH, IW, OH, OW, ph, pw, sh, sw,
stream>>>(src, dst, N, C, IH, IW, OH, OW, ph, pw, sh, sw, dh, dw,
wh, ww);
after_kernel_launch();
}
@@ -73,7 +73,7 @@ void forward(const T* src, T* dst, int N, int C, int IH, int IW, int OH, int OW,
template <typename T>
__global__ void backward_kernel(const T *diff, T *grad,
int N, int C, int IH, int IW, int OH, int OW,
int ph, int pw, int sh, int sw, int WH, int WW)
int ph, int pw, int sh, int sw, int dh, int dw, int WH, int WW)
{
int id = threadIdx.x + blockIdx.x * blockDim.x;
if (id < N*C*IH*IW) {
@@ -82,17 +82,20 @@ __global__ void backward_kernel(const T *diff, T *grad,
int iw = id % (IH*IW) % IW;
grad[nc*IH*IW + ih*IW + iw] = 0.0f;
int oh_max = min((ih+ph) / sh, OH-1);
int oh_min = max((ih+ph-(WH-1)+sh-1) / sh, 0);
int oh_min = max((ih+ph-(WH-1)*dh+sh-1) / sh, 0);
int ow_max = min((iw+pw) / sw, OW-1);
int ow_min = max((iw+pw-(WW-1)+sw-1) / sw, 0);
int ow_min = max((iw+pw-(WW-1)*dw+sw-1) / sw, 0);
for (int oh = oh_min; oh <= oh_max; ++oh)
for (int ow = ow_min; ow <= ow_max; ++ow)
{
int wh = ih+ph - sh*oh;
int ww = iw+pw - sw*ow;
grad[nc*IH*IW + ih*IW + iw] +=
diff[nc*OH*OW*WH*WW + oh*OW*WH*WW + ow*WH*WW +
wh*WW + ww];
if ((ih+ph - sh*oh)%dh==0 && (iw+pw - sw*ow)%dw==0){
int wh = ih+ph - sh*oh - (ih+ph - sh*oh)/dh * (dh-1);
int ww = iw+pw - sw*ow - (iw+pw - sw*ow)/dw * (dw-1);
grad[nc*IH*IW + ih*IW + iw] +=
diff[nc*OH*OW*WH*WW + oh*OW*WH*WW + ow*WH*WW +
wh*WW + ww];

}
}
}
}
@@ -100,23 +103,23 @@ __global__ void backward_kernel(const T *diff, T *grad,
template <typename T>
void backward(const T *diff, T *grad,
int N, int C, int IH, int IW, int OH, int OW,
int ph, int pw, int sh, int sw, int wh, int ww,
int ph, int pw, int sh, int sw, int dh, int dw, int wh, int ww,
cudaStream_t stream)
{
int threads = NR_THREADS;
int blocks = DIVUP(N*C*IH*IW, threads);
backward_kernel<<<blocks, threads, 0, stream>>>(diff, grad,
N, C, IH, IW, OH, OW,
ph, pw, sh, sw, wh, ww);
ph, pw, sh, sw, dh, dw, wh, ww);
after_kernel_launch();
}

#define INST(T) \
template void forward<T>(const T *, T *, int, int, int, int, int, int, \
int, int, int, int, int, int, \
int, int, int, int, int, int, int, int, \
cudaStream_t); \
template void backward<T>(const T *, T *, int, int, int, int, int, int, \
int, int, int, int, int, int, \
int, int, int, int, int, int, int, int, \
cudaStream_t);
#define cb(DType) \
INST(DTypeTrait<DType>::ctype)


+ 2
- 2
dnn/src/cuda/images2neibs/kernel.cuh View File

@@ -18,13 +18,13 @@ namespace images2neibs {
template <typename T>
void forward(const T *src, T *dst,
int N, int C, int IH, int IW, int OH, int OW,
int ph, int pw, int sh, int sw, int wh, int ww,
int ph, int pw, int sh, int sw, int dh, int dw, int wh, int ww,
cudaStream_t stream);

template <typename T>
void backward(const T *diff, T *grad,
int N, int C, int IH, int IW, int OH, int OW,
int ph, int pw, int sh, int sw, int wh, int ww,
int ph, int pw, int sh, int sw, int dh, int dw, int wh, int ww,
cudaStream_t stream);
} // namespace images2neibs


+ 4
- 2
dnn/src/cuda/images2neibs/opr_impl.cpp View File

@@ -27,13 +27,14 @@ void Images2NeibsForwardImpl::exec(_megdnn_tensor_in src,
int OH = dst.layout[2], OW = dst.layout[3];
int ph = param().pad_h, pw = param().pad_w;
int sh = param().stride_h, sw = param().stride_w;
int dh = param().dilate_h, dw = param().dilate_w;
int wh = param().window_h, ww = param().window_w;
#define cb(DType) \
if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
using T = DTypeTrait<DType>::ctype; \
images2neibs::forward(src.ptr<T>(), dst.ptr<T>(), \
N, C, IH, IW, OH, OW, \
ph, pw, sh, sw, wh, ww, \
ph, pw, sh, sw, dh, dw, wh, ww, \
stream); \
return; \
}
@@ -53,13 +54,14 @@ void Images2NeibsBackwardImpl::exec(_megdnn_tensor_in diff,
int OH = diff.layout[2], OW = diff.layout[3];
int ph = param().pad_h, pw = param().pad_w;
int sh = param().stride_h, sw = param().stride_w;
int dh = param().dilate_h, dw = param().dilate_w;
int wh = param().window_h, ww = param().window_w;
#define cb(DType) \
if (diff.layout.dtype == DType()) { \
using T = DTypeTrait<DType>::ctype; \
images2neibs::backward(diff.ptr<T>(), grad.ptr<T>(), \
N, C, IH, IW, OH, OW, \
ph, pw, sh, sw, wh, ww, \
ph, pw, sh, sw, dh, dw, wh, ww, \
stream); \
return; \
}


+ 17
- 8
dnn/src/naive/images2neibs/opr_impl.cpp View File

@@ -33,20 +33,25 @@ void Images2NeibsForwardImpl::exec_internal(_megdnn_tensor_in src,
int pad_w = static_cast<int>(param().pad_w);
int stride_h = static_cast<int>(param().stride_h);
int stride_w = static_cast<int>(param().stride_w);
int dilate_h = static_cast<int>(param().dilate_h);
int dilate_w = static_cast<int>(param().dilate_w);
int equ_window_h = dilate_h * (window_h-1) + 1;
int equ_window_w = dilate_w * (window_w-1) + 1;
for (int n = 0; n < N; ++n)
for (int c = 0; c < C; ++c)
{
int ih = -pad_h;
for (; ih+window_h <= IH+pad_h; ih += stride_h) {
for (; ih+equ_window_h <= IH+pad_h; ih += stride_h) {
int iw = -pad_w;
for (; iw+window_w <= IW+pad_w; iw += stride_w) {
for (; iw+equ_window_w <= IW+pad_w; iw += stride_w) {
for (int kh = 0; kh < window_h; ++kh)
for (int kw = 0; kw < window_w; ++kw)
{
int ih2 = ih+dilate_h*kh, iw2 = iw+dilate_w*kw;
dptr[idx*window_h*window_w + kh*window_w + kw] =
(ih+kh) >= 0 && (ih+kh) < IH &&
(iw+kw) >= 0 && (iw+kw) < IW ?
sptr[n*C*IH*IW + c*IH*IW + (ih+kh)*IW + (iw+kw)] : 0.0f;
ih2 >= 0 && ih2 < IH &&
iw2 >= 0 && iw2 < IW ?
sptr[n*C*IH*IW + c*IH*IW + ih2*IW + iw2] : 0.0f;
}
++idx;
}
@@ -86,18 +91,22 @@ void Images2NeibsBackwardImpl::exec_internal(_megdnn_tensor_in diff,
int pad_w = static_cast<int>(param().pad_w);
int stride_h = static_cast<int>(param().stride_h);
int stride_w = static_cast<int>(param().stride_w);
int dilate_h = static_cast<int>(param().dilate_h);
int dilate_w = static_cast<int>(param().dilate_w);
int equ_window_h = dilate_h * (window_h-1) + 1;
int equ_window_w = dilate_w * (window_w-1) + 1;
memset(sptr, 0, sizeof(T) * N*C*IH*IW);
for (int n = 0; n < N; ++n)
for (int c = 0; c < C; ++c)
{
int ih = -pad_h;
for (; ih+window_h <= IH+pad_h; ih += stride_h) {
for (; ih+equ_window_h <= IH+pad_h; ih += stride_h) {
int iw = -pad_w;
for (; iw+window_w <= IW+pad_w; iw += stride_w) {
for (; iw+equ_window_w <= IW+pad_w; iw += stride_w) {
for (int kh = 0; kh < window_h; ++kh)
for (int kw = 0; kw < window_w; ++kw)
{
int ih2 = ih+kh, iw2 = iw+kw;
int ih2 = ih+dilate_h*kh, iw2 = iw+dilate_w*kw;
if (ih2 >= 0 && ih2 < IH && iw2 >= 0 && iw2 < IW) {
sptr[n*C*IH*IW + c*IH*IW + ih2*IW + iw2] +=
dptr[idx*window_h*window_w + kh*window_w + kw];


+ 10
- 6
dnn/test/common/images2neibs.h View File

@@ -31,17 +31,19 @@ inline std::vector<TestArg> get_args() {
for (uint32_t pw : {0, 1})
for (uint32_t sh : {1, 2})
for (uint32_t sw : {1, 2})
for (uint32_t dh : {1, 2, 3})
for (uint32_t dw : {1, 2, 3})
for (uint32_t wh : {3, 4})
for (uint32_t ww : {3, 4}) {
args.emplace_back(param::Images2Neibs{ph, pw, sh, sw, wh, ww},
TensorShape{2, 3, 5, 6});
args.emplace_back(param::Images2Neibs{ph, pw, sh, sw, dh, dw, wh, ww},
TensorShape{2, 3, 19, 20});
}
// clang-format on
// large window case
args.emplace_back(param::Images2Neibs{0, 0, 1, 1, 32, 64},
args.emplace_back(param::Images2Neibs{0, 0, 1, 1, 1, 1, 32, 64},
TensorShape{2, 3, 96, 128});
// large size
args.emplace_back(param::Images2Neibs{0, 0, 1, 1, 1, 1},
args.emplace_back(param::Images2Neibs{0, 0, 1, 1, 1, 1, 1, 1},
TensorShape{128, 128, 28, 24});

return args;
@@ -54,17 +56,19 @@ inline std::vector<TestArg> get_benchmark_args() {
for (uint32_t pw : {0, 1})
for (uint32_t sh : {1, 2})
for (uint32_t sw : {1, 2})
for (uint32_t dh : {1, 2})
for (uint32_t dw : {1, 2})
for (uint32_t wh : {3, 4})
for (uint32_t ww : {3, 4})
for (uint32_t b : {1, 64})
for (uint32_t c : {64, 128})
for (uint32_t hw : {64, 128}) {
args.emplace_back(param::Images2Neibs{ph, pw, sh, sw, wh, ww},
args.emplace_back(param::Images2Neibs{ph, pw, sh, sw, dh, dw, wh, ww},
TensorShape{b, c, hw, hw});
}
// clang-format on
// large size
args.emplace_back(param::Images2Neibs{0, 0, 1, 1, 1, 1},
args.emplace_back(param::Images2Neibs{0, 0, 1, 1, 1, 1, 1, 1},
TensorShape{1024, 128, 28, 24});

return args;


+ 59
- 0
dnn/test/naive/images2neibs.cpp View File

@@ -0,0 +1,59 @@
/**
* \file dnn/test/naive/images2neibs.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "test/naive/fixture.h"

#include "megdnn/oprs/nn.h"
#include "test/common/checker.h"

using namespace megdnn;
using namespace test;

TEST_F(NAIVE, IMAGES2NEIBS_FORWARD) {
Checker<Images2Neibs> checker(handle(), /* check_dispatch */false);
Images2Neibs::Param param(0,0,1,1,1,1,2,2);
checker.set_param(param).exect(
Testcase{TensorValue({1, 1, 3, 3}, dtype::Uint8(),
{0,1,2,
3,4,5,
6,7,8}), {}},
Testcase{{},
TensorValue({1, 1, 2, 2, 2, 2}, dtype::Uint8(),
{0,1,3,4,
1,2,4,5,
3,4,6,7,
4,5,7,8})});
param.pad_h = 1;
param.pad_w = 1;
param.stride_h = 2;
param.stride_w = 2;
param.dilate_h = 2;
param.dilate_w = 2;
param.window_h = 3;
param.window_w = 3;
checker.set_param(param).exect(
Testcase{TensorValue({1, 1, 6, 7}, dtype::Uint8(),
{0,1,2,3,4,5,6,
7,8,9,10,11,12,13,
14,15,16,17,18,19,20,
21,22,23,24,25,26,27,
28,29,30,31,32,33,34,
35,36,37,38,39,40,41}), {}},
Testcase{{},
TensorValue({1, 1, 2, 3, 3, 3}, dtype::Uint8(),
{0,0,0,0,8,10,0,22,24,
0,0,0,8,10,12,22,24,26,
0,0,0,10,12,0,24,26,0,
0,8,10,0,22,24,0,36,38,
8,10,12,22,24,26,36,38,40,
10,12,0,24,26,0,38,40,0})});
}

+ 39
- 0
imperative/python/megengine/functional/nn.py View File

@@ -70,6 +70,7 @@ __all__ = [
"remap",
"resize",
"sigmoid",
"sliding_window",
"softmax",
"softplus",
"sync_batch_norm",
@@ -1353,6 +1354,44 @@ def indexing_one_hot(
return result


def sliding_window(
inp: Tensor,
kernel_size: Union[int, Tuple[int, int]],
padding: Union[int, Tuple[int, int]] = 0,
stride: Union[int, Tuple[int, int]] = 1,
dilation: Union[int, Tuple[int, int]] = 1,
) -> Tensor:
"""
Extracts sliding local blocks from a batched input tensor.

Refer to :class:`~.SlidingWindow` for more information.

:param inp: input tensor.
:param kernel_size: size of the window.
:param padding: implicit zero padding added on both sides of input. Default: 0
:param stride: stride of the window. Default: 1
:param dilation: dilation of the window. Default: 1
:return: output tensor.
"""
padding_h, padding_w = _pair(padding)
stride_h, stride_w = _pair_nonzero(stride)
dilation_h, dilation_w = _pair_nonzero(dilation)
window_h, window_w = _pair_nonzero(kernel_size)

op = builtin.Images2Neibs(
pad_h=padding_h,
pad_w=padding_w,
stride_h=stride_h,
stride_w=stride_w,
dilate_h=dilation_h,
dilate_w=dilation_w,
window_h=window_h,
window_w=window_w,
)
(output,) = apply(op, inp)
return output


interpolate = deprecated_func("1.3", "megengine.functional.vision", "interpolate", True)
roi_pooling = deprecated_func("1.3", "megengine.functional.vision", "roi_pooling", True)
roi_align = deprecated_func("1.3", "megengine.functional.vision", "roi_align", True)


+ 1
- 0
imperative/python/megengine/module/__init__.py View File

@@ -34,3 +34,4 @@ from .normalization import GroupNorm, InstanceNorm, LayerNorm
from .pooling import AvgPool2d, MaxPool2d
from .quant_dequant import DequantStub, QuantStub
from .sequential import Sequential
from .sliding_window import SlidingWindow

+ 88
- 0
imperative/python/megengine/module/sliding_window.py View File

@@ -0,0 +1,88 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Tuple, Union

from ..functional import sliding_window
from .module import Module


class SlidingWindow(Module):
r"""
Apply a sliding window to input tensor and copy content in the window to
corresponding output location. Assume input shape is :math:`(N, C, IH, IW)`,
then output shape would be :math:`(N, C, OH, OW, window_h, window_w)` where
:math:`(OH, OW)` would be computed from padding, stride, window and
:math:`(IH, IW)`, as in convolution. For each output location, we have;

.. math::

out_{n, c, oh, ow, wh, ww} &= src_{n, c, ih+wh, iw+ww} \\
\text{where } & ih=-pad_h+oh \times stride_h + (wh-1) \times (dilation_h-1) \\
& iw=-pad_w+ow \times stride_w + (ww-1) \times (dilation_w-1)


:param kernel_size: the size of the window to take a max over.
:param padding: implicit zero padding to be added on both sides. Default: 0
:param stride: the stride of the window. Default: 1
:param dilation: the dilation of the window. Default: 1

Example:

.. testcode::

from megengine import tensor
import megengine.module as M
import numpy as np

inp = tensor(np.arange(30).reshape(1,1,5,6))
op = M.SlidingWindow(kernel_size=3, padding=1, stride=2, dilation=2)
out = op(inp)
print(out.numpy())

Outputs:

.. testoutput::

[[[[[[ 0 0 0]
[ 0 7 9]
[ 0 19 21]]

[[ 0 0 0]
[ 7 9 11]
[19 21 23]]]


[[[ 0 7 9]
[ 0 19 21]
[ 0 0 0]]

[[ 7 9 11]
[19 21 23]
[ 0 0 0]]]]]]

"""

def __init__(
self,
kernel_size: Union[int, Tuple[int, int]],
padding: Union[int, Tuple[int, int]] = 0,
stride: Union[int, Tuple[int, int]] = 1,
dilation: Union[int, Tuple[int, int]] = 1,
**kwargs
):
super(SlidingWindow, self).__init__(**kwargs)
self.kernel_size = kernel_size
self.padding = padding
self.stride = stride
self.dilation = dilation

def forward(self, inp):
return sliding_window(
inp, self.kernel_size, self.padding, self.stride, self.dilation
)

+ 25
- 0
imperative/python/test/unit/functional/test_functional.py View File

@@ -927,3 +927,28 @@ def test_neg_axis():
y = F.argmin(x, axis=(-1, -2))
yy = F.argmin(x, axis=(0, 1))
np.testing.assert_equal(y.numpy(), yy.numpy())


def test_sliding_window():
N, C, H, W = 2, 3, 7, 8
inp = np.random.normal(size=(N, C, H, W))
ph, pw = 1, 2
sh, sw = 2, 1
wh, ww = 3, 2
dh, dw = 1, 3
s = lambda i, p, s, d, w: (i + p * 2 - (w - 1) * d - 1) // s + 1
inp_pad = np.zeros((N, C, H + ph * 2, W + pw * 2))
inp_pad[:, :, ph : H + ph, pw : W + pw] = inp
gt_out = np.empty(
(N, C, s(H, ph, sh, dh, wh), s(W, pw, sw, dw, ww), wh, ww), dtype=np.float32
)
for n, c, oh, ow in itertools.product(*map(range, gt_out.shape[:4])):
ih, iw = oh * sh, ow * sw
gt_out[n, c, oh, ow, :] = inp_pad[
n, c, ih : ih + (wh - 1) * dh + 1 : dh, iw : iw + (ww - 1) * dw + 1 : dw
]

out = F.sliding_window(
tensor(inp), (wh, ww), padding=(ph, pw), stride=(sh, sw), dilation=(dh, dw)
)
np.testing.assert_equal(gt_out, out.numpy())

+ 14
- 0
imperative/src/impl/ops/specializations.cpp View File

@@ -32,6 +32,7 @@
#include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/opr/dnn/images2neibs.h"

#include "../op_trait.h"

@@ -652,4 +653,17 @@ OP_TRAIT_REG(SVD, SVD)
.fallback();
}} // svd

namespace { namespace images2neibs {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const Images2Neibs&>(def);
OperatorNodeConfig config{op.make_name()};
return opr::Images2Neibs::make(inputs[0], op.param(), config);
}
OP_TRAIT_REG(Images2Neibs, Images2Neibs)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // images2neibs

} // namespace mgb::imperative

+ 2
- 0
src/core/include/megbrain/ir/ops.td View File

@@ -79,6 +79,8 @@ def BatchConvBias : MgbHashableOp<"BatchConvBias", [BatchConvBiasParam, Executio
);
}

def Images2Neibs : MgbHashableOp<"Images2Neibs", [Images2NeibsParam]>;

def BatchNorm : MgbHashableOp<"BatchNorm", [BNParam]>;

def ROIAlign: MgbHashableOp<"ROIAlign", [ROIAlignParam]>;


Loading…
Cancel
Save