Browse Source

fix(mgb/opr): add non finite check

GitOrigin-RevId: a9fcd0a350
release-1.7
Megvii Engine Team 3 years ago
parent
commit
f5cb21ed3a
22 changed files with 79 additions and 65 deletions
  1. +3
    -3
      dnn/include/megdnn/oprs/general.h
  2. +3
    -3
      dnn/src/common/check_non_finite.cpp
  3. +1
    -1
      dnn/src/common/handle_impl.h
  4. +1
    -1
      dnn/src/common/opr_trait.h
  5. +4
    -4
      dnn/src/common/reduce_helper.h
  6. +2
    -2
      dnn/src/cuda/check_non_finite/kern.cu
  7. +7
    -7
      dnn/src/cuda/check_non_finite/opr_impl.cpp
  8. +3
    -3
      dnn/src/cuda/check_non_finite/opr_impl.h
  9. +1
    -1
      dnn/src/cuda/handle_create.cpp
  10. +5
    -5
      dnn/src/naive/check_non_finite/opr_impl.cpp
  11. +3
    -3
      dnn/src/naive/check_non_finite/opr_impl.h
  12. +1
    -1
      dnn/src/naive/handle.cpp
  13. +7
    -3
      dnn/test/cuda/check_non_finite.cpp
  14. +9
    -3
      dnn/test/naive/check_non_finite.cpp
  15. +7
    -7
      imperative/python/megengine/amp/grad_scaler.py
  16. +3
    -3
      imperative/python/megengine/functional/math.py
  17. +7
    -3
      imperative/python/test/unit/functional/test_math.py
  18. +5
    -5
      imperative/src/impl/ops/misc.cpp
  19. +1
    -1
      src/core/include/megbrain/ir/ops.td
  20. +4
    -4
      src/opr/impl/misc.cpp
  21. +1
    -1
      src/opr/impl/misc.sereg.h
  22. +1
    -1
      src/opr/include/megbrain/opr/misc.h

+ 3
- 3
dnn/include/megdnn/oprs/general.h View File

@@ -1319,11 +1319,11 @@ protected:
}; };


