Browse Source

fix(mge/functional): fix convert_inputs before apply

GitOrigin-RevId: ab41974a1f
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
5474b000a7
2 changed files with 11 additions and 0 deletions
  1. +7
    -0
      imperative/python/megengine/functional/nn.py
  2. +4
    -0
      imperative/python/src/dispatcher.cpp

+ 7
- 0
imperative/python/megengine/functional/nn.py View File

@@ -146,6 +146,7 @@ def conv2d(
compute_mode=compute_mode, compute_mode=compute_mode,
sparse=sparse_type, sparse=sparse_type,
) )
inp, weight = utils.convert_inputs(inp, weight)
(output,) = apply(op, inp, weight) (output,) = apply(op, inp, weight)
if bias is not None: if bias is not None:
output += bias output += bias
@@ -209,6 +210,7 @@ def conv_transpose2d(
dilate_w=dilate_w, dilate_w=dilate_w,
strategy=get_conv_execution_strategy(), strategy=get_conv_execution_strategy(),
) )
weight, inp = utils.convert_inputs(weight, inp)
(output,) = apply(op, weight, inp) (output,) = apply(op, weight, inp)
if bias is not None: if bias is not None:
output += bias output += bias
@@ -243,6 +245,7 @@ def local_conv2d(
dilate_w=dilate_w, dilate_w=dilate_w,
# strategy=get_conv_execution_strategy(), # strategy=get_conv_execution_strategy(),
) )
inp, weight = utils.convert_inputs(inp, weight)
(output,) = apply(op, inp, weight) (output,) = apply(op, inp, weight)
if bias is not None: if bias is not None:
output += bias output += bias
@@ -900,6 +903,7 @@ def warp_perspective(
op = builtin.WarpPerspective( op = builtin.WarpPerspective(
imode=interp_mode, bmode=border_mode, format="NCHW", border_val=border_val imode=interp_mode, bmode=border_mode, format="NCHW", border_val=border_val
) )
inp, M = utils.convert_inputs(inp, M)
(result,) = apply(op, inp, M, Tensor(dsize)) (result,) = apply(op, inp, M, Tensor(dsize))
return result return result


@@ -1004,6 +1008,7 @@ def matmul(
format=format, format=format,
) )


inp1, inp2 = utils.convert_inputs(inp1, inp2)
(result,) = apply(op, inp1, inp2) (result,) = apply(op, inp1, inp2)
if shp is not None: if shp is not None:
result = result.reshape(shp) result = result.reshape(shp)
@@ -1327,6 +1332,7 @@ def roi_pooling(
output_shape = (output_shape, output_shape) output_shape = (output_shape, output_shape)


op = builtin.ROIPooling(mode=mode, scale=scale) op = builtin.ROIPooling(mode=mode, scale=scale)
inp, rois = utils.convert_inputs(inp, rois)
result, _ = apply( result, _ = apply(
op, inp, rois, Tensor(output_shape, dtype="int32", device=inp.device) op, inp, rois, Tensor(output_shape, dtype="int32", device=inp.device)
) )
@@ -1374,6 +1380,7 @@ def roi_align(
sample_height=sample_height, sample_height=sample_height,
sample_width=sample_width, sample_width=sample_width,
) )
input, rois = utils.convert_inputs(input, rois)
result, *_ = apply(op, input, rois) result, *_ = apply(op, input, rois)
return result return result




+ 4
- 0
imperative/python/src/dispatcher.cpp View File

@@ -104,6 +104,10 @@ struct Dispatcher {
auto& frame = stack.back(); auto& frame = stack.back();
auto& mro = *frame.mro; auto& mro = *frame.mro;
auto& i = frame.mro_offset; auto& i = frame.mro_offset;
if (!mro.size()) {
PyErr_SetString(PyExc_NotImplementedError, "function not registered in dispatcher");
return nullptr;
}
for (; i < mro.size(); ++i) { for (; i < mro.size(); ++i) {
if (mro[i]->enabled) { if (mro[i]->enabled) {
auto ret = caller(mro[i]->func); auto ret = caller(mro[i]->func);


Loading…
Cancel
Save