From a5060a2bfe914bba9dd84af16655f54033ab0749 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 23 Jun 2021 18:38:22 +0800 Subject: [PATCH] feat(mgb/opr): add check_has_inf kernel and opr GitOrigin-RevId: 0d042dbfce8baa51245f4189e197bf800347c6b9 --- dnn/include/megdnn/oprs/general.h | 21 ++++++++ dnn/src/common/check_has_inf.cpp | 36 +++++++++++++ dnn/src/common/handle_impl.h | 3 +- dnn/src/common/opr_trait.h | 1 + dnn/src/common/reduce_helper.h | 33 ++++++++++-- dnn/src/cuda/check_has_inf/kern.cu | 27 ++++++++++ dnn/src/cuda/check_has_inf/opr_impl.cpp | 45 +++++++++++++++++ dnn/src/cuda/check_has_inf/opr_impl.h | 36 +++++++++++++ dnn/src/cuda/handle_create.cpp | 1 + dnn/src/cuda/reduce/reduce.cu | 15 +++--- dnn/src/naive/check_has_inf/opr_impl.cpp | 59 ++++++++++++++++++++++ dnn/src/naive/check_has_inf/opr_impl.h | 35 +++++++++++++ dnn/src/naive/handle.cpp | 1 + dnn/src/rocm/reduce/reduce.cpp.hip | 15 +++--- dnn/test/common/checker.cpp | 39 ++++++++++---- dnn/test/common/checker.h | 19 ++++++- dnn/test/common/exec_proxy.h | 34 ++++++------- dnn/test/common/rng.cpp | 8 ++- dnn/test/common/rng.h | 34 ++++++++----- dnn/test/cuda/check_has_inf.cpp | 33 ++++++++++++ dnn/test/naive/check_has_inf.cpp | 37 ++++++++++++++ imperative/python/megengine/functional/math.py | 13 +++++ .../python/test/unit/functional/test_math.py | 11 ++++ imperative/src/impl/ops/misc.cpp | 34 +++++++++++++ src/core/include/megbrain/ir/ops.td | 2 + src/opr/impl/misc.cpp | 15 ++++++ src/opr/impl/misc.sereg.h | 1 + src/opr/include/megbrain/opr/misc.h | 2 + 28 files changed, 544 insertions(+), 66 deletions(-) create mode 100644 dnn/src/common/check_has_inf.cpp create mode 100644 dnn/src/cuda/check_has_inf/kern.cu create mode 100644 dnn/src/cuda/check_has_inf/opr_impl.cpp create mode 100644 dnn/src/cuda/check_has_inf/opr_impl.h create mode 100644 dnn/src/naive/check_has_inf/opr_impl.cpp create mode 100644 dnn/src/naive/check_has_inf/opr_impl.h create mode 100644 dnn/test/cuda/check_has_inf.cpp create mode 100644 dnn/test/naive/check_has_inf.cpp create mode 100644 imperative/src/impl/ops/misc.cpp diff --git a/dnn/include/megdnn/oprs/general.h b/dnn/include/megdnn/oprs/general.h index 4b1c2887..829d5395 100644 --- a/dnn/include/megdnn/oprs/general.h +++ b/dnn/include/megdnn/oprs/general.h @@ -1317,6 +1317,27 @@ protected: TensorLayout& exec_workspace, TensorLayout& exec_src, TensorLayout& exec_dst); }; + +/*! + * \brief check whether input contains inf value. + */ +class CheckHasInf: public OperatorBase { + DEF_OPR_PARAM(Empty); + DEF_OPR_IMPL(CheckHasInf, OperatorBase, 1, 1); + + public: + virtual size_t get_workspace_in_bytes(const TensorLayout &src, + const TensorLayout &dst) = 0; + + void deduce_layout(const TensorLayout &src, TensorLayout &dst); + + virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + + protected: + void check_exec(const TensorLayout &src, const TensorLayout &dst, + size_t workspace_in_bytes); +}; } // namespace megdnn #include "megdnn/internal/opr_header_epilogue.h" diff --git a/dnn/src/common/check_has_inf.cpp b/dnn/src/common/check_has_inf.cpp new file mode 100644 index 00000000..66f1a63c --- /dev/null +++ b/dnn/src/common/check_has_inf.cpp @@ -0,0 +1,36 @@ +/** + * \file dnn/src/common/check_has_inf.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megdnn/oprs.h" +#include "src/common/utils.h" + +namespace megdnn { + +void CheckHasInf::check_exec(const TensorLayout& src, const TensorLayout& dst, + size_t workspace_in_bytes) { + megdnn_assert_contiguous(src); + megdnn_assert_contiguous(dst); + megdnn_assert(src.ndim == 1); + megdnn_assert(src.dtype == dtype::Float32()); + auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); +} + +void CheckHasInf::deduce_layout(const TensorLayout&, TensorLayout& dst) { + dst.shape[0] = 1; + dst.ndim = 1; + dst.dtype = dtype::Int32(); + dst.init_contiguous_stride(); +} + +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index e05449ab..4d6354e3 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -207,7 +207,8 @@ private: cb(FakeQuantForward) \ cb(FakeQuantBackward) \ cb(TQTForward) \ - cb(TQTBackward) + cb(TQTBackward) \ + cb(CheckHasInf) /*! * \brief specialize HandleImpl::create_operator for a single opr type; diff --git a/dnn/src/common/opr_trait.h b/dnn/src/common/opr_trait.h index 0642345e..313bac68 100644 --- a/dnn/src/common/opr_trait.h +++ b/dnn/src/common/opr_trait.h @@ -120,6 +120,7 @@ DEF(PowC, 2, false, true); DEF(UniformRNG, 1, true, true); DEF(GaussianRNG, 1, true, true); DEF(ChecksumForward, 1, true, false); +DEF(CheckHasInf, 2, true, true); } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/reduce_helper.h b/dnn/src/common/reduce_helper.h index c7cd6890..08d4fc47 100644 --- a/dnn/src/common/reduce_helper.h +++ b/dnn/src/common/reduce_helper.h @@ -4,9 +4,9 @@ * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #include "megdnn/dtype.h" @@ -151,6 +151,33 @@ struct MaxOp { : INIT(wtype(DTypeTrait::min())), src(src), dst(dst), B(B) {} }; +template +struct CheckHasInfOp { + typedef wtype_ wtype; + const wtype INIT; + + src_ctype* src; + dst_ctype* dst; + const size_t B; + + MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { +#if defined(__CUDA_ARCH__) + return isinf(src[idx]); +#else + return std::isinf(src[idx]); +#endif + } + MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { + dst[idx] = val; + } + static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { + return lhs | rhs; + } + MEGDNN_HOST MEGDNN_DEVICE CheckHasInfOp(src_ctype* src, dst_ctype* dst, + size_t B) + : INIT(wtype(0)), src(src), dst(dst), B(B) {} +}; + #if MEGDNN_CC_HOST void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, size_t axis); diff --git a/dnn/src/cuda/check_has_inf/kern.cu b/dnn/src/cuda/check_has_inf/kern.cu new file mode 100644 index 00000000..cb3d1049 --- /dev/null +++ b/dnn/src/cuda/check_has_inf/kern.cu @@ -0,0 +1,27 @@ +/** + * \file dnn/src/cuda/check_has_inf/kern.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/common/reduce_helper.h" + +#include "megdnn/dtype.h" +#include "src/cuda/reduce_helper.cuh" + +namespace megdnn { +namespace cuda { + +#define COMMA , + +INST_REDUCE(reduce::CheckHasInfOp, false); + +#undef COMMA +} // namespace cuda +} // namespace megdnn + +// vim: ft=cpp syntax=cpp.doxygen diff --git a/dnn/src/cuda/check_has_inf/opr_impl.cpp b/dnn/src/cuda/check_has_inf/opr_impl.cpp new file mode 100644 index 00000000..bf44be61 --- /dev/null +++ b/dnn/src/cuda/check_has_inf/opr_impl.cpp @@ -0,0 +1,45 @@ +/** + * \file dnn/src/cuda/check_has_inf/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/cuda/check_has_inf/opr_impl.h" +#include "src/cuda/reduce_helper.cuh" + +#include "src/cuda/handle.h" +#include "src/cuda/utils.h" + +#include "src/common/reduce_helper.h" + +namespace megdnn { +namespace cuda { + +using reduce::CheckHasInfOp; + +size_t CheckHasInfImpl::get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& dst) { + typedef CheckHasInfOp Op; + return get_reduce_workspace_in_bytes(1, src.total_nr_elems(), 1); +} + +void CheckHasInfImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { + check_exec(src.layout, dst.layout, workspace.size); + typedef CheckHasInfOp Op; + auto stream = cuda_stream(this->handle()); + auto B = src.layout.total_nr_elems(); + return run_reduce( + workspace.ptr(), 1, B, 1, stream, + Op(src.ptr(), dst.ptr(), B)); +} + +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/check_has_inf/opr_impl.h b/dnn/src/cuda/check_has_inf/opr_impl.h new file mode 100644 index 00000000..32d60f66 --- /dev/null +++ b/dnn/src/cuda/check_has_inf/opr_impl.h @@ -0,0 +1,36 @@ +/** + * \file dnn/src/cuda/check_has_inf/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megdnn/oprs/utils.h" +#include "src/cuda/utils.h" + +namespace megdnn { +namespace cuda { + +class CheckHasInfImpl final : public CheckHasInf { +public: + using CheckHasInf::CheckHasInf; + + size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& dst) override; + + bool is_thread_safe() const override { return true; } + + void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; +}; + +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp index 80ac5963..afbd1dee 100644 --- a/dnn/src/cuda/handle_create.cpp +++ b/dnn/src/cuda/handle_create.cpp @@ -17,6 +17,7 @@ #include "src/cuda/argsort/opr_impl.h" #include "src/cuda/batch_normalization/opr_impl.h" #include "src/cuda/batched_matrix_mul/opr_impl.h" +#include "src/cuda/check_has_inf/opr_impl.h" #include "src/cuda/checksum/opr_impl.h" #include "src/cuda/concat/opr_impl.h" #include "src/cuda/cond_take/opr_impl.h" diff --git a/dnn/src/cuda/reduce/reduce.cu b/dnn/src/cuda/reduce/reduce.cu index 992513cb..5f6ab475 100644 --- a/dnn/src/cuda/reduce/reduce.cu +++ b/dnn/src/cuda/reduce/reduce.cu @@ -18,15 +18,15 @@ namespace cuda { using namespace reduce; -#define COMMOA , +#define COMMA , #define INST(sctype, dctype, wtype) \ - INST_REDUCE(SumOp, false); \ - INST_REDUCE(SumSqrOp, false); \ - INST_REDUCE(ProdOp, false); \ - INST_REDUCE(MinOp, false); \ - INST_REDUCE(MaxOp, false); \ - INST_REDUCE(MeanOp, false); + INST_REDUCE(SumOp, false); \ + INST_REDUCE(SumSqrOp, false); \ + INST_REDUCE(ProdOp, false); \ + INST_REDUCE(MinOp, false); \ + INST_REDUCE(MaxOp, false); \ + INST_REDUCE(MeanOp, false); #define cb(_dt) \ INST(DTypeTrait<_dt>::ctype, DTypeTrait<_dt>::ctype, DTypeTrait<_dt>::ctype) @@ -40,6 +40,7 @@ INST(int, float, float) #undef cb #undef INST +#undef COMMA } // namespace cuda } // namespace megdnn diff --git a/dnn/src/naive/check_has_inf/opr_impl.cpp b/dnn/src/naive/check_has_inf/opr_impl.cpp new file mode 100644 index 00000000..1a30910c --- /dev/null +++ b/dnn/src/naive/check_has_inf/opr_impl.cpp @@ -0,0 +1,59 @@ +/** + * \file dnn/src/naive/check_has_inf/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/naive/check_has_inf/opr_impl.h" + +#include "src/common/utils.h" +#include "src/naive/handle.h" + +namespace { +using namespace megdnn; + +#define src_ctype dt_float32 +#define wtype dt_int32 + +void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) { + std::function func; + func = [&](size_t l, size_t r) -> wtype { + if (l + 1 < r) { + size_t mid = l + (r - l) / 2; + return func(l, mid) | func(mid, r); + } else { + return static_cast(std::isinf(sptr[l])); + } + }; + + dptr[0] = func(0, size); +} + +} // namespace + +namespace megdnn { +namespace naive { + +size_t CheckHasInfImpl::get_workspace_in_bytes(const TensorLayout&, + const TensorLayout&) { + return 0; +} + +void CheckHasInfImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { + check_exec(src.layout, dst.layout, workspace.size); + + auto handle = static_cast(this->handle()); + MEGDNN_DISPATCH_CPU_KERN( + handle, reduce_fwd(src.ptr(), dst.ptr(), + src.layout.total_nr_elems())); +} +} // namespace naive +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/check_has_inf/opr_impl.h b/dnn/src/naive/check_has_inf/opr_impl.h new file mode 100644 index 00000000..53e9c635 --- /dev/null +++ b/dnn/src/naive/check_has_inf/opr_impl.h @@ -0,0 +1,35 @@ +/** + * \file dnn/src/naive/check_has_inf/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megdnn/oprs.h" + +namespace megdnn { +namespace naive { + +class CheckHasInfImpl final : public CheckHasInf { +public: + using CheckHasInf::CheckHasInf; + + bool is_thread_safe() const override { return true; } + + size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& dst) override; + + void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; +}; + +} // namespace naive +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index e6138c4a..7cb31fb0 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -21,6 +21,7 @@ #include "src/naive/batch_conv_bias/opr_impl.h" #include "src/naive/batch_normalization/opr_impl.h" #include "src/naive/batched_matrix_mul/opr_impl.h" +#include "src/naive/check_has_inf/opr_impl.h" #include "src/naive/checksum/opr_impl.h" #include "src/naive/concat/opr_impl.h" #include "src/naive/cond_take/opr_impl.h" diff --git a/dnn/src/rocm/reduce/reduce.cpp.hip b/dnn/src/rocm/reduce/reduce.cpp.hip index 88dfa632..6a3eb727 100644 --- a/dnn/src/rocm/reduce/reduce.cpp.hip +++ b/dnn/src/rocm/reduce/reduce.cpp.hip @@ -18,15 +18,15 @@ namespace rocm { using namespace reduce; -#define COMMOA , +#define COMMA , #define INST(sctype, dctype, wtype) \ - INST_REDUCE(SumOp, false); \ - INST_REDUCE(SumSqrOp, false); \ - INST_REDUCE(ProdOp, false); \ - INST_REDUCE(MinOp, false); \ - INST_REDUCE(MaxOp, false); \ - INST_REDUCE(MeanOp, false); + INST_REDUCE(SumOp, false); \ + INST_REDUCE(SumSqrOp, false); \ + INST_REDUCE(ProdOp, false); \ + INST_REDUCE(MinOp, false); \ + INST_REDUCE(MaxOp, false); \ + INST_REDUCE(MeanOp, false); #define cb(_dt) \ INST(DTypeTrait<_dt>::ctype, DTypeTrait<_dt>::ctype, DTypeTrait<_dt>::ctype) @@ -39,6 +39,7 @@ INST(float, dt_float16, float) INST(int, float, float) #undef cb #undef INST +#undef COMMA } // namespace rocm } // namespace megdnn diff --git a/dnn/test/common/checker.cpp b/dnn/test/common/checker.cpp index 8c5bbf90..656175e5 100644 --- a/dnn/test/common/checker.cpp +++ b/dnn/test/common/checker.cpp @@ -23,7 +23,7 @@ namespace { ::testing::AssertionResult assert_tensor_eq_with_iter( const char *expr0, const char *expr1, Iter it0, Iter it1, const TensorLayout &layout, - float maxerr, float maxerr_avg, float maxerr_avg_biased) { + float maxerr, float maxerr_avg, float maxerr_avg_biased, bool allow_invalid) { auto nr_elem = layout.total_nr_elems(); double error_sum = 0; @@ -33,8 +33,8 @@ namespace { float err = diff(iv0, iv1); error_sum += std::abs(err); error_sum_biased += err; - if (!good_float(iv0) || !good_float(iv1) || - std::abs(err) > maxerr) { + if (!allow_invalid && (!good_float(iv0) || !good_float(iv1) || + std::abs(err) > maxerr)) { Index index(layout, i); return ::testing::AssertionFailure() << "Unequal value\n" @@ -82,14 +82,14 @@ namespace { ::testing::AssertionResult assert_tensor_eq_with_dtype( const char *expr0, const char *expr1, const TensorND &v0, const TensorND &v1, - float maxerr, float maxerr_avg, float maxerr_avg_biased) { + float maxerr, float maxerr_avg, float maxerr_avg_biased, bool allow_invalid) { if (!std::is_same::value && !std::is_same::value) { if (v0.layout.is_physical_contiguous() && v1.layout.is_physical_contiguous()) { return assert_tensor_eq_with_iter( expr0, expr1, v0.ptr(), v1.ptr(), - v0.layout, maxerr, maxerr_avg, maxerr_avg_biased); + v0.layout, maxerr, maxerr_avg, maxerr_avg_biased, allow_invalid); } } @@ -98,7 +98,7 @@ namespace { return assert_tensor_eq_with_iter(expr0, expr1, it0, it1, v0.layout, maxerr, maxerr_avg, - maxerr_avg_biased); + maxerr_avg_biased, allow_invalid); } template @@ -136,7 +136,7 @@ namespace { const char* /*expr_maxerr_avg*/, const char* /*expr_maxerr_avg*/, const TensorND &v0, const TensorND &v1, - float maxerr, float maxerr_avg, float maxerr_avg_biased) { + float maxerr, float maxerr_avg, float maxerr_avg_biased, bool allow_invalid) { if (!v0.layout.eq_shape(v1.layout)) { return ::testing::AssertionFailure() @@ -160,7 +160,7 @@ namespace { #define cb(_dt) \ case DTypeTrait<_dt>::enumv: \ return assert_tensor_eq_with_dtype::ctype>( \ - expr0, expr1, v0, v1, maxerr, maxerr_avg, maxerr_avg_biased); + expr0, expr1, v0, v1, maxerr, maxerr_avg, maxerr_avg_biased, allow_invalid); MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) //! In order to avoid an unnecessary increase in binary size, we just @@ -174,6 +174,17 @@ namespace { } +::testing::AssertionResult test::__assert_tensor_eq_allow_invalid( + const char* expr0, const char* expr1, const char* expr_maxerr, + const char* expr_maxerr_avg, const char* expr_maxerr_avg_biased, + const TensorND& v0, const TensorND& v1, float maxerr, float maxerr_avg, + float maxerr_avg_biased) { + return __assert_tensor_eq(expr0, expr1, expr_maxerr, expr_maxerr_avg, + expr_maxerr_avg_biased, v0, v1, maxerr, + maxerr_avg, maxerr_avg_biased, true); +}; + + CheckerHelper::CheckerHelper(Handle *handle, bool check_dispatch): m_handle_cur(handle), m_default_rng(new NormalRNG()) @@ -411,9 +422,15 @@ void CheckerHelper::check_tensors(const TensorValueArray& expected, for (size_t i = 0; i < expected.size(); ++i) { if (expected[i].layout.ndim == 0) continue; - MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG(expected[i], computed[i], m_epsilon, - m_max_avg_error, - m_max_avg_biased_error); + if (m_allow_invalid_check) { + MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG_ALLOW_INVALID( + expected[i], computed[i], m_epsilon, m_max_avg_error, + m_max_avg_biased_error); + } else { + MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG(expected[i], computed[i], m_epsilon, + m_max_avg_error, + m_max_avg_biased_error); + } } } diff --git a/dnn/test/common/checker.h b/dnn/test/common/checker.h index c00632f2..48337100 100644 --- a/dnn/test/common/checker.h +++ b/dnn/test/common/checker.h @@ -79,6 +79,7 @@ protected: bool m_no_naive_and_check = false; bool m_stable_check = false; bool m_force_deduce_dst = true; + bool m_allow_invalid_check = false; /** * the offset from the start of malloc memory * @@ -248,6 +249,11 @@ public: return *this; } + Checker& set_allow_invalid_check(bool allow_invalid_check) { + m_allow_invalid_check = allow_invalid_check; + return *this; + } + //! load input tensors from file for next run Checker& load_input_tensors(const char* fpath) { m_input_tensors_fpath = fpath; @@ -329,6 +335,12 @@ private: const char* expr0, const char* expr1, const char* expr_maxerr, const char* expr_maxerr_avg, const char* expr_maxerr_avg_biased, const TensorND& v0, const TensorND& v1, float maxerr, float maxerr_avg, + float maxerr_avg_biased, bool allow_invalid = false); + +::testing::AssertionResult __assert_tensor_eq_allow_invalid( + const char* expr0, const char* expr1, const char* expr_maxerr, + const char* expr_maxerr_avg, const char* expr_maxerr_avg_biased, + const TensorND& v0, const TensorND& v1, float maxerr, float maxerr_avg, float maxerr_avg_biased); #define MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG(v0, v1, maxerr, maxerr_avg, \ @@ -336,6 +348,11 @@ private: ASSERT_PRED_FORMAT5(::megdnn::test::__assert_tensor_eq, v0, v1, maxerr, \ maxerr_avg, maxerr_avg_biased) +#define MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG_ALLOW_INVALID( \ + v0, v1, maxerr, maxerr_avg, maxerr_avg_biased) \ + ASSERT_PRED_FORMAT5(::megdnn::test::__assert_tensor_eq_allow_invalid, v0, \ + v1, maxerr, maxerr_avg, maxerr_avg_biased) + #define MEGDNN_ASSERT_TENSOR_EQ_EPS(v0, v1, maxerr) \ MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG(v0, v1, maxerr, maxerr, maxerr) @@ -435,7 +452,7 @@ TensorND TensorValue(const TensorShape& shape, T dtype, template TensorND TensorValueLowbit4(const TensorShape& shape, T dtype, - std::vector values) { + std::vector values) { TensorND tensor; tensor.layout = {shape, dtype}; tensor.raw_ptr = diff --git a/dnn/test/common/exec_proxy.h b/dnn/test/common/exec_proxy.h index 69ae544c..1393f357 100644 --- a/dnn/test/common/exec_proxy.h +++ b/dnn/test/common/exec_proxy.h @@ -39,6 +39,22 @@ struct ExecProxy { }; template +struct ExecProxy { + WorkspaceWrapper W; + void exec(Opr* opr, const TensorNDArray& tensors) { + if (!W.valid()) { + W = WorkspaceWrapper(opr->handle(), 0); + } + W.update(opr->get_workspace_in_bytes( + tensors[0].layout, tensors[1].layout, tensors[2].layout, + tensors[3].layout, tensors[4].layout, tensors[5].layout, + tensors[6].layout)); + opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], + tensors[5], tensors[6], W.workspace()); + } +}; + +template struct ExecProxy { WorkspaceWrapper W; void exec(Opr* opr, const TensorNDArray& tensors) { @@ -149,24 +165,6 @@ struct ExecProxy { } }; -template -struct ExecProxy { - WorkspaceWrapper W; - - void exec(Opr* opr, const TensorNDArray& tensors) { - if (!W.valid()) { - W = WorkspaceWrapper(opr->handle(), 0); - } - W.update(opr->get_workspace_in_bytes( - tensors[0].layout, tensors[1].layout, tensors[2].layout, - tensors[3].layout, tensors[4].layout, tensors[5].layout, - tensors[6].layout)); - - opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], - tensors[5], tensors[6], W.workspace()); - } -}; - } // namespace test } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/test/common/rng.cpp b/dnn/test/common/rng.cpp index 6e6e9fbf..70921a87 100644 --- a/dnn/test/common/rng.cpp +++ b/dnn/test/common/rng.cpp @@ -301,9 +301,8 @@ void UniformFloatNonZeroRNG::fill_fast_float32(dt_float32* dest, size_t size) { } } -void UniformFloatWithZeroRNG::fill_fast_float32(dt_float32 *dest, size_t size) { +void UniformFloatWithValueRNG::fill_fast_float32(dt_float32 *dest, size_t size) { RNGxorshf gen{RandomState::generator()}; - printf("a %f, b %f \n", m_dist.a(), m_dist.b()); auto k = double(m_dist.b() - m_dist.a()) / double(RNGxorshf::max() - RNGxorshf::min() + 1.0); auto b = m_dist.a() - RNGxorshf::min() * k; @@ -312,9 +311,8 @@ void UniformFloatWithZeroRNG::fill_fast_float32(dt_float32 *dest, size_t size) { auto pb = 0.f - RNGxorshf::min() * p; for (size_t i = 0; i < size; ++ i) { float rnd = gen() * p + pb; - //printf("%.3f \n", rnd); - if(rnd < zero_val_proportion_) { - dest[i] = 0.f; + if(rnd < val_proportion_) { + dest[i] = val_; } else { dest[i] = gen() * k + b; } diff --git a/dnn/test/common/rng.h b/dnn/test/common/rng.h index 324a0518..6b0d2cc0 100644 --- a/dnn/test/common/rng.h +++ b/dnn/test/common/rng.h @@ -11,10 +11,10 @@ #pragma once #include "megdnn/dtype.h" -#include "test/common/utils.h" -#include "test/common/random_state.h" #include #include +#include "test/common/random_state.h" +#include "test/common/utils.h" namespace megdnn { namespace test { @@ -80,7 +80,8 @@ public: } void gen(const TensorND& tensor) override { - megdnn_assert(tensor.layout.dtype.enumv() == DTypeTrait::enumv); + megdnn_assert(tensor.layout.dtype.enumv() == + DTypeTrait::enumv); size_t nr_elems = tensor.layout.span().dist_elem(); auto offset = tensor.layout.span().low_elem; for (size_t i = 0; i < nr_elems; ++i) { @@ -185,24 +186,31 @@ public: void fill_fast_float32(dt_float32* dest, size_t size) override; }; -class UniformFloatWithZeroRNG final : public UniformFloatRNG { +class UniformFloatWithValueRNG : public UniformFloatRNG { public: - UniformFloatWithZeroRNG(dt_float32 a, dt_float32 b, - float zero_val_proportion) - : UniformFloatRNG(a, b) { - if (zero_val_proportion < 0.f) - zero_val_proportion_ = 0.f; - else if (zero_val_proportion > 1.f) - zero_val_proportion_ = 1.f; + UniformFloatWithValueRNG(dt_float32 a, dt_float32 b, float val_proportion, + float val) + : UniformFloatRNG(a, b), val_(val) { + if (val_proportion < 0.f) + val_proportion_ = 0.f; + else if (val_proportion > 1.f) + val_proportion_ = 1.f; else - zero_val_proportion_ = zero_val_proportion; + val_proportion_ = val_proportion; } private: - float zero_val_proportion_; + float val_proportion_, val_; void fill_fast_float32(dt_float32* dest, size_t size) override; }; +class UniformFloatWithZeroRNG final : public UniformFloatWithValueRNG { +public: + UniformFloatWithZeroRNG(dt_float32 a, dt_float32 b, + float zero_val_proportion) + : UniformFloatWithValueRNG(a, b, zero_val_proportion, 0.f) {} +}; + class BernoulliRNG final : public IIDRNG { public: BernoulliRNG(dt_float32 probability_); diff --git a/dnn/test/cuda/check_has_inf.cpp b/dnn/test/cuda/check_has_inf.cpp new file mode 100644 index 00000000..1e452514 --- /dev/null +++ b/dnn/test/cuda/check_has_inf.cpp @@ -0,0 +1,33 @@ +/** + * \file dnn/test/cuda/check_has_inf.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "megdnn/oprs.h" +#include "test/common/checker.h" +#include "test/cuda/fixture.h" + +namespace megdnn { +namespace test { + +TEST_F(CUDA, CHECK_HAS_INF_BASIC) { + Checker checker(handle_cuda()); + checker.set_allow_invalid_check(true); + const auto inf = std::numeric_limits::infinity(); + UniformFloatWithValueRNG rng(-1.0f, 1.0f, 0.1f, inf); + checker.set_rng(0, &rng); + checker.execs({{512*16}, {1}}); + rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, inf); + checker.set_rng(0, &rng); + checker.execs({{512*16}, {1}}); +} + +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/test/naive/check_has_inf.cpp b/dnn/test/naive/check_has_inf.cpp new file mode 100644 index 00000000..1532a7c3 --- /dev/null +++ b/dnn/test/naive/check_has_inf.cpp @@ -0,0 +1,37 @@ +/** + * \file test/naive/check_has_inf.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "test/naive/fixture.h" + +#include "megdnn/oprs.h" +#include "test/common/checker.h" + +namespace megdnn { +namespace test { + +TEST_F(NAIVE, CHECK_HAS_INF_BASIC) { + Checker checker(handle(), false); + checker.exect(Testcase{TensorValue({4}, dtype::Float32(), + {1.1, 2.2, 3.3, 4.3}), + {}}, + Testcase{{}, TensorValue({1}, dtype::Int32(), {0})}); + checker.exect( + Testcase{TensorValue({4}, dtype::Float32(), + {1.1f, 2.2f, 3.3f, + std::numeric_limits::infinity()}), + {}}, + Testcase{{}, TensorValue({1}, dtype::Int32(), {1})}); +} + +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index f2187b3f..fbc021a6 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -959,3 +959,16 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv) U, sigma, V = apply(op, inp) return U, sigma, V + + +def _has_inf(inp: Tensor) -> Tensor: + """ + Check whether input contains infinite value. + + :param inp: a tensor to be checked. + :return: a int32 scalar tensor, 0 for False and 1 for True. + """ + op = builtin.CheckHasInf() + (oup,) = apply(op, inp.reshape(-1).astype("float32")) + oup._setscalar() + return oup diff --git a/imperative/python/test/unit/functional/test_math.py b/imperative/python/test/unit/functional/test_math.py index 661e40d1..a14e8b54 100644 --- a/imperative/python/test/unit/functional/test_math.py +++ b/imperative/python/test/unit/functional/test_math.py @@ -157,3 +157,14 @@ def test_sum_neg_axis(): np.testing.assert_allclose(get.numpy(), ref, rtol=1e-6) with pytest.raises(AssertionError): F.sum(tensor(data), axis=(-1, 1)) + + +def test_has_inf(): + shape = (32, 3, 32, 32) + data = np.random.random(shape).astype(np.float32) + rst = F.math._has_inf(tensor(data)) + np.testing.assert_equal(rst.numpy(), [0]) + + data[0][0][0][0] = float("inf") + rst = F.math._has_inf(tensor(data)) + np.testing.assert_equal(rst.numpy(), [1]) diff --git a/imperative/src/impl/ops/misc.cpp b/imperative/src/impl/ops/misc.cpp new file mode 100644 index 00000000..d08e09b8 --- /dev/null +++ b/imperative/src/impl/ops/misc.cpp @@ -0,0 +1,34 @@ +/** + * \file imperative/src/impl/ops/tensor_manip.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "../op_trait.h" + +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/opr/misc.h" + +namespace mgb { +namespace imperative { + +namespace check_has_inf { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = def.cast_final_safe(); + mgb_assert(inputs.size() == 1); + OperatorNodeConfig config{op.make_name()}; + return opr::CheckHasInf::make(inputs[0], {}, config); +} +OP_TRAIT_REG(CheckHasInf, CheckHasInf) + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace check_has_inf + +} // namespace imperative +} // namespace mgb + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 3a3fa718..8da5647e 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -307,4 +307,6 @@ def CambriconRuntime: MgbHashableOp<"CambriconRuntime"> { def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>; +def CheckHasInf: MgbHashableOp<"CheckHasInf", [EmptyParam]>; + #endif // MGB_OPS diff --git a/src/opr/impl/misc.cpp b/src/opr/impl/misc.cpp index 558ad0a2..b8c11278 100644 --- a/src/opr/impl/misc.cpp +++ b/src/opr/impl/misc.cpp @@ -437,4 +437,19 @@ MGB_IMPL_OPR_GRAD(TopK) { } #endif +/* ================= CheckHasInf ================= */ +namespace mgb { +namespace opr { +namespace intl { +template<> +struct MegDNNOprInitPostCtor { + static void apply(cg::OperatorNodeBase &opr) { + opr.output(0)->dtype(dtype::Int32()); + } +}; +} +} +} +MGB_DYN_TYPE_OBJ_FINAL_IMPL(CheckHasInf); +MEGDNN_OPR_INIT1(CheckHasInf, "check_has_inf") // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/misc.sereg.h b/src/opr/impl/misc.sereg.h index 3e4c16e3..15fc8aae 100644 --- a/src/opr/impl/misc.sereg.h +++ b/src/opr/impl/misc.sereg.h @@ -73,6 +73,7 @@ namespace opr { #if MGB_CUDA MGB_SEREG_OPR(NvOf, 1); #endif + MGB_SEREG_OPR(CheckHasInf, 1); } // namespace opr } // namespace mgb diff --git a/src/opr/include/megbrain/opr/misc.h b/src/opr/include/megbrain/opr/misc.h index f3559d95..79c0a177 100644 --- a/src/opr/include/megbrain/opr/misc.h +++ b/src/opr/include/megbrain/opr/misc.h @@ -178,6 +178,8 @@ public: const OperatorNodeConfig& config = {}); }; +MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(CheckHasInf); + } // namespace opr } // namespace mgb