Browse Source

test(imperative): test autodiff.Function with non tensor arguments

GitOrigin-RevId: 6114f48d21
tags/v1.6.0-rc1
Megvii Engine Team 3 years ago
parent
commit
2950dd8d69
2 changed files with 17 additions and 2 deletions
  1. +15
    -0
      imperative/python/test/unit/core/test_function.py
  2. +2
    -2
      imperative/src/impl/op_trait.h

+ 15
- 0
imperative/python/test/unit/core/test_function.py View File

@@ -8,6 +8,7 @@
import copy

import numpy as np
import pytest

import megengine.autodiff as ad
import megengine.functional as F
@@ -303,3 +304,17 @@ def test_zero_grad():
np.testing.assert_almost_equal(
net.a.numpy(), np.array([1.0 - 4.0], dtype=np.float32),
)


def test_throw_on_non_tensor_argument():
class NonTensorArg(Function):
def forward(self, inp, c):
return inp + c

def backward(self, grad):
return grad

x = tensor(np.array([2.33], dtype=np.float32))
func = NonTensorArg()
with pytest.raises(TypeError, match=r"op .* expect type Tensor as inputs"):
func(x, 1)

+ 2
- 2
imperative/src/impl/op_trait.h View File

@@ -108,7 +108,7 @@ struct OpMethNotImpl {
struct OpMethFallback : public OpMethNotImpl {
using OpMethNotImpl::impl;
static void impl(ApplyOnPhysicalTensor& func,
op_meth_tag::ApplyOnPhysicalTensor);
op_meth_tag::ApplyOnPhysicalTensor);
static void impl(Execute& func, op_meth_tag::Execute);
static void impl(InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc);
static void impl(InferOutputAttrsFallible& func,
@@ -120,9 +120,9 @@ struct OpMethFallback : public OpMethNotImpl {
template <typename Tag, typename RType, typename... Args>
struct OpMeth<Tag, RType(Args...)> : public thin_function<RType(Args...)> {
using Base = thin_function<RType(Args...)>;
using Base::operator bool;
OpMeth() : Base{}, allow_fallback(false){};
explicit OpMeth(const Base& base) { this->Base::operator=(base); }
using Base::operator bool;
RType operator()(Args... args) const {
if (!this->Base::operator bool()) {
if (allow_fallback) {


Loading…
Cancel
Save