@@ -1319,11 +1319,11 @@ protected: | |||||
}; | }; | ||||
/*! | /*! | ||||
* \brief check whether input contains inf value. | |||||
* \brief check whether input contains inf or nan value. | |||||
*/ | */ | ||||
class CheckHasInf: public OperatorBase { | |||||
class CheckNonFinite: public OperatorBase { | |||||
DEF_OPR_PARAM(Empty); | DEF_OPR_PARAM(Empty); | ||||
DEF_OPR_IMPL(CheckHasInf, OperatorBase, 1, 1); | |||||
DEF_OPR_IMPL(CheckNonFinite, OperatorBase, 1, 1); | |||||
public: | public: | ||||
virtual size_t get_workspace_in_bytes(const TensorLayout &src, | virtual size_t get_workspace_in_bytes(const TensorLayout &src, | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/common/check_has_inf.cpp | |||||
* \file dnn/src/common/check_non_finite.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -14,7 +14,7 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
void CheckHasInf::check_exec(const TensorLayout& src, const TensorLayout& dst, | |||||
void CheckNonFinite::check_exec(const TensorLayout& src, const TensorLayout& dst, | |||||
size_t workspace_in_bytes) { | size_t workspace_in_bytes) { | ||||
megdnn_assert_contiguous(src); | megdnn_assert_contiguous(src); | ||||
megdnn_assert_contiguous(dst); | megdnn_assert_contiguous(dst); | ||||
@@ -24,7 +24,7 @@ void CheckHasInf::check_exec(const TensorLayout& src, const TensorLayout& dst, | |||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | ||||
} | } | ||||
void CheckHasInf::deduce_layout(const TensorLayout&, TensorLayout& dst) { | |||||
void CheckNonFinite::deduce_layout(const TensorLayout&, TensorLayout& dst) { | |||||
dst.shape[0] = 1; | dst.shape[0] = 1; | ||||
dst.ndim = 1; | dst.ndim = 1; | ||||
dst.dtype = dtype::Int32(); | dst.dtype = dtype::Int32(); |
@@ -216,7 +216,7 @@ private: | |||||
cb(FakeQuantBackward) \ | cb(FakeQuantBackward) \ | ||||
cb(TQTForward) \ | cb(TQTForward) \ | ||||
cb(TQTBackward) \ | cb(TQTBackward) \ | ||||
cb(CheckHasInf) \ | |||||
cb(CheckNonFinite) \ | |||||
cb(LSQForward) \ | cb(LSQForward) \ | ||||
cb(LSQBackward) \ | cb(LSQBackward) \ | ||||
cb(Fill) \ | cb(Fill) \ | ||||
@@ -131,7 +131,7 @@ DEF(PermutationRNG, 1, true, true); | |||||
DEF(ShuffleRNGForward, 3, true, true); | DEF(ShuffleRNGForward, 3, true, true); | ||||
DEF(ShuffleRNGBackward, 3, true, false); | DEF(ShuffleRNGBackward, 3, true, false); | ||||
DEF(ChecksumForward, 1, true, false); | DEF(ChecksumForward, 1, true, false); | ||||
DEF(CheckHasInf, 2, true, true); | |||||
DEF(CheckNonFinite, 2, true, true); | |||||
DEF(LSQForward, 5, true, true); | DEF(LSQForward, 5, true, true); | ||||
DEF(LSQBackward, 7, true, false); | DEF(LSQBackward, 7, true, false); | ||||
DEF(Fill, 1, true, false); | DEF(Fill, 1, true, false); | ||||
@@ -152,7 +152,7 @@ struct MaxOp { | |||||
}; | }; | ||||
template <typename src_ctype, typename dst_ctype, typename wtype_> | template <typename src_ctype, typename dst_ctype, typename wtype_> | ||||
struct CheckHasInfOp { | |||||
struct CheckNonFiniteOp { | |||||
typedef wtype_ wtype; | typedef wtype_ wtype; | ||||
const wtype INIT; | const wtype INIT; | ||||
@@ -162,9 +162,9 @@ struct CheckHasInfOp { | |||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { | MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { | ||||
#if defined(__CUDA_ARCH__) | #if defined(__CUDA_ARCH__) | ||||
return isinf(src[idx]); | |||||
return !isfinite(src[idx]); | |||||
#else | #else | ||||
return std::isinf(src[idx]); | |||||
return !std::isfinite(src[idx]); | |||||
#endif | #endif | ||||
} | } | ||||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { | MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { | ||||
@@ -173,7 +173,7 @@ struct CheckHasInfOp { | |||||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | ||||
return lhs | rhs; | return lhs | rhs; | ||||
} | } | ||||
MEGDNN_HOST MEGDNN_DEVICE CheckHasInfOp(src_ctype* src, dst_ctype* dst, | |||||
MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp(src_ctype* src, dst_ctype* dst, | |||||
size_t B) | size_t B) | ||||
: INIT(wtype(0)), src(src), dst(dst), B(B) {} | : INIT(wtype(0)), src(src), dst(dst), B(B) {} | ||||
}; | }; | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/cuda/check_has_inf/kern.cu | |||||
* \file dnn/src/cuda/check_non_finite/kern.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -18,7 +18,7 @@ namespace cuda { | |||||
#define COMMA , | #define COMMA , | ||||
INST_REDUCE(reduce::CheckHasInfOp<dt_float32 COMMA dt_int32 COMMA dt_int32>, false); | |||||
INST_REDUCE(reduce::CheckNonFiniteOp<dt_float32 COMMA dt_int32 COMMA dt_int32>, false); | |||||
#undef COMMA | #undef COMMA | ||||
} // namespace cuda | } // namespace cuda |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/cuda/check_has_inf/opr_impl.cpp | |||||
* \file dnn/src/cuda/check_non_finite/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -9,7 +9,7 @@ | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "src/cuda/check_has_inf/opr_impl.h" | |||||
#include "src/cuda/check_non_finite/opr_impl.h" | |||||
#include "src/cuda/reduce_helper.cuh" | #include "src/cuda/reduce_helper.cuh" | ||||
#include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
@@ -20,18 +20,18 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace cuda { | namespace cuda { | ||||
using reduce::CheckHasInfOp; | |||||
using reduce::CheckNonFiniteOp; | |||||
size_t CheckHasInfImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
size_t CheckNonFiniteImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
const TensorLayout& dst) { | const TensorLayout& dst) { | ||||
typedef CheckHasInfOp<dt_float32, dt_int32, dt_int32> Op; | |||||
typedef CheckNonFiniteOp<dt_float32, dt_int32, dt_int32> Op; | |||||
return get_reduce_workspace_in_bytes<Op>(1, src.total_nr_elems(), 1); | return get_reduce_workspace_in_bytes<Op>(1, src.total_nr_elems(), 1); | ||||
} | } | ||||
void CheckHasInfImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
void CheckNonFiniteImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) { | _megdnn_workspace workspace) { | ||||
check_exec(src.layout, dst.layout, workspace.size); | check_exec(src.layout, dst.layout, workspace.size); | ||||
typedef CheckHasInfOp<dt_float32, dt_int32, dt_int32> Op; | |||||
typedef CheckNonFiniteOp<dt_float32, dt_int32, dt_int32> Op; | |||||
auto stream = cuda_stream(this->handle()); | auto stream = cuda_stream(this->handle()); | ||||
auto B = src.layout.total_nr_elems(); | auto B = src.layout.total_nr_elems(); | ||||
return run_reduce<Op, false>( | return run_reduce<Op, false>( |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/cuda/check_has_inf/opr_impl.h | |||||
* \file dnn/src/cuda/check_non_finite/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -17,9 +17,9 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace cuda { | namespace cuda { | ||||
class CheckHasInfImpl final : public CheckHasInf { | |||||
class CheckNonFiniteImpl final : public CheckNonFinite { | |||||
public: | public: | ||||
using CheckHasInf::CheckHasInf; | |||||
using CheckNonFinite::CheckNonFinite; | |||||
size_t get_workspace_in_bytes(const TensorLayout& src, | size_t get_workspace_in_bytes(const TensorLayout& src, | ||||
const TensorLayout& dst) override; | const TensorLayout& dst) override; |
@@ -20,7 +20,7 @@ | |||||
#include "src/cuda/batch_conv_bias/opr_impl.h" | #include "src/cuda/batch_conv_bias/opr_impl.h" | ||||
#include "src/cuda/batch_normalization/opr_impl.h" | #include "src/cuda/batch_normalization/opr_impl.h" | ||||
#include "src/cuda/batched_matrix_mul/opr_impl.h" | #include "src/cuda/batched_matrix_mul/opr_impl.h" | ||||
#include "src/cuda/check_has_inf/opr_impl.h" | |||||
#include "src/cuda/check_non_finite/opr_impl.h" | |||||
#include "src/cuda/checksum/opr_impl.h" | #include "src/cuda/checksum/opr_impl.h" | ||||
#include "src/cuda/concat/opr_impl.h" | #include "src/cuda/concat/opr_impl.h" | ||||
#include "src/cuda/cond_take/opr_impl.h" | #include "src/cuda/cond_take/opr_impl.h" | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/naive/check_has_inf/opr_impl.cpp | |||||
* \file dnn/src/naive/check_non_finite/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -9,7 +9,7 @@ | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "src/naive/check_has_inf/opr_impl.h" | |||||
#include "src/naive/check_non_finite/opr_impl.h" | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
@@ -27,7 +27,7 @@ void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) { | |||||
size_t mid = l + (r - l) / 2; | size_t mid = l + (r - l) / 2; | ||||
return func(l, mid) | func(mid, r); | return func(l, mid) | func(mid, r); | ||||
} else { | } else { | ||||
return static_cast<wtype>(std::isinf(sptr[l])); | |||||
return static_cast<wtype>(!std::isfinite(sptr[l])); | |||||
} | } | ||||
}; | }; | ||||
@@ -39,12 +39,12 @@ void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) { | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace naive { | namespace naive { | ||||
size_t CheckHasInfImpl::get_workspace_in_bytes(const TensorLayout&, | |||||
size_t CheckNonFiniteImpl::get_workspace_in_bytes(const TensorLayout&, | |||||
const TensorLayout&) { | const TensorLayout&) { | ||||
return 0; | return 0; | ||||
} | } | ||||
void CheckHasInfImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
void CheckNonFiniteImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) { | _megdnn_workspace workspace) { | ||||
check_exec(src.layout, dst.layout, workspace.size); | check_exec(src.layout, dst.layout, workspace.size); | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/naive/check_has_inf/opr_impl.h | |||||
* \file dnn/src/naive/check_non_finite/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -16,9 +16,9 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace naive { | namespace naive { | ||||
class CheckHasInfImpl final : public CheckHasInf { | |||||
class CheckNonFiniteImpl final : public CheckNonFinite { | |||||
public: | public: | ||||
using CheckHasInf::CheckHasInf; | |||||
using CheckNonFinite::CheckNonFinite; | |||||
bool is_thread_safe() const override { return true; } | bool is_thread_safe() const override { return true; } | ||||
@@ -22,7 +22,7 @@ | |||||
#include "src/naive/batch_conv_bias/opr_impl.h" | #include "src/naive/batch_conv_bias/opr_impl.h" | ||||
#include "src/naive/batch_normalization/opr_impl.h" | #include "src/naive/batch_normalization/opr_impl.h" | ||||
#include "src/naive/batched_matrix_mul/opr_impl.h" | #include "src/naive/batched_matrix_mul/opr_impl.h" | ||||
#include "src/naive/check_has_inf/opr_impl.h" | |||||
#include "src/naive/check_non_finite/opr_impl.h" | |||||
#include "src/naive/checksum/opr_impl.h" | #include "src/naive/checksum/opr_impl.h" | ||||
#include "src/naive/concat/opr_impl.h" | #include "src/naive/concat/opr_impl.h" | ||||
#include "src/naive/cond_take/opr_impl.h" | #include "src/naive/cond_take/opr_impl.h" | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/test/cuda/check_has_inf.cpp | |||||
* \file dnn/test/cuda/check_non_finite.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -15,16 +15,20 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace test { | namespace test { | ||||
TEST_F(CUDA, CHECK_HAS_INF_BASIC) { | |||||
Checker<CheckHasInf> checker(handle_cuda()); | |||||
TEST_F(CUDA, CHECK_NON_FINITE_BASIC) { | |||||
Checker<CheckNonFinite> checker(handle_cuda()); | |||||
checker.set_allow_invalid_check(true); | checker.set_allow_invalid_check(true); | ||||
const auto inf = std::numeric_limits<float>::infinity(); | const auto inf = std::numeric_limits<float>::infinity(); | ||||
const auto nan = std::numeric_limits<float>::quiet_NaN(); | |||||
UniformFloatWithValueRNG rng(-1.0f, 1.0f, 0.1f, inf); | UniformFloatWithValueRNG rng(-1.0f, 1.0f, 0.1f, inf); | ||||
checker.set_rng(0, &rng); | checker.set_rng(0, &rng); | ||||
checker.execs({{512*16}, {1}}); | checker.execs({{512*16}, {1}}); | ||||
rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, inf); | rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, inf); | ||||
checker.set_rng(0, &rng); | checker.set_rng(0, &rng); | ||||
checker.execs({{512*16}, {1}}); | checker.execs({{512*16}, {1}}); | ||||
rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, nan); | |||||
checker.set_rng(0, &rng); | |||||
checker.execs({{512*16}, {1}}); | |||||
} | } | ||||
} // namespace test | } // namespace test |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file test/naive/check_has_inf.cpp | |||||
* \file test/naive/check_non_finite.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -17,8 +17,8 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace test { | namespace test { | ||||
TEST_F(NAIVE, CHECK_HAS_INF_BASIC) { | |||||
Checker<CheckHasInf> checker(handle(), false); | |||||
TEST_F(NAIVE, CHECK_NON_FINITE_BASIC) { | |||||
Checker<CheckNonFinite> checker(handle(), false); | |||||
checker.exect(Testcase{TensorValue({4}, dtype::Float32(), | checker.exect(Testcase{TensorValue({4}, dtype::Float32(), | ||||
{1.1, 2.2, 3.3, 4.3}), | {1.1, 2.2, 3.3, 4.3}), | ||||
{}}, | {}}, | ||||
@@ -29,6 +29,12 @@ TEST_F(NAIVE, CHECK_HAS_INF_BASIC) { | |||||
std::numeric_limits<float>::infinity()}), | std::numeric_limits<float>::infinity()}), | ||||
{}}, | {}}, | ||||
Testcase{{}, TensorValue({1}, dtype::Int32(), {1})}); | Testcase{{}, TensorValue({1}, dtype::Int32(), {1})}); | ||||
checker.exect( | |||||
Testcase{TensorValue({4}, dtype::Float32(), | |||||
{1.1f, 2.2f, 3.3f, | |||||
std::numeric_limits<float>::quiet_NaN()}), | |||||
{}}, | |||||
Testcase{{}, TensorValue({1}, dtype::Int32(), {1})}); | |||||
} | } | ||||
} // namespace test | } // namespace test |
@@ -11,7 +11,7 @@ import numpy as np | |||||
from ..autodiff import GradManager | from ..autodiff import GradManager | ||||
from ..functional import full_like | from ..functional import full_like | ||||
from ..functional.math import _has_inf | |||||
from ..functional.math import _check_non_finite | |||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
@@ -76,7 +76,7 @@ class GradScaler: | |||||
self.growth_interval = growth_interval | self.growth_interval = growth_interval | ||||
self._growth_tracker = 0 | self._growth_tracker = 0 | ||||
self._found_inf = False | |||||
self._found_non_finite = False | |||||
def backward( | def backward( | ||||
self, | self, | ||||
@@ -135,10 +135,10 @@ class GradScaler: | |||||
continue | continue | ||||
# to support tracing, _check_gradients should be applied to every grad. | # to support tracing, _check_gradients should be applied to every grad. | ||||
if self._check_gradients(tensor.grad): | if self._check_gradients(tensor.grad): | ||||
self._found_inf = True | |||||
self._found_non_finite = True | |||||
tensor.grad *= inv_scale | tensor.grad *= inv_scale | ||||
if self._found_inf: | |||||
if self._found_non_finite: | |||||
for tensor in grad_tensors: | for tensor in grad_tensors: | ||||
if tensor is None or getattr(tensor, "grad", None) is None: | if tensor is None or getattr(tensor, "grad", None) is None: | ||||
continue | continue | ||||
@@ -148,7 +148,7 @@ class GradScaler: | |||||
def _check_gradients(self, grad): | def _check_gradients(self, grad): | ||||
if self.growth_interval == 0: | if self.growth_interval == 0: | ||||
return False | return False | ||||
return _has_inf(grad) | |||||
return _check_non_finite(grad) | |||||
def update(self, new_scale: float = None): | def update(self, new_scale: float = None): | ||||
r"""Update the scale factor according to whether encountered overflow grad. | r"""Update the scale factor according to whether encountered overflow grad. | ||||
@@ -160,7 +160,7 @@ class GradScaler: | |||||
if new_scale is not None: | if new_scale is not None: | ||||
self.scale_factor = float(new_scale) | self.scale_factor = float(new_scale) | ||||
else: | else: | ||||
if self._found_inf: | |||||
if self._found_non_finite: | |||||
self.scale_factor *= self.backoff_factor | self.scale_factor *= self.backoff_factor | ||||
self._growth_tracker = 0 | self._growth_tracker = 0 | ||||
else: | else: | ||||
@@ -168,7 +168,7 @@ class GradScaler: | |||||
if self._growth_tracker >= self.growth_interval: | if self._growth_tracker >= self.growth_interval: | ||||
self.scale_factor *= self.growth_factor | self.scale_factor *= self.growth_factor | ||||
self._growth_tracker = 0 | self._growth_tracker = 0 | ||||
self._found_inf = False | |||||
self._found_non_finite = False | |||||
def state_dict(self): | def state_dict(self): | ||||
return { | return { | ||||
@@ -1181,8 +1181,8 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: | |||||
return U, sigma, V | return U, sigma, V | ||||
def _has_inf(inp: Tensor) -> Tensor: | |||||
r"""Check whether input contains infinite value. | |||||
def _check_non_finite(inp: Tensor) -> Tensor: | |||||
r"""Check whether input contains infinite or nan value. | |||||
Args: | Args: | ||||
inp: a tensor to be checked. | inp: a tensor to be checked. | ||||
@@ -1190,7 +1190,7 @@ def _has_inf(inp: Tensor) -> Tensor: | |||||
Returns: | Returns: | ||||
a int32 scalar tensor, 0 for False and 1 for True. | a int32 scalar tensor, 0 for False and 1 for True. | ||||
""" | """ | ||||
op = builtin.CheckHasInf() | |||||
op = builtin.CheckNonFinite() | |||||
(oup,) = apply(op, inp.reshape(-1).astype("float32")) | (oup,) = apply(op, inp.reshape(-1).astype("float32")) | ||||
oup._setscalar() | oup._setscalar() | ||||
return oup | return oup |
@@ -185,14 +185,18 @@ def test_sum_neg_axis(): | |||||
F.sum(tensor(data), axis=(-1, 1)) | F.sum(tensor(data), axis=(-1, 1)) | ||||
def test_has_inf(): | |||||
def test_non_finite(): | |||||
shape = (32, 3, 32, 32) | shape = (32, 3, 32, 32) | ||||
data = np.random.random(shape).astype(np.float32) | data = np.random.random(shape).astype(np.float32) | ||||
rst = F.math._has_inf(tensor(data)) | |||||
rst = F.math._check_non_finite(tensor(data)) | |||||
np.testing.assert_equal(rst.numpy(), [0]) | np.testing.assert_equal(rst.numpy(), [0]) | ||||
data[0][0][0][0] = float("inf") | data[0][0][0][0] = float("inf") | ||||
rst = F.math._has_inf(tensor(data)) | |||||
rst = F.math._check_non_finite(tensor(data)) | |||||
np.testing.assert_equal(rst.numpy(), [1]) | |||||
data[0][0][0][0] = float("nan") | |||||
rst = F.math._check_non_finite(tensor(data)) | |||||
np.testing.assert_equal(rst.numpy(), [1]) | np.testing.assert_equal(rst.numpy(), [1]) | ||||
@@ -16,17 +16,17 @@ | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
namespace check_has_inf { | |||||
namespace check_non_finite { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | ||||
auto&& op = def.cast_final_safe<CheckHasInf>(); | |||||
auto&& op = def.cast_final_safe<CheckNonFinite>(); | |||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
OperatorNodeConfig config{op.make_name()}; | OperatorNodeConfig config{op.make_name()}; | ||||
return opr::CheckHasInf::make(inputs[0], {}, config); | |||||
return opr::CheckNonFinite::make(inputs[0], {}, config); | |||||
} | } | ||||
OP_TRAIT_REG(CheckHasInf, CheckHasInf) | |||||
OP_TRAIT_REG(CheckNonFinite, CheckNonFinite) | |||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
.fallback(); | .fallback(); | ||||
} // namespace check_has_inf | |||||
} // namespace check_non_finite | |||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -390,7 +390,7 @@ def CambriconRuntime: MgbHashableOp<"CambriconRuntime"> { | |||||
def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>; | def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>; | ||||
def CheckHasInf: MgbHashableOp<"CheckHasInf", [EmptyParam]>; | |||||
def CheckNonFinite: MgbHashableOp<"CheckNonFinite", [EmptyParam]>; | |||||
def FastpathCopy: MgbHashableOp<"FastpathCopy">; | def FastpathCopy: MgbHashableOp<"FastpathCopy">; | ||||
@@ -491,12 +491,12 @@ MGB_IMPL_OPR_GRAD(TopK) { | |||||
} | } | ||||
#endif | #endif | ||||
/* ================= CheckHasInf ================= */ | |||||
/* ================= CheckNonFinite ================= */ | |||||
namespace mgb { | namespace mgb { | ||||
namespace opr { | namespace opr { | ||||
namespace intl { | namespace intl { | ||||
template<> | template<> | ||||
struct MegDNNOprInitPostCtor<CheckHasInf> { | |||||
struct MegDNNOprInitPostCtor<CheckNonFinite> { | |||||
static void apply(cg::OperatorNodeBase &opr) { | static void apply(cg::OperatorNodeBase &opr) { | ||||
opr.output(0)->dtype(dtype::Int32()); | opr.output(0)->dtype(dtype::Int32()); | ||||
} | } | ||||
@@ -504,6 +504,6 @@ struct MegDNNOprInitPostCtor<CheckHasInf> { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CheckHasInf); | |||||
MEGDNN_OPR_INIT1(CheckHasInf, "check_has_inf") | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CheckNonFinite); | |||||
MEGDNN_OPR_INIT1(CheckNonFinite, "check_non_finite") | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -72,7 +72,7 @@ namespace opr { | |||||
#if MGB_CUDA | #if MGB_CUDA | ||||
MGB_SEREG_OPR(NvOf, 1); | MGB_SEREG_OPR(NvOf, 1); | ||||
#endif | #endif | ||||
MGB_SEREG_OPR(CheckHasInf, 1); | |||||
MGB_SEREG_OPR(CheckNonFinite, 1); | |||||
} // namespace opr | } // namespace opr | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -185,7 +185,7 @@ public: | |||||
const OperatorNodeConfig& config = {}); | const OperatorNodeConfig& config = {}); | ||||
}; | }; | ||||
MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(CheckHasInf); | |||||
MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(CheckNonFinite); | |||||
} // namespace opr | } // namespace opr | ||||
} // namespace mgb | } // namespace mgb | ||||