GitOrigin-RevId: 0d042dbfce
release-1.5
@@ -1317,6 +1317,27 @@ protected: | |||
TensorLayout& exec_workspace, | |||
TensorLayout& exec_src, TensorLayout& exec_dst); | |||
}; | |||
/*! | |||
* \brief check whether input contains inf value. | |||
*/ | |||
class CheckHasInf: public OperatorBase { | |||
DEF_OPR_PARAM(Empty); | |||
DEF_OPR_IMPL(CheckHasInf, OperatorBase, 1, 1); | |||
public: | |||
virtual size_t get_workspace_in_bytes(const TensorLayout &src, | |||
const TensorLayout &dst) = 0; | |||
void deduce_layout(const TensorLayout &src, TensorLayout &dst); | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
protected: | |||
void check_exec(const TensorLayout &src, const TensorLayout &dst, | |||
size_t workspace_in_bytes); | |||
}; | |||
} // namespace megdnn | |||
#include "megdnn/internal/opr_header_epilogue.h" | |||
@@ -0,0 +1,36 @@ | |||
/** | |||
* \file dnn/src/common/check_has_inf.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
void CheckHasInf::check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
size_t workspace_in_bytes) { | |||
megdnn_assert_contiguous(src); | |||
megdnn_assert_contiguous(dst); | |||
megdnn_assert(src.ndim == 1); | |||
megdnn_assert(src.dtype == dtype::Float32()); | |||
auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
} | |||
void CheckHasInf::deduce_layout(const TensorLayout&, TensorLayout& dst) { | |||
dst.shape[0] = 1; | |||
dst.ndim = 1; | |||
dst.dtype = dtype::Int32(); | |||
dst.init_contiguous_stride(); | |||
} | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -207,7 +207,8 @@ private: | |||
cb(FakeQuantForward) \ | |||
cb(FakeQuantBackward) \ | |||
cb(TQTForward) \ | |||
cb(TQTBackward) | |||
cb(TQTBackward) \ | |||
cb(CheckHasInf) | |||
/*! | |||
* \brief specialize HandleImpl::create_operator for a single opr type; | |||
@@ -120,6 +120,7 @@ DEF(PowC, 2, false, true); | |||
DEF(UniformRNG, 1, true, true); | |||
DEF(GaussianRNG, 1, true, true); | |||
DEF(ChecksumForward, 1, true, false); | |||
DEF(CheckHasInf, 2, true, true); | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -4,9 +4,9 @@ | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/dtype.h" | |||
@@ -151,6 +151,33 @@ struct MaxOp { | |||
: INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | |||
}; | |||
template <typename src_ctype, typename dst_ctype, typename wtype_> | |||
struct CheckHasInfOp { | |||
typedef wtype_ wtype; | |||
const wtype INIT; | |||
src_ctype* src; | |||
dst_ctype* dst; | |||
const size_t B; | |||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { | |||
#if defined(__CUDA_ARCH__) | |||
return isinf(src[idx]); | |||
#else | |||
return std::isinf(src[idx]); | |||
#endif | |||
} | |||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { | |||
dst[idx] = val; | |||
} | |||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||
return lhs | rhs; | |||
} | |||
MEGDNN_HOST MEGDNN_DEVICE CheckHasInfOp(src_ctype* src, dst_ctype* dst, | |||
size_t B) | |||
: INIT(wtype(0)), src(src), dst(dst), B(B) {} | |||
}; | |||
#if MEGDNN_CC_HOST | |||
void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, | |||
size_t axis); | |||
@@ -0,0 +1,27 @@ | |||
/** | |||
* \file dnn/src/cuda/check_has_inf/kern.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/common/reduce_helper.h" | |||
#include "megdnn/dtype.h" | |||
#include "src/cuda/reduce_helper.cuh" | |||
namespace megdnn { | |||
namespace cuda { | |||
#define COMMA , | |||
INST_REDUCE(reduce::CheckHasInfOp<dt_float32 COMMA dt_int32 COMMA dt_int32>, false); | |||
#undef COMMA | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: ft=cpp syntax=cpp.doxygen |
@@ -0,0 +1,45 @@ | |||
/** | |||
* \file dnn/src/cuda/check_has_inf/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/cuda/check_has_inf/opr_impl.h" | |||
#include "src/cuda/reduce_helper.cuh" | |||
#include "src/cuda/handle.h" | |||
#include "src/cuda/utils.h" | |||
#include "src/common/reduce_helper.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
using reduce::CheckHasInfOp; | |||
size_t CheckHasInfImpl::get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& dst) { | |||
typedef CheckHasInfOp<dt_float32, dt_int32, dt_int32> Op; | |||
return get_reduce_workspace_in_bytes<Op>(1, src.total_nr_elems(), 1); | |||
} | |||
void CheckHasInfImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) { | |||
check_exec(src.layout, dst.layout, workspace.size); | |||
typedef CheckHasInfOp<dt_float32, dt_int32, dt_int32> Op; | |||
auto stream = cuda_stream(this->handle()); | |||
auto B = src.layout.total_nr_elems(); | |||
return run_reduce<Op, false>( | |||
workspace.ptr<dt_int32>(), 1, B, 1, stream, | |||
Op(src.ptr<dt_float32>(), dst.ptr<dt_int32>(), B)); | |||
} | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,36 @@ | |||
/** | |||
* \file dnn/src/cuda/check_has_inf/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs/utils.h" | |||
#include "src/cuda/utils.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
class CheckHasInfImpl final : public CheckHasInf { | |||
public: | |||
using CheckHasInf::CheckHasInf; | |||
size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& dst) override; | |||
bool is_thread_safe() const override { return true; } | |||
void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) override; | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -17,6 +17,7 @@ | |||
#include "src/cuda/argsort/opr_impl.h" | |||
#include "src/cuda/batch_normalization/opr_impl.h" | |||
#include "src/cuda/batched_matrix_mul/opr_impl.h" | |||
#include "src/cuda/check_has_inf/opr_impl.h" | |||
#include "src/cuda/checksum/opr_impl.h" | |||
#include "src/cuda/concat/opr_impl.h" | |||
#include "src/cuda/cond_take/opr_impl.h" | |||
@@ -18,15 +18,15 @@ namespace cuda { | |||
using namespace reduce; | |||
#define COMMOA , | |||
#define COMMA , | |||
#define INST(sctype, dctype, wtype) \ | |||
INST_REDUCE(SumOp<sctype COMMOA dctype COMMOA wtype>, false); \ | |||
INST_REDUCE(SumSqrOp<sctype COMMOA dctype COMMOA wtype>, false); \ | |||
INST_REDUCE(ProdOp<sctype COMMOA dctype COMMOA wtype>, false); \ | |||
INST_REDUCE(MinOp<sctype COMMOA dctype COMMOA wtype>, false); \ | |||
INST_REDUCE(MaxOp<sctype COMMOA dctype COMMOA wtype>, false); \ | |||
INST_REDUCE(MeanOp<sctype COMMOA dctype COMMOA wtype>, false); | |||
INST_REDUCE(SumOp<sctype COMMA dctype COMMA wtype>, false); \ | |||
INST_REDUCE(SumSqrOp<sctype COMMA dctype COMMA wtype>, false); \ | |||
INST_REDUCE(ProdOp<sctype COMMA dctype COMMA wtype>, false); \ | |||
INST_REDUCE(MinOp<sctype COMMA dctype COMMA wtype>, false); \ | |||
INST_REDUCE(MaxOp<sctype COMMA dctype COMMA wtype>, false); \ | |||
INST_REDUCE(MeanOp<sctype COMMA dctype COMMA wtype>, false); | |||
#define cb(_dt) \ | |||
INST(DTypeTrait<_dt>::ctype, DTypeTrait<_dt>::ctype, DTypeTrait<_dt>::ctype) | |||
@@ -40,6 +40,7 @@ INST(int, float, float) | |||
#undef cb | |||
#undef INST | |||
#undef COMMA | |||
} // namespace cuda | |||
} // namespace megdnn | |||
@@ -0,0 +1,59 @@ | |||
/** | |||
* \file dnn/src/naive/check_has_inf/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/naive/check_has_inf/opr_impl.h" | |||
#include "src/common/utils.h" | |||
#include "src/naive/handle.h" | |||
namespace { | |||
using namespace megdnn; | |||
#define src_ctype dt_float32 | |||
#define wtype dt_int32 | |||
void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) { | |||
std::function<wtype(size_t, size_t)> func; | |||
func = [&](size_t l, size_t r) -> wtype { | |||
if (l + 1 < r) { | |||
size_t mid = l + (r - l) / 2; | |||
return func(l, mid) | func(mid, r); | |||
} else { | |||
return static_cast<wtype>(std::isinf(sptr[l])); | |||
} | |||
}; | |||
dptr[0] = func(0, size); | |||
} | |||
} // namespace | |||
namespace megdnn { | |||
namespace naive { | |||
size_t CheckHasInfImpl::get_workspace_in_bytes(const TensorLayout&, | |||
const TensorLayout&) { | |||
return 0; | |||
} | |||
void CheckHasInfImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) { | |||
check_exec(src.layout, dst.layout, workspace.size); | |||
auto handle = static_cast<HandleImpl*>(this->handle()); | |||
MEGDNN_DISPATCH_CPU_KERN( | |||
handle, reduce_fwd(src.ptr<dt_float32>(), dst.ptr<dt_int32>(), | |||
src.layout.total_nr_elems())); | |||
} | |||
} // namespace naive | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,35 @@ | |||
/** | |||
* \file dnn/src/naive/check_has_inf/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
namespace megdnn { | |||
namespace naive { | |||
class CheckHasInfImpl final : public CheckHasInf { | |||
public: | |||
using CheckHasInf::CheckHasInf; | |||
bool is_thread_safe() const override { return true; } | |||
size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& dst) override; | |||
void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) override; | |||
}; | |||
} // namespace naive | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -21,6 +21,7 @@ | |||
#include "src/naive/batch_conv_bias/opr_impl.h" | |||
#include "src/naive/batch_normalization/opr_impl.h" | |||
#include "src/naive/batched_matrix_mul/opr_impl.h" | |||
#include "src/naive/check_has_inf/opr_impl.h" | |||
#include "src/naive/checksum/opr_impl.h" | |||
#include "src/naive/concat/opr_impl.h" | |||
#include "src/naive/cond_take/opr_impl.h" | |||
@@ -18,15 +18,15 @@ namespace rocm { | |||
using namespace reduce; | |||
#define COMMOA , | |||
#define COMMA , | |||
#define INST(sctype, dctype, wtype) \ | |||
INST_REDUCE(SumOp<sctype COMMOA dctype COMMOA wtype>, false); \ | |||
INST_REDUCE(SumSqrOp<sctype COMMOA dctype COMMOA wtype>, false); \ | |||
INST_REDUCE(ProdOp<sctype COMMOA dctype COMMOA wtype>, false); \ | |||
INST_REDUCE(MinOp<sctype COMMOA dctype COMMOA wtype>, false); \ | |||
INST_REDUCE(MaxOp<sctype COMMOA dctype COMMOA wtype>, false); \ | |||
INST_REDUCE(MeanOp<sctype COMMOA dctype COMMOA wtype>, false); | |||
INST_REDUCE(SumOp<sctype COMMA dctype COMMA wtype>, false); \ | |||
INST_REDUCE(SumSqrOp<sctype COMMA dctype COMMA wtype>, false); \ | |||
INST_REDUCE(ProdOp<sctype COMMA dctype COMMA wtype>, false); \ | |||
INST_REDUCE(MinOp<sctype COMMA dctype COMMA wtype>, false); \ | |||
INST_REDUCE(MaxOp<sctype COMMA dctype COMMA wtype>, false); \ | |||
INST_REDUCE(MeanOp<sctype COMMA dctype COMMA wtype>, false); | |||
#define cb(_dt) \ | |||
INST(DTypeTrait<_dt>::ctype, DTypeTrait<_dt>::ctype, DTypeTrait<_dt>::ctype) | |||
@@ -39,6 +39,7 @@ INST(float, dt_float16, float) | |||
INST(int, float, float) | |||
#undef cb | |||
#undef INST | |||
#undef COMMA | |||
} // namespace rocm | |||
} // namespace megdnn | |||
@@ -23,7 +23,7 @@ namespace { | |||
::testing::AssertionResult assert_tensor_eq_with_iter( | |||
const char *expr0, const char *expr1, | |||
Iter it0, Iter it1, const TensorLayout &layout, | |||
float maxerr, float maxerr_avg, float maxerr_avg_biased) { | |||
float maxerr, float maxerr_avg, float maxerr_avg_biased, bool allow_invalid) { | |||
auto nr_elem = layout.total_nr_elems(); | |||
double error_sum = 0; | |||
@@ -33,8 +33,8 @@ namespace { | |||
float err = diff(iv0, iv1); | |||
error_sum += std::abs(err); | |||
error_sum_biased += err; | |||
if (!good_float(iv0) || !good_float(iv1) || | |||
std::abs(err) > maxerr) { | |||
if (!allow_invalid && (!good_float(iv0) || !good_float(iv1) || | |||
std::abs(err) > maxerr)) { | |||
Index index(layout, i); | |||
return ::testing::AssertionFailure() | |||
<< "Unequal value\n" | |||
@@ -82,14 +82,14 @@ namespace { | |||
::testing::AssertionResult assert_tensor_eq_with_dtype( | |||
const char *expr0, const char *expr1, | |||
const TensorND &v0, const TensorND &v1, | |||
float maxerr, float maxerr_avg, float maxerr_avg_biased) { | |||
float maxerr, float maxerr_avg, float maxerr_avg_biased, bool allow_invalid) { | |||
if (!std::is_same<ctype, dt_qint4>::value && | |||
!std::is_same<ctype, dt_quint4>::value) { | |||
if (v0.layout.is_physical_contiguous() && | |||
v1.layout.is_physical_contiguous()) { | |||
return assert_tensor_eq_with_iter<ctype>( | |||
expr0, expr1, v0.ptr<ctype>(), v1.ptr<ctype>(), | |||
v0.layout, maxerr, maxerr_avg, maxerr_avg_biased); | |||
v0.layout, maxerr, maxerr_avg, maxerr_avg_biased, allow_invalid); | |||
} | |||
} | |||
@@ -98,7 +98,7 @@ namespace { | |||
return assert_tensor_eq_with_iter<ctype>(expr0, expr1, it0, it1, | |||
v0.layout, maxerr, maxerr_avg, | |||
maxerr_avg_biased); | |||
maxerr_avg_biased, allow_invalid); | |||
} | |||
template<class Impl> | |||
@@ -136,7 +136,7 @@ namespace { | |||
const char* /*expr_maxerr_avg*/, | |||
const char* /*expr_maxerr_avg*/, | |||
const TensorND &v0, const TensorND &v1, | |||
float maxerr, float maxerr_avg, float maxerr_avg_biased) { | |||
float maxerr, float maxerr_avg, float maxerr_avg_biased, bool allow_invalid) { | |||
if (!v0.layout.eq_shape(v1.layout)) { | |||
return ::testing::AssertionFailure() | |||
@@ -160,7 +160,7 @@ namespace { | |||
#define cb(_dt) \ | |||
case DTypeTrait<_dt>::enumv: \ | |||
return assert_tensor_eq_with_dtype<DTypeTrait<_dt>::ctype>( \ | |||
expr0, expr1, v0, v1, maxerr, maxerr_avg, maxerr_avg_biased); | |||
expr0, expr1, v0, v1, maxerr, maxerr_avg, maxerr_avg_biased, allow_invalid); | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||
//! In order to avoid an unnecessary increase in binary size, we just | |||
@@ -174,6 +174,17 @@ namespace { | |||
} | |||
::testing::AssertionResult test::__assert_tensor_eq_allow_invalid( | |||
const char* expr0, const char* expr1, const char* expr_maxerr, | |||
const char* expr_maxerr_avg, const char* expr_maxerr_avg_biased, | |||
const TensorND& v0, const TensorND& v1, float maxerr, float maxerr_avg, | |||
float maxerr_avg_biased) { | |||
return __assert_tensor_eq(expr0, expr1, expr_maxerr, expr_maxerr_avg, | |||
expr_maxerr_avg_biased, v0, v1, maxerr, | |||
maxerr_avg, maxerr_avg_biased, true); | |||
}; | |||
CheckerHelper::CheckerHelper(Handle *handle, bool check_dispatch): | |||
m_handle_cur(handle), | |||
m_default_rng(new NormalRNG()) | |||
@@ -411,9 +422,15 @@ void CheckerHelper::check_tensors(const TensorValueArray& expected, | |||
for (size_t i = 0; i < expected.size(); ++i) { | |||
if (expected[i].layout.ndim == 0) | |||
continue; | |||
MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG(expected[i], computed[i], m_epsilon, | |||
m_max_avg_error, | |||
m_max_avg_biased_error); | |||
if (m_allow_invalid_check) { | |||
MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG_ALLOW_INVALID( | |||
expected[i], computed[i], m_epsilon, m_max_avg_error, | |||
m_max_avg_biased_error); | |||
} else { | |||
MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG(expected[i], computed[i], m_epsilon, | |||
m_max_avg_error, | |||
m_max_avg_biased_error); | |||
} | |||
} | |||
} | |||
@@ -79,6 +79,7 @@ protected: | |||
bool m_no_naive_and_check = false; | |||
bool m_stable_check = false; | |||
bool m_force_deduce_dst = true; | |||
bool m_allow_invalid_check = false; | |||
/** | |||
* the offset from the start of malloc memory | |||
* | |||
@@ -248,6 +249,11 @@ public: | |||
return *this; | |||
} | |||
Checker& set_allow_invalid_check(bool allow_invalid_check) { | |||
m_allow_invalid_check = allow_invalid_check; | |||
return *this; | |||
} | |||
//! load input tensors from file for next run | |||
Checker& load_input_tensors(const char* fpath) { | |||
m_input_tensors_fpath = fpath; | |||
@@ -329,6 +335,12 @@ private: | |||
const char* expr0, const char* expr1, const char* expr_maxerr, | |||
const char* expr_maxerr_avg, const char* expr_maxerr_avg_biased, | |||
const TensorND& v0, const TensorND& v1, float maxerr, float maxerr_avg, | |||
float maxerr_avg_biased, bool allow_invalid = false); | |||
::testing::AssertionResult __assert_tensor_eq_allow_invalid( | |||
const char* expr0, const char* expr1, const char* expr_maxerr, | |||
const char* expr_maxerr_avg, const char* expr_maxerr_avg_biased, | |||
const TensorND& v0, const TensorND& v1, float maxerr, float maxerr_avg, | |||
float maxerr_avg_biased); | |||
#define MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG(v0, v1, maxerr, maxerr_avg, \ | |||
@@ -336,6 +348,11 @@ private: | |||
ASSERT_PRED_FORMAT5(::megdnn::test::__assert_tensor_eq, v0, v1, maxerr, \ | |||
maxerr_avg, maxerr_avg_biased) | |||
#define MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG_ALLOW_INVALID( \ | |||
v0, v1, maxerr, maxerr_avg, maxerr_avg_biased) \ | |||
ASSERT_PRED_FORMAT5(::megdnn::test::__assert_tensor_eq_allow_invalid, v0, \ | |||
v1, maxerr, maxerr_avg, maxerr_avg_biased) | |||
#define MEGDNN_ASSERT_TENSOR_EQ_EPS(v0, v1, maxerr) \ | |||
MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG(v0, v1, maxerr, maxerr, maxerr) | |||
@@ -435,7 +452,7 @@ TensorND TensorValue(const TensorShape& shape, T dtype, | |||
template <typename T, typename U> | |||
TensorND TensorValueLowbit4(const TensorShape& shape, T dtype, | |||
std::vector<U> values) { | |||
std::vector<U> values) { | |||
TensorND tensor; | |||
tensor.layout = {shape, dtype}; | |||
tensor.raw_ptr = | |||
@@ -39,6 +39,22 @@ struct ExecProxy<Opr, 8, true> { | |||
}; | |||
template <typename Opr> | |||
struct ExecProxy<Opr, 7, true> { | |||
WorkspaceWrapper W; | |||
void exec(Opr* opr, const TensorNDArray& tensors) { | |||
if (!W.valid()) { | |||
W = WorkspaceWrapper(opr->handle(), 0); | |||
} | |||
W.update(opr->get_workspace_in_bytes( | |||
tensors[0].layout, tensors[1].layout, tensors[2].layout, | |||
tensors[3].layout, tensors[4].layout, tensors[5].layout, | |||
tensors[6].layout)); | |||
opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], | |||
tensors[5], tensors[6], W.workspace()); | |||
} | |||
}; | |||
template <typename Opr> | |||
struct ExecProxy<Opr, 6, true> { | |||
WorkspaceWrapper W; | |||
void exec(Opr* opr, const TensorNDArray& tensors) { | |||
@@ -149,24 +165,6 @@ struct ExecProxy<Opr, 2, false> { | |||
} | |||
}; | |||
template <typename Opr> | |||
struct ExecProxy<Opr, 7, true> { | |||
WorkspaceWrapper W; | |||
void exec(Opr* opr, const TensorNDArray& tensors) { | |||
if (!W.valid()) { | |||
W = WorkspaceWrapper(opr->handle(), 0); | |||
} | |||
W.update(opr->get_workspace_in_bytes( | |||
tensors[0].layout, tensors[1].layout, tensors[2].layout, | |||
tensors[3].layout, tensors[4].layout, tensors[5].layout, | |||
tensors[6].layout)); | |||
opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], | |||
tensors[5], tensors[6], W.workspace()); | |||
} | |||
}; | |||
} // namespace test | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -301,9 +301,8 @@ void UniformFloatNonZeroRNG::fill_fast_float32(dt_float32* dest, size_t size) { | |||
} | |||
} | |||
void UniformFloatWithZeroRNG::fill_fast_float32(dt_float32 *dest, size_t size) { | |||
void UniformFloatWithValueRNG::fill_fast_float32(dt_float32 *dest, size_t size) { | |||
RNGxorshf gen{RandomState::generator()}; | |||
printf("a %f, b %f \n", m_dist.a(), m_dist.b()); | |||
auto k = double(m_dist.b() - m_dist.a()) / | |||
double(RNGxorshf::max() - RNGxorshf::min() + 1.0); | |||
auto b = m_dist.a() - RNGxorshf::min() * k; | |||
@@ -312,9 +311,8 @@ void UniformFloatWithZeroRNG::fill_fast_float32(dt_float32 *dest, size_t size) { | |||
auto pb = 0.f - RNGxorshf::min() * p; | |||
for (size_t i = 0; i < size; ++ i) { | |||
float rnd = gen() * p + pb; | |||
//printf("%.3f \n", rnd); | |||
if(rnd < zero_val_proportion_) { | |||
dest[i] = 0.f; | |||
if(rnd < val_proportion_) { | |||
dest[i] = val_; | |||
} else { | |||
dest[i] = gen() * k + b; | |||
} | |||
@@ -11,10 +11,10 @@ | |||
#pragma once | |||
#include "megdnn/dtype.h" | |||
#include "test/common/utils.h" | |||
#include "test/common/random_state.h" | |||
#include <random> | |||
#include <set> | |||
#include "test/common/random_state.h" | |||
#include "test/common/utils.h" | |||
namespace megdnn { | |||
namespace test { | |||
@@ -80,7 +80,8 @@ public: | |||
} | |||
void gen(const TensorND& tensor) override { | |||
megdnn_assert(tensor.layout.dtype.enumv() == DTypeTrait<dt_bfloat16>::enumv); | |||
megdnn_assert(tensor.layout.dtype.enumv() == | |||
DTypeTrait<dt_bfloat16>::enumv); | |||
size_t nr_elems = tensor.layout.span().dist_elem(); | |||
auto offset = tensor.layout.span().low_elem; | |||
for (size_t i = 0; i < nr_elems; ++i) { | |||
@@ -185,24 +186,31 @@ public: | |||
void fill_fast_float32(dt_float32* dest, size_t size) override; | |||
}; | |||
class UniformFloatWithZeroRNG final : public UniformFloatRNG { | |||
class UniformFloatWithValueRNG : public UniformFloatRNG { | |||
public: | |||
UniformFloatWithZeroRNG(dt_float32 a, dt_float32 b, | |||
float zero_val_proportion) | |||
: UniformFloatRNG(a, b) { | |||
if (zero_val_proportion < 0.f) | |||
zero_val_proportion_ = 0.f; | |||
else if (zero_val_proportion > 1.f) | |||
zero_val_proportion_ = 1.f; | |||
UniformFloatWithValueRNG(dt_float32 a, dt_float32 b, float val_proportion, | |||
float val) | |||
: UniformFloatRNG(a, b), val_(val) { | |||
if (val_proportion < 0.f) | |||
val_proportion_ = 0.f; | |||
else if (val_proportion > 1.f) | |||
val_proportion_ = 1.f; | |||
else | |||
zero_val_proportion_ = zero_val_proportion; | |||
val_proportion_ = val_proportion; | |||
} | |||
private: | |||
float zero_val_proportion_; | |||
float val_proportion_, val_; | |||
void fill_fast_float32(dt_float32* dest, size_t size) override; | |||
}; | |||
class UniformFloatWithZeroRNG final : public UniformFloatWithValueRNG { | |||
public: | |||
UniformFloatWithZeroRNG(dt_float32 a, dt_float32 b, | |||
float zero_val_proportion) | |||
: UniformFloatWithValueRNG(a, b, zero_val_proportion, 0.f) {} | |||
}; | |||
class BernoulliRNG final : public IIDRNG { | |||
public: | |||
BernoulliRNG(dt_float32 probability_); | |||
@@ -0,0 +1,33 @@ | |||
/** | |||
* \file dnn/test/cuda/check_has_inf.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megdnn/oprs.h" | |||
#include "test/common/checker.h" | |||
#include "test/cuda/fixture.h" | |||
namespace megdnn { | |||
namespace test { | |||
TEST_F(CUDA, CHECK_HAS_INF_BASIC) { | |||
Checker<CheckHasInf> checker(handle_cuda()); | |||
checker.set_allow_invalid_check(true); | |||
const auto inf = std::numeric_limits<float>::infinity(); | |||
UniformFloatWithValueRNG rng(-1.0f, 1.0f, 0.1f, inf); | |||
checker.set_rng(0, &rng); | |||
checker.execs({{512*16}, {1}}); | |||
rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, inf); | |||
checker.set_rng(0, &rng); | |||
checker.execs({{512*16}, {1}}); | |||
} | |||
} // namespace test | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,37 @@ | |||
/** | |||
* \file test/naive/check_has_inf.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "test/naive/fixture.h" | |||
#include "megdnn/oprs.h" | |||
#include "test/common/checker.h" | |||
namespace megdnn { | |||
namespace test { | |||
TEST_F(NAIVE, CHECK_HAS_INF_BASIC) { | |||
Checker<CheckHasInf> checker(handle(), false); | |||
checker.exect(Testcase{TensorValue({4}, dtype::Float32(), | |||
{1.1, 2.2, 3.3, 4.3}), | |||
{}}, | |||
Testcase{{}, TensorValue({1}, dtype::Int32(), {0})}); | |||
checker.exect( | |||
Testcase{TensorValue({4}, dtype::Float32(), | |||
{1.1f, 2.2f, 3.3f, | |||
std::numeric_limits<float>::infinity()}), | |||
{}}, | |||
Testcase{{}, TensorValue({1}, dtype::Int32(), {1})}); | |||
} | |||
} // namespace test | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -959,3 +959,16 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: | |||
op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv) | |||
U, sigma, V = apply(op, inp) | |||
return U, sigma, V | |||
def _has_inf(inp: Tensor) -> Tensor: | |||
""" | |||
Check whether input contains infinite value. | |||
:param inp: a tensor to be checked. | |||
:return: a int32 scalar tensor, 0 for False and 1 for True. | |||
""" | |||
op = builtin.CheckHasInf() | |||
(oup,) = apply(op, inp.reshape(-1).astype("float32")) | |||
oup._setscalar() | |||
return oup |
@@ -157,3 +157,14 @@ def test_sum_neg_axis(): | |||
np.testing.assert_allclose(get.numpy(), ref, rtol=1e-6) | |||
with pytest.raises(AssertionError): | |||
F.sum(tensor(data), axis=(-1, 1)) | |||
def test_has_inf(): | |||
shape = (32, 3, 32, 32) | |||
data = np.random.random(shape).astype(np.float32) | |||
rst = F.math._has_inf(tensor(data)) | |||
np.testing.assert_equal(rst.numpy(), [0]) | |||
data[0][0][0][0] = float("inf") | |||
rst = F.math._has_inf(tensor(data)) | |||
np.testing.assert_equal(rst.numpy(), [1]) |
@@ -0,0 +1,34 @@ | |||
/** | |||
* \file imperative/src/impl/ops/tensor_manip.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "../op_trait.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/opr/misc.h" | |||
namespace mgb { | |||
namespace imperative { | |||
namespace check_has_inf { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = def.cast_final_safe<CheckHasInf>(); | |||
mgb_assert(inputs.size() == 1); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::CheckHasInf::make(inputs[0], {}, config); | |||
} | |||
OP_TRAIT_REG(CheckHasInf, CheckHasInf) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
} // namespace check_has_inf | |||
} // namespace imperative | |||
} // namespace mgb | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -307,4 +307,6 @@ def CambriconRuntime: MgbHashableOp<"CambriconRuntime"> { | |||
def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>; | |||
def CheckHasInf: MgbHashableOp<"CheckHasInf", [EmptyParam]>; | |||
#endif // MGB_OPS |
@@ -437,4 +437,19 @@ MGB_IMPL_OPR_GRAD(TopK) { | |||
} | |||
#endif | |||
/* ================= CheckHasInf ================= */ | |||
namespace mgb { | |||
namespace opr { | |||
namespace intl { | |||
template<> | |||
struct MegDNNOprInitPostCtor<CheckHasInf> { | |||
static void apply(cg::OperatorNodeBase &opr) { | |||
opr.output(0)->dtype(dtype::Int32()); | |||
} | |||
}; | |||
} | |||
} | |||
} | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CheckHasInf); | |||
MEGDNN_OPR_INIT1(CheckHasInf, "check_has_inf") | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -73,6 +73,7 @@ namespace opr { | |||
#if MGB_CUDA | |||
MGB_SEREG_OPR(NvOf, 1); | |||
#endif | |||
MGB_SEREG_OPR(CheckHasInf, 1); | |||
} // namespace opr | |||
} // namespace mgb | |||
@@ -178,6 +178,8 @@ public: | |||
const OperatorNodeConfig& config = {}); | |||
}; | |||
MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(CheckHasInf); | |||
} // namespace opr | |||
} // namespace mgb | |||