@@ -1319,11 +1319,11 @@ protected: | |||
}; | |||
/*! | |||
* \brief check whether input contains inf value. | |||
* \brief check whether input contains inf or nan value. | |||
*/ | |||
class CheckHasInf: public OperatorBase { | |||
class CheckNonFinite: public OperatorBase { | |||
DEF_OPR_PARAM(Empty); | |||
DEF_OPR_IMPL(CheckHasInf, OperatorBase, 1, 1); | |||
DEF_OPR_IMPL(CheckNonFinite, OperatorBase, 1, 1); | |||
public: | |||
virtual size_t get_workspace_in_bytes(const TensorLayout &src, | |||
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/common/check_has_inf.cpp | |||
* \file dnn/src/common/check_non_finite.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -14,7 +14,7 @@ | |||
namespace megdnn { | |||
void CheckHasInf::check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
void CheckNonFinite::check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
size_t workspace_in_bytes) { | |||
megdnn_assert_contiguous(src); | |||
megdnn_assert_contiguous(dst); | |||
@@ -24,7 +24,7 @@ void CheckHasInf::check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
} | |||
void CheckHasInf::deduce_layout(const TensorLayout&, TensorLayout& dst) { | |||
void CheckNonFinite::deduce_layout(const TensorLayout&, TensorLayout& dst) { | |||
dst.shape[0] = 1; | |||
dst.ndim = 1; | |||
dst.dtype = dtype::Int32(); |
@@ -216,7 +216,7 @@ private: | |||
cb(FakeQuantBackward) \ | |||
cb(TQTForward) \ | |||
cb(TQTBackward) \ | |||
cb(CheckHasInf) \ | |||
cb(CheckNonFinite) \ | |||
cb(LSQForward) \ | |||
cb(LSQBackward) \ | |||
cb(Fill) \ | |||
@@ -131,7 +131,7 @@ DEF(PermutationRNG, 1, true, true); | |||
DEF(ShuffleRNGForward, 3, true, true); | |||
DEF(ShuffleRNGBackward, 3, true, false); | |||
DEF(ChecksumForward, 1, true, false); | |||
DEF(CheckHasInf, 2, true, true); | |||
DEF(CheckNonFinite, 2, true, true); | |||
DEF(LSQForward, 5, true, true); | |||
DEF(LSQBackward, 7, true, false); | |||
DEF(Fill, 1, true, false); | |||
@@ -152,7 +152,7 @@ struct MaxOp { | |||
}; | |||
template <typename src_ctype, typename dst_ctype, typename wtype_> | |||
struct CheckHasInfOp { | |||
struct CheckNonFiniteOp { | |||
typedef wtype_ wtype; | |||
const wtype INIT; | |||
@@ -162,9 +162,9 @@ struct CheckHasInfOp { | |||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { | |||
#if defined(__CUDA_ARCH__) | |||
return isinf(src[idx]); | |||
return !isfinite(src[idx]); | |||
#else | |||
return std::isinf(src[idx]); | |||
return !std::isfinite(src[idx]); | |||
#endif | |||
} | |||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { | |||
@@ -173,7 +173,7 @@ struct CheckHasInfOp { | |||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||
return lhs | rhs; | |||
} | |||
MEGDNN_HOST MEGDNN_DEVICE CheckHasInfOp(src_ctype* src, dst_ctype* dst, | |||
MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp(src_ctype* src, dst_ctype* dst, | |||
size_t B) | |||
: INIT(wtype(0)), src(src), dst(dst), B(B) {} | |||
}; | |||
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/cuda/check_has_inf/kern.cu | |||
* \file dnn/src/cuda/check_non_finite/kern.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -18,7 +18,7 @@ namespace cuda { | |||
#define COMMA , | |||
INST_REDUCE(reduce::CheckHasInfOp<dt_float32 COMMA dt_int32 COMMA dt_int32>, false); | |||
INST_REDUCE(reduce::CheckNonFiniteOp<dt_float32 COMMA dt_int32 COMMA dt_int32>, false); | |||
#undef COMMA | |||
} // namespace cuda |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/cuda/check_has_inf/opr_impl.cpp | |||
* \file dnn/src/cuda/check_non_finite/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -9,7 +9,7 @@ | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/cuda/check_has_inf/opr_impl.h" | |||
#include "src/cuda/check_non_finite/opr_impl.h" | |||
#include "src/cuda/reduce_helper.cuh" | |||
#include "src/cuda/handle.h" | |||
@@ -20,18 +20,18 @@ | |||
namespace megdnn { | |||
namespace cuda { | |||
using reduce::CheckHasInfOp; | |||
using reduce::CheckNonFiniteOp; | |||
size_t CheckHasInfImpl::get_workspace_in_bytes(const TensorLayout& src, | |||
size_t CheckNonFiniteImpl::get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& dst) { | |||
typedef CheckHasInfOp<dt_float32, dt_int32, dt_int32> Op; | |||
typedef CheckNonFiniteOp<dt_float32, dt_int32, dt_int32> Op; | |||
return get_reduce_workspace_in_bytes<Op>(1, src.total_nr_elems(), 1); | |||
} | |||
void CheckHasInfImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
void CheckNonFiniteImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) { | |||
check_exec(src.layout, dst.layout, workspace.size); | |||
typedef CheckHasInfOp<dt_float32, dt_int32, dt_int32> Op; | |||
typedef CheckNonFiniteOp<dt_float32, dt_int32, dt_int32> Op; | |||
auto stream = cuda_stream(this->handle()); | |||
auto B = src.layout.total_nr_elems(); | |||
return run_reduce<Op, false>( |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/cuda/check_has_inf/opr_impl.h | |||
* \file dnn/src/cuda/check_non_finite/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -17,9 +17,9 @@ | |||
namespace megdnn { | |||
namespace cuda { | |||
class CheckHasInfImpl final : public CheckHasInf { | |||
class CheckNonFiniteImpl final : public CheckNonFinite { | |||
public: | |||
using CheckHasInf::CheckHasInf; | |||
using CheckNonFinite::CheckNonFinite; | |||
size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& dst) override; |
@@ -20,7 +20,7 @@ | |||
#include "src/cuda/batch_conv_bias/opr_impl.h" | |||
#include "src/cuda/batch_normalization/opr_impl.h" | |||
#include "src/cuda/batched_matrix_mul/opr_impl.h" | |||
#include "src/cuda/check_has_inf/opr_impl.h" | |||
#include "src/cuda/check_non_finite/opr_impl.h" | |||
#include "src/cuda/checksum/opr_impl.h" | |||
#include "src/cuda/concat/opr_impl.h" | |||
#include "src/cuda/cond_take/opr_impl.h" | |||
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/naive/check_has_inf/opr_impl.cpp | |||
* \file dnn/src/naive/check_non_finite/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -9,7 +9,7 @@ | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/naive/check_has_inf/opr_impl.h" | |||
#include "src/naive/check_non_finite/opr_impl.h" | |||
#include "src/common/utils.h" | |||
#include "src/naive/handle.h" | |||
@@ -27,7 +27,7 @@ void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) { | |||
size_t mid = l + (r - l) / 2; | |||
return func(l, mid) | func(mid, r); | |||
} else { | |||
return static_cast<wtype>(std::isinf(sptr[l])); | |||
return static_cast<wtype>(!std::isfinite(sptr[l])); | |||
} | |||
}; | |||
@@ -39,12 +39,12 @@ void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) { | |||
namespace megdnn { | |||
namespace naive { | |||
size_t CheckHasInfImpl::get_workspace_in_bytes(const TensorLayout&, | |||
size_t CheckNonFiniteImpl::get_workspace_in_bytes(const TensorLayout&, | |||
const TensorLayout&) { | |||
return 0; | |||
} | |||
void CheckHasInfImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
void CheckNonFiniteImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) { | |||
check_exec(src.layout, dst.layout, workspace.size); | |||
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/naive/check_has_inf/opr_impl.h | |||
* \file dnn/src/naive/check_non_finite/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -16,9 +16,9 @@ | |||
namespace megdnn { | |||
namespace naive { | |||
class CheckHasInfImpl final : public CheckHasInf { | |||
class CheckNonFiniteImpl final : public CheckNonFinite { | |||
public: | |||
using CheckHasInf::CheckHasInf; | |||
using CheckNonFinite::CheckNonFinite; | |||
bool is_thread_safe() const override { return true; } | |||
@@ -22,7 +22,7 @@ | |||
#include "src/naive/batch_conv_bias/opr_impl.h" | |||
#include "src/naive/batch_normalization/opr_impl.h" | |||
#include "src/naive/batched_matrix_mul/opr_impl.h" | |||
#include "src/naive/check_has_inf/opr_impl.h" | |||
#include "src/naive/check_non_finite/opr_impl.h" | |||
#include "src/naive/checksum/opr_impl.h" | |||
#include "src/naive/concat/opr_impl.h" | |||
#include "src/naive/cond_take/opr_impl.h" | |||
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/test/cuda/check_has_inf.cpp | |||
* \file dnn/test/cuda/check_non_finite.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -15,16 +15,20 @@ | |||
namespace megdnn { | |||
namespace test { | |||
TEST_F(CUDA, CHECK_HAS_INF_BASIC) { | |||
Checker<CheckHasInf> checker(handle_cuda()); | |||
TEST_F(CUDA, CHECK_NON_FINITE_BASIC) { | |||
Checker<CheckNonFinite> checker(handle_cuda()); | |||
checker.set_allow_invalid_check(true); | |||
const auto inf = std::numeric_limits<float>::infinity(); | |||
const auto nan = std::numeric_limits<float>::quiet_NaN(); | |||
UniformFloatWithValueRNG rng(-1.0f, 1.0f, 0.1f, inf); | |||
checker.set_rng(0, &rng); | |||
checker.execs({{512*16}, {1}}); | |||
rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, inf); | |||
checker.set_rng(0, &rng); | |||
checker.execs({{512*16}, {1}}); | |||
rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, nan); | |||
checker.set_rng(0, &rng); | |||
checker.execs({{512*16}, {1}}); | |||
} | |||
} // namespace test |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file test/naive/check_has_inf.cpp | |||
* \file test/naive/check_non_finite.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -17,8 +17,8 @@ | |||
namespace megdnn { | |||
namespace test { | |||
TEST_F(NAIVE, CHECK_HAS_INF_BASIC) { | |||
Checker<CheckHasInf> checker(handle(), false); | |||
TEST_F(NAIVE, CHECK_NON_FINITE_BASIC) { | |||
Checker<CheckNonFinite> checker(handle(), false); | |||
checker.exect(Testcase{TensorValue({4}, dtype::Float32(), | |||
{1.1, 2.2, 3.3, 4.3}), | |||
{}}, | |||
@@ -29,6 +29,12 @@ TEST_F(NAIVE, CHECK_HAS_INF_BASIC) { | |||
std::numeric_limits<float>::infinity()}), | |||
{}}, | |||
Testcase{{}, TensorValue({1}, dtype::Int32(), {1})}); | |||
checker.exect( | |||
Testcase{TensorValue({4}, dtype::Float32(), | |||
{1.1f, 2.2f, 3.3f, | |||
std::numeric_limits<float>::quiet_NaN()}), | |||
{}}, | |||
Testcase{{}, TensorValue({1}, dtype::Int32(), {1})}); | |||
} | |||
} // namespace test |
@@ -11,7 +11,7 @@ import numpy as np | |||
from ..autodiff import GradManager | |||
from ..functional import full_like | |||
from ..functional.math import _has_inf | |||
from ..functional.math import _check_non_finite | |||
from ..tensor import Tensor | |||
@@ -76,7 +76,7 @@ class GradScaler: | |||
self.growth_interval = growth_interval | |||
self._growth_tracker = 0 | |||
self._found_inf = False | |||
self._found_non_finite = False | |||
def backward( | |||
self, | |||
@@ -135,10 +135,10 @@ class GradScaler: | |||
continue | |||
# to support tracing, _check_gradients should be applied to every grad. | |||
if self._check_gradients(tensor.grad): | |||
self._found_inf = True | |||
self._found_non_finite = True | |||
tensor.grad *= inv_scale | |||
if self._found_inf: | |||
if self._found_non_finite: | |||
for tensor in grad_tensors: | |||
if tensor is None or getattr(tensor, "grad", None) is None: | |||
continue | |||
@@ -148,7 +148,7 @@ class GradScaler: | |||
def _check_gradients(self, grad): | |||
if self.growth_interval == 0: | |||
return False | |||
return _has_inf(grad) | |||
return _check_non_finite(grad) | |||
def update(self, new_scale: float = None): | |||
r"""Update the scale factor according to whether encountered overflow grad. | |||
@@ -160,7 +160,7 @@ class GradScaler: | |||
if new_scale is not None: | |||
self.scale_factor = float(new_scale) | |||
else: | |||
if self._found_inf: | |||
if self._found_non_finite: | |||
self.scale_factor *= self.backoff_factor | |||
self._growth_tracker = 0 | |||
else: | |||
@@ -168,7 +168,7 @@ class GradScaler: | |||
if self._growth_tracker >= self.growth_interval: | |||
self.scale_factor *= self.growth_factor | |||
self._growth_tracker = 0 | |||
self._found_inf = False | |||
self._found_non_finite = False | |||
def state_dict(self): | |||
return { | |||
@@ -1181,8 +1181,8 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: | |||
return U, sigma, V | |||
def _has_inf(inp: Tensor) -> Tensor: | |||
r"""Check whether input contains infinite value. | |||
def _check_non_finite(inp: Tensor) -> Tensor: | |||
r"""Check whether input contains infinite or nan value. | |||
Args: | |||
inp: a tensor to be checked. | |||
@@ -1190,7 +1190,7 @@ def _has_inf(inp: Tensor) -> Tensor: | |||
Returns: | |||
a int32 scalar tensor, 0 for False and 1 for True. | |||
""" | |||
op = builtin.CheckHasInf() | |||
op = builtin.CheckNonFinite() | |||
(oup,) = apply(op, inp.reshape(-1).astype("float32")) | |||
oup._setscalar() | |||
return oup |
@@ -185,14 +185,18 @@ def test_sum_neg_axis(): | |||
F.sum(tensor(data), axis=(-1, 1)) | |||
def test_has_inf(): | |||
def test_non_finite(): | |||
shape = (32, 3, 32, 32) | |||
data = np.random.random(shape).astype(np.float32) | |||
rst = F.math._has_inf(tensor(data)) | |||
rst = F.math._check_non_finite(tensor(data)) | |||
np.testing.assert_equal(rst.numpy(), [0]) | |||
data[0][0][0][0] = float("inf") | |||
rst = F.math._has_inf(tensor(data)) | |||
rst = F.math._check_non_finite(tensor(data)) | |||
np.testing.assert_equal(rst.numpy(), [1]) | |||
data[0][0][0][0] = float("nan") | |||
rst = F.math._check_non_finite(tensor(data)) | |||
np.testing.assert_equal(rst.numpy(), [1]) | |||
@@ -16,17 +16,17 @@ | |||
namespace mgb { | |||
namespace imperative { | |||
namespace check_has_inf { | |||
namespace check_non_finite { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = def.cast_final_safe<CheckHasInf>(); | |||
auto&& op = def.cast_final_safe<CheckNonFinite>(); | |||
mgb_assert(inputs.size() == 1); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::CheckHasInf::make(inputs[0], {}, config); | |||
return opr::CheckNonFinite::make(inputs[0], {}, config); | |||
} | |||
OP_TRAIT_REG(CheckHasInf, CheckHasInf) | |||
OP_TRAIT_REG(CheckNonFinite, CheckNonFinite) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
} // namespace check_has_inf | |||
} // namespace check_non_finite | |||
} // namespace imperative | |||
} // namespace mgb | |||
@@ -390,7 +390,7 @@ def CambriconRuntime: MgbHashableOp<"CambriconRuntime"> { | |||
def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>; | |||
def CheckHasInf: MgbHashableOp<"CheckHasInf", [EmptyParam]>; | |||
def CheckNonFinite: MgbHashableOp<"CheckNonFinite", [EmptyParam]>; | |||
def FastpathCopy: MgbHashableOp<"FastpathCopy">; | |||
@@ -491,12 +491,12 @@ MGB_IMPL_OPR_GRAD(TopK) { | |||
} | |||
#endif | |||
/* ================= CheckHasInf ================= */ | |||
/* ================= CheckNonFinite ================= */ | |||
namespace mgb { | |||
namespace opr { | |||
namespace intl { | |||
template<> | |||
struct MegDNNOprInitPostCtor<CheckHasInf> { | |||
struct MegDNNOprInitPostCtor<CheckNonFinite> { | |||
static void apply(cg::OperatorNodeBase &opr) { | |||
opr.output(0)->dtype(dtype::Int32()); | |||
} | |||
@@ -504,6 +504,6 @@ struct MegDNNOprInitPostCtor<CheckHasInf> { | |||
} | |||
} | |||
} | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CheckHasInf); | |||
MEGDNN_OPR_INIT1(CheckHasInf, "check_has_inf") | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CheckNonFinite); | |||
MEGDNN_OPR_INIT1(CheckNonFinite, "check_non_finite") | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -72,7 +72,7 @@ namespace opr { | |||
#if MGB_CUDA | |||
MGB_SEREG_OPR(NvOf, 1); | |||
#endif | |||
MGB_SEREG_OPR(CheckHasInf, 1); | |||
MGB_SEREG_OPR(CheckNonFinite, 1); | |||
} // namespace opr | |||
} // namespace mgb | |||
@@ -185,7 +185,7 @@ public: | |||
const OperatorNodeConfig& config = {}); | |||
}; | |||
MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(CheckHasInf); | |||
MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(CheckNonFinite); | |||
} // namespace opr | |||
} // namespace mgb | |||