/*! /*!
* \brief check whether input contains inf value.
* \brief check whether input contains inf or nan value.
*/ */
class CheckHasInf: public OperatorBase {
class CheckNonFinite: public OperatorBase {
DEF_OPR_PARAM(Empty); DEF_OPR_PARAM(Empty);
DEF_OPR_IMPL(CheckHasInf, OperatorBase, 1, 1);
DEF_OPR_IMPL(CheckNonFinite, OperatorBase, 1, 1);


public: public:
virtual size_t get_workspace_in_bytes(const TensorLayout &src, virtual size_t get_workspace_in_bytes(const TensorLayout &src,


dnn/src/common/check_has_inf.cpp → dnn/src/common/check_non_finite.cpp View File

@@ -1,5 +1,5 @@
/** /**
* \file dnn/src/common/check_has_inf.cpp
* \file dnn/src/common/check_non_finite.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -14,7 +14,7 @@


namespace megdnn { namespace megdnn {


void CheckHasInf::check_exec(const TensorLayout& src, const TensorLayout& dst,
void CheckNonFinite::check_exec(const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes) { size_t workspace_in_bytes) {
megdnn_assert_contiguous(src); megdnn_assert_contiguous(src);
megdnn_assert_contiguous(dst); megdnn_assert_contiguous(dst);
@@ -24,7 +24,7 @@ void CheckHasInf::check_exec(const TensorLayout& src, const TensorLayout& dst,
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
} }


void CheckHasInf::deduce_layout(const TensorLayout&, TensorLayout& dst) {
void CheckNonFinite::deduce_layout(const TensorLayout&, TensorLayout& dst) {
dst.shape[0] = 1; dst.shape[0] = 1;
dst.ndim = 1; dst.ndim = 1;
dst.dtype = dtype::Int32(); dst.dtype = dtype::Int32();

+ 1
- 1
dnn/src/common/handle_impl.h View File

@@ -216,7 +216,7 @@ private:
cb(FakeQuantBackward) \ cb(FakeQuantBackward) \
cb(TQTForward) \ cb(TQTForward) \
cb(TQTBackward) \ cb(TQTBackward) \
cb(CheckHasInf) \
cb(CheckNonFinite) \
cb(LSQForward) \ cb(LSQForward) \
cb(LSQBackward) \ cb(LSQBackward) \
cb(Fill) \ cb(Fill) \


+ 1
- 1
dnn/src/common/opr_trait.h View File

@@ -131,7 +131,7 @@ DEF(PermutationRNG, 1, true, true);
DEF(ShuffleRNGForward, 3, true, true); DEF(ShuffleRNGForward, 3, true, true);
DEF(ShuffleRNGBackward, 3, true, false); DEF(ShuffleRNGBackward, 3, true, false);
DEF(ChecksumForward, 1, true, false); DEF(ChecksumForward, 1, true, false);
DEF(CheckHasInf, 2, true, true);
DEF(CheckNonFinite, 2, true, true);
DEF(LSQForward, 5, true, true); DEF(LSQForward, 5, true, true);
DEF(LSQBackward, 7, true, false); DEF(LSQBackward, 7, true, false);
DEF(Fill, 1, true, false); DEF(Fill, 1, true, false);


+ 4
- 4
dnn/src/common/reduce_helper.h View File

@@ -152,7 +152,7 @@ struct MaxOp {
}; };


template <typename src_ctype, typename dst_ctype, typename wtype_> template <typename src_ctype, typename dst_ctype, typename wtype_>
struct CheckHasInfOp {
struct CheckNonFiniteOp {
typedef wtype_ wtype; typedef wtype_ wtype;
const wtype INIT; const wtype INIT;


@@ -162,9 +162,9 @@ struct CheckHasInfOp {


MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) {
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__)
return isinf(src[idx]);
return !isfinite(src[idx]);
#else #else
return std::isinf(src[idx]);
return !std::isfinite(src[idx]);
#endif #endif
} }
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) {
@@ -173,7 +173,7 @@ struct CheckHasInfOp {
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) {
return lhs | rhs; return lhs | rhs;
} }
MEGDNN_HOST MEGDNN_DEVICE CheckHasInfOp(src_ctype* src, dst_ctype* dst,
MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp(src_ctype* src, dst_ctype* dst,
size_t B) size_t B)
: INIT(wtype(0)), src(src), dst(dst), B(B) {} : INIT(wtype(0)), src(src), dst(dst), B(B) {}
}; };


dnn/src/cuda/check_has_inf/kern.cu → dnn/src/cuda/check_non_finite/kern.cu View File

@@ -1,5 +1,5 @@
/** /**
* \file dnn/src/cuda/check_has_inf/kern.cu
* \file dnn/src/cuda/check_non_finite/kern.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -18,7 +18,7 @@ namespace cuda {


#define COMMA , #define COMMA ,


INST_REDUCE(reduce::CheckHasInfOp<dt_float32 COMMA dt_int32 COMMA dt_int32>, false);
INST_REDUCE(reduce::CheckNonFiniteOp<dt_float32 COMMA dt_int32 COMMA dt_int32>, false);


#undef COMMA #undef COMMA
} // namespace cuda } // namespace cuda

dnn/src/cuda/check_has_inf/opr_impl.cpp → dnn/src/cuda/check_non_finite/opr_impl.cpp View File

@@ -1,5 +1,5 @@
/** /**
* \file dnn/src/cuda/check_has_inf/opr_impl.cpp
* \file dnn/src/cuda/check_non_finite/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -9,7 +9,7 @@
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#include "src/cuda/check_has_inf/opr_impl.h"
#include "src/cuda/check_non_finite/opr_impl.h"
#include "src/cuda/reduce_helper.cuh" #include "src/cuda/reduce_helper.cuh"


#include "src/cuda/handle.h" #include "src/cuda/handle.h"
@@ -20,18 +20,18 @@
namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {


using reduce::CheckHasInfOp;
using reduce::CheckNonFiniteOp;


size_t CheckHasInfImpl::get_workspace_in_bytes(const TensorLayout& src,
size_t CheckNonFiniteImpl::get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) { const TensorLayout& dst) {
typedef CheckHasInfOp<dt_float32, dt_int32, dt_int32> Op;
typedef CheckNonFiniteOp<dt_float32, dt_int32, dt_int32> Op;
return get_reduce_workspace_in_bytes<Op>(1, src.total_nr_elems(), 1); return get_reduce_workspace_in_bytes<Op>(1, src.total_nr_elems(), 1);
} }


void CheckHasInfImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
void CheckNonFiniteImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) { _megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size); check_exec(src.layout, dst.layout, workspace.size);
typedef CheckHasInfOp<dt_float32, dt_int32, dt_int32> Op;
typedef CheckNonFiniteOp<dt_float32, dt_int32, dt_int32> Op;
auto stream = cuda_stream(this->handle()); auto stream = cuda_stream(this->handle());
auto B = src.layout.total_nr_elems(); auto B = src.layout.total_nr_elems();
return run_reduce<Op, false>( return run_reduce<Op, false>(

dnn/src/cuda/check_has_inf/opr_impl.h → dnn/src/cuda/check_non_finite/opr_impl.h View File

@@ -1,5 +1,5 @@
/** /**
* \file dnn/src/cuda/check_has_inf/opr_impl.h
* \file dnn/src/cuda/check_non_finite/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -17,9 +17,9 @@
namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {


class CheckHasInfImpl final : public CheckHasInf {
class CheckNonFiniteImpl final : public CheckNonFinite {
public: public:
using CheckHasInf::CheckHasInf;
using CheckNonFinite::CheckNonFinite;


size_t get_workspace_in_bytes(const TensorLayout& src, size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) override; const TensorLayout& dst) override;

+ 1
- 1
dnn/src/cuda/handle_create.cpp View File

@@ -20,7 +20,7 @@
#include "src/cuda/batch_conv_bias/opr_impl.h" #include "src/cuda/batch_conv_bias/opr_impl.h"
#include "src/cuda/batch_normalization/opr_impl.h" #include "src/cuda/batch_normalization/opr_impl.h"
#include "src/cuda/batched_matrix_mul/opr_impl.h" #include "src/cuda/batched_matrix_mul/opr_impl.h"
#include "src/cuda/check_has_inf/opr_impl.h"
#include "src/cuda/check_non_finite/opr_impl.h"
#include "src/cuda/checksum/opr_impl.h" #include "src/cuda/checksum/opr_impl.h"
#include "src/cuda/concat/opr_impl.h" #include "src/cuda/concat/opr_impl.h"
#include "src/cuda/cond_take/opr_impl.h" #include "src/cuda/cond_take/opr_impl.h"


dnn/src/naive/check_has_inf/opr_impl.cpp → dnn/src/naive/check_non_finite/opr_impl.cpp View File

@@ -1,5 +1,5 @@
/** /**
* \file dnn/src/naive/check_has_inf/opr_impl.cpp
* \file dnn/src/naive/check_non_finite/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -9,7 +9,7 @@
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#include "src/naive/check_has_inf/opr_impl.h"
#include "src/naive/check_non_finite/opr_impl.h"


#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/naive/handle.h" #include "src/naive/handle.h"
@@ -27,7 +27,7 @@ void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) {
size_t mid = l + (r - l) / 2; size_t mid = l + (r - l) / 2;
return func(l, mid) | func(mid, r); return func(l, mid) | func(mid, r);
} else { } else {
return static_cast<wtype>(std::isinf(sptr[l]));
return static_cast<wtype>(!std::isfinite(sptr[l]));
} }
}; };


@@ -39,12 +39,12 @@ void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) {
namespace megdnn { namespace megdnn {
namespace naive { namespace naive {


size_t CheckHasInfImpl::get_workspace_in_bytes(const TensorLayout&,
size_t CheckNonFiniteImpl::get_workspace_in_bytes(const TensorLayout&,
const TensorLayout&) { const TensorLayout&) {
return 0; return 0;
} }


void CheckHasInfImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
void CheckNonFiniteImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) { _megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size); check_exec(src.layout, dst.layout, workspace.size);



dnn/src/naive/check_has_inf/opr_impl.h → dnn/src/naive/check_non_finite/opr_impl.h View File

@@ -1,5 +1,5 @@
/** /**
* \file dnn/src/naive/check_has_inf/opr_impl.h
* \file dnn/src/naive/check_non_finite/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -16,9 +16,9 @@
namespace megdnn { namespace megdnn {
namespace naive { namespace naive {


class CheckHasInfImpl final : public CheckHasInf {
class CheckNonFiniteImpl final : public CheckNonFinite {
public: public:
using CheckHasInf::CheckHasInf;
using CheckNonFinite::CheckNonFinite;


bool is_thread_safe() const override { return true; } bool is_thread_safe() const override { return true; }



+ 1
- 1
dnn/src/naive/handle.cpp View File

@@ -22,7 +22,7 @@
#include "src/naive/batch_conv_bias/opr_impl.h" #include "src/naive/batch_conv_bias/opr_impl.h"
#include "src/naive/batch_normalization/opr_impl.h" #include "src/naive/batch_normalization/opr_impl.h"
#include "src/naive/batched_matrix_mul/opr_impl.h" #include "src/naive/batched_matrix_mul/opr_impl.h"
#include "src/naive/check_has_inf/opr_impl.h"
#include "src/naive/check_non_finite/opr_impl.h"
#include "src/naive/checksum/opr_impl.h" #include "src/naive/checksum/opr_impl.h"
#include "src/naive/concat/opr_impl.h" #include "src/naive/concat/opr_impl.h"
#include "src/naive/cond_take/opr_impl.h" #include "src/naive/cond_take/opr_impl.h"


dnn/test/cuda/check_has_inf.cpp → dnn/test/cuda/check_non_finite.cpp View File

@@ -1,5 +1,5 @@
/** /**
* \file dnn/test/cuda/check_has_inf.cpp
* \file dnn/test/cuda/check_non_finite.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -15,16 +15,20 @@
namespace megdnn { namespace megdnn {
namespace test { namespace test {


TEST_F(CUDA, CHECK_HAS_INF_BASIC) {
Checker<CheckHasInf> checker(handle_cuda());
TEST_F(CUDA, CHECK_NON_FINITE_BASIC) {
Checker<CheckNonFinite> checker(handle_cuda());
checker.set_allow_invalid_check(true); checker.set_allow_invalid_check(true);
const auto inf = std::numeric_limits<float>::infinity(); const auto inf = std::numeric_limits<float>::infinity();
const auto nan = std::numeric_limits<float>::quiet_NaN();
UniformFloatWithValueRNG rng(-1.0f, 1.0f, 0.1f, inf); UniformFloatWithValueRNG rng(-1.0f, 1.0f, 0.1f, inf);
checker.set_rng(0, &rng); checker.set_rng(0, &rng);
checker.execs({{512*16}, {1}}); checker.execs({{512*16}, {1}});
rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, inf); rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, inf);
checker.set_rng(0, &rng); checker.set_rng(0, &rng);
checker.execs({{512*16}, {1}}); checker.execs({{512*16}, {1}});
rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, nan);
checker.set_rng(0, &rng);
checker.execs({{512*16}, {1}});
} }


} // namespace test } // namespace test

dnn/test/naive/check_has_inf.cpp → dnn/test/naive/check_non_finite.cpp View File

@@ -1,5 +1,5 @@
/** /**
* \file test/naive/check_has_inf.cpp
* \file test/naive/check_non_finite.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -17,8 +17,8 @@
namespace megdnn { namespace megdnn {
namespace test { namespace test {


TEST_F(NAIVE, CHECK_HAS_INF_BASIC) {
Checker<CheckHasInf> checker(handle(), false);
TEST_F(NAIVE, CHECK_NON_FINITE_BASIC) {
Checker<CheckNonFinite> checker(handle(), false);
checker.exect(Testcase{TensorValue({4}, dtype::Float32(), checker.exect(Testcase{TensorValue({4}, dtype::Float32(),
{1.1, 2.2, 3.3, 4.3}), {1.1, 2.2, 3.3, 4.3}),
{}}, {}},
@@ -29,6 +29,12 @@ TEST_F(NAIVE, CHECK_HAS_INF_BASIC) {
std::numeric_limits<float>::infinity()}), std::numeric_limits<float>::infinity()}),
{}}, {}},
Testcase{{}, TensorValue({1}, dtype::Int32(), {1})}); Testcase{{}, TensorValue({1}, dtype::Int32(), {1})});
checker.exect(
Testcase{TensorValue({4}, dtype::Float32(),
{1.1f, 2.2f, 3.3f,
std::numeric_limits<float>::quiet_NaN()}),
{}},
Testcase{{}, TensorValue({1}, dtype::Int32(), {1})});
} }


} // namespace test } // namespace test

+ 7
- 7
imperative/python/megengine/amp/grad_scaler.py View File

@@ -11,7 +11,7 @@ import numpy as np


from ..autodiff import GradManager from ..autodiff import GradManager
from ..functional import full_like from ..functional import full_like
from ..functional.math import _has_inf
from ..functional.math import _check_non_finite
from ..tensor import Tensor from ..tensor import Tensor




@@ -76,7 +76,7 @@ class GradScaler:
self.growth_interval = growth_interval self.growth_interval = growth_interval


self._growth_tracker = 0 self._growth_tracker = 0
self._found_inf = False
self._found_non_finite = False


def backward( def backward(
self, self,
@@ -135,10 +135,10 @@ class GradScaler:
continue continue
# to support tracing, _check_gradients should be applied to every grad. # to support tracing, _check_gradients should be applied to every grad.
if self._check_gradients(tensor.grad): if self._check_gradients(tensor.grad):
self._found_inf = True
self._found_non_finite = True
tensor.grad *= inv_scale tensor.grad *= inv_scale


if self._found_inf:
if self._found_non_finite:
for tensor in grad_tensors: for tensor in grad_tensors:
if tensor is None or getattr(tensor, "grad", None) is None: if tensor is None or getattr(tensor, "grad", None) is None:
continue continue
@@ -148,7 +148,7 @@ class GradScaler:
def _check_gradients(self, grad): def _check_gradients(self, grad):
if self.growth_interval == 0: if self.growth_interval == 0:
return False return False
return _has_inf(grad)
return _check_non_finite(grad)


def update(self, new_scale: float = None): def update(self, new_scale: float = None):
r"""Update the scale factor according to whether encountered overflow grad. r"""Update the scale factor according to whether encountered overflow grad.
@@ -160,7 +160,7 @@ class GradScaler:
if new_scale is not None: if new_scale is not None:
self.scale_factor = float(new_scale) self.scale_factor = float(new_scale)
else: else:
if self._found_inf:
if self._found_non_finite:
self.scale_factor *= self.backoff_factor self.scale_factor *= self.backoff_factor
self._growth_tracker = 0 self._growth_tracker = 0
else: else:
@@ -168,7 +168,7 @@ class GradScaler:
if self._growth_tracker >= self.growth_interval: if self._growth_tracker >= self.growth_interval:
self.scale_factor *= self.growth_factor self.scale_factor *= self.growth_factor
self._growth_tracker = 0 self._growth_tracker = 0
self._found_inf = False
self._found_non_finite = False


def state_dict(self): def state_dict(self):
return { return {


+ 3
- 3
imperative/python/megengine/functional/math.py View File

@@ -1181,8 +1181,8 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
return U, sigma, V return U, sigma, V




def _has_inf(inp: Tensor) -> Tensor:
r"""Check whether input contains infinite value.
def _check_non_finite(inp: Tensor) -> Tensor:
r"""Check whether input contains infinite or nan value.


Args: Args:
inp: a tensor to be checked. inp: a tensor to be checked.
@@ -1190,7 +1190,7 @@ def _has_inf(inp: Tensor) -> Tensor:
Returns: Returns:
a int32 scalar tensor, 0 for False and 1 for True. a int32 scalar tensor, 0 for False and 1 for True.
""" """
op = builtin.CheckHasInf()
op = builtin.CheckNonFinite()
(oup,) = apply(op, inp.reshape(-1).astype("float32")) (oup,) = apply(op, inp.reshape(-1).astype("float32"))
oup._setscalar() oup._setscalar()
return oup return oup

+ 7
- 3
imperative/python/test/unit/functional/test_math.py View File

@@ -185,14 +185,18 @@ def test_sum_neg_axis():
F.sum(tensor(data), axis=(-1, 1)) F.sum(tensor(data), axis=(-1, 1))




def test_has_inf():
def test_non_finite():
shape = (32, 3, 32, 32) shape = (32, 3, 32, 32)
data = np.random.random(shape).astype(np.float32) data = np.random.random(shape).astype(np.float32)
rst = F.math._has_inf(tensor(data))
rst = F.math._check_non_finite(tensor(data))
np.testing.assert_equal(rst.numpy(), [0]) np.testing.assert_equal(rst.numpy(), [0])


data[0][0][0][0] = float("inf") data[0][0][0][0] = float("inf")
rst = F.math._has_inf(tensor(data))
rst = F.math._check_non_finite(tensor(data))
np.testing.assert_equal(rst.numpy(), [1])

data[0][0][0][0] = float("nan")
rst = F.math._check_non_finite(tensor(data))
np.testing.assert_equal(rst.numpy(), [1]) np.testing.assert_equal(rst.numpy(), [1])






+ 5
- 5
imperative/src/impl/ops/misc.cpp View File

@@ -16,17 +16,17 @@
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {


namespace check_has_inf {
namespace check_non_finite {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<CheckHasInf>();
auto&& op = def.cast_final_safe<CheckNonFinite>();
mgb_assert(inputs.size() == 1); mgb_assert(inputs.size() == 1);
OperatorNodeConfig config{op.make_name()}; OperatorNodeConfig config{op.make_name()};
return opr::CheckHasInf::make(inputs[0], {}, config);
return opr::CheckNonFinite::make(inputs[0], {}, config);
} }
OP_TRAIT_REG(CheckHasInf, CheckHasInf)
OP_TRAIT_REG(CheckNonFinite, CheckNonFinite)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.fallback(); .fallback();
} // namespace check_has_inf
} // namespace check_non_finite


} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb


+ 1
- 1
src/core/include/megbrain/ir/ops.td View File

@@ -390,7 +390,7 @@ def CambriconRuntime: MgbHashableOp<"CambriconRuntime"> {


def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>; def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>;


def CheckHasInf: MgbHashableOp<"CheckHasInf", [EmptyParam]>;
def CheckNonFinite: MgbHashableOp<"CheckNonFinite", [EmptyParam]>;


def FastpathCopy: MgbHashableOp<"FastpathCopy">; def FastpathCopy: MgbHashableOp<"FastpathCopy">;




+ 4
- 4
src/opr/impl/misc.cpp View File

@@ -491,12 +491,12 @@ MGB_IMPL_OPR_GRAD(TopK) {
} }
#endif #endif


/* ================= CheckHasInf ================= */
/* ================= CheckNonFinite ================= */
namespace mgb { namespace mgb {
namespace opr { namespace opr {
namespace intl { namespace intl {
template<> template<>
struct MegDNNOprInitPostCtor<CheckHasInf> {
struct MegDNNOprInitPostCtor<CheckNonFinite> {
static void apply(cg::OperatorNodeBase &opr) { static void apply(cg::OperatorNodeBase &opr) {
opr.output(0)->dtype(dtype::Int32()); opr.output(0)->dtype(dtype::Int32());
} }
@@ -504,6 +504,6 @@ struct MegDNNOprInitPostCtor<CheckHasInf> {
} }
} }
} }
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CheckHasInf);
MEGDNN_OPR_INIT1(CheckHasInf, "check_has_inf")
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CheckNonFinite);
MEGDNN_OPR_INIT1(CheckNonFinite, "check_non_finite")
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 1
- 1
src/opr/impl/misc.sereg.h View File

@@ -72,7 +72,7 @@ namespace opr {
#if MGB_CUDA #if MGB_CUDA
MGB_SEREG_OPR(NvOf, 1); MGB_SEREG_OPR(NvOf, 1);
#endif #endif
MGB_SEREG_OPR(CheckHasInf, 1);
MGB_SEREG_OPR(CheckNonFinite, 1);


} // namespace opr } // namespace opr
} // namespace mgb } // namespace mgb


+ 1
- 1
src/opr/include/megbrain/opr/misc.h View File

@@ -185,7 +185,7 @@ public:
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };


MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(CheckHasInf);
MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(CheckNonFinite);


} // namespace opr } // namespace opr
} // namespace mgb } // namespace mgb


Loading…
Cancel
Save