From 5474b000a7686a36b23da6d0b8b5c48f78b4b1e8 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 9 Sep 2020 20:09:25 +0800 Subject: [PATCH] fix(mge/functional): fix convert_inputs before apply GitOrigin-RevId: ab41974a1f3f60f3261a5f34d76ff858c8ccd07b --- imperative/python/megengine/functional/nn.py | 7 +++++++ imperative/python/src/dispatcher.cpp | 4 ++++ 2 files changed, 11 insertions(+) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index a934a077..27e1b917 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -146,6 +146,7 @@ def conv2d( compute_mode=compute_mode, sparse=sparse_type, ) + inp, weight = utils.convert_inputs(inp, weight) (output,) = apply(op, inp, weight) if bias is not None: output += bias @@ -209,6 +210,7 @@ def conv_transpose2d( dilate_w=dilate_w, strategy=get_conv_execution_strategy(), ) + weight, inp = utils.convert_inputs(weight, inp) (output,) = apply(op, weight, inp) if bias is not None: output += bias @@ -243,6 +245,7 @@ def local_conv2d( dilate_w=dilate_w, # strategy=get_conv_execution_strategy(), ) + inp, weight = utils.convert_inputs(inp, weight) (output,) = apply(op, inp, weight) if bias is not None: output += bias @@ -900,6 +903,7 @@ def warp_perspective( op = builtin.WarpPerspective( imode=interp_mode, bmode=border_mode, format="NCHW", border_val=border_val ) + inp, M = utils.convert_inputs(inp, M) (result,) = apply(op, inp, M, Tensor(dsize)) return result @@ -1004,6 +1008,7 @@ def matmul( format=format, ) + inp1, inp2 = utils.convert_inputs(inp1, inp2) (result,) = apply(op, inp1, inp2) if shp is not None: result = result.reshape(shp) @@ -1327,6 +1332,7 @@ def roi_pooling( output_shape = (output_shape, output_shape) op = builtin.ROIPooling(mode=mode, scale=scale) + inp, rois = utils.convert_inputs(inp, rois) result, _ = apply( op, inp, rois, Tensor(output_shape, dtype="int32", device=inp.device) ) @@ -1374,6 +1380,7 @@ def roi_align( sample_height=sample_height, sample_width=sample_width, ) + input, rois = utils.convert_inputs(input, rois) result, *_ = apply(op, input, rois) return result diff --git a/imperative/python/src/dispatcher.cpp b/imperative/python/src/dispatcher.cpp index 79e93e26..2d2cd844 100644 --- a/imperative/python/src/dispatcher.cpp +++ b/imperative/python/src/dispatcher.cpp @@ -104,6 +104,10 @@ struct Dispatcher { auto& frame = stack.back(); auto& mro = *frame.mro; auto& i = frame.mro_offset; + if (!mro.size()) { + PyErr_SetString(PyExc_NotImplementedError, "function not registered in dispatcher"); + return nullptr; + } for (; i < mro.size(); ++i) { if (mro[i]->enabled) { auto ret = caller(mro[i]->func);