/** * \file dnn/src/cuda/check_non_finite/opr_impl.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "src/cuda/check_non_finite/opr_impl.h" #include "src/cuda/reduce_helper.cuh" #include "src/cuda/handle.h" #include "src/cuda/utils.h" #include "src/common/reduce_helper.h" namespace megdnn { namespace cuda { using reduce::CheckNonFiniteOp; size_t CheckNonFiniteImpl::get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& dst) { typedef CheckNonFiniteOp Op; return get_reduce_workspace_in_bytes(1, src.total_nr_elems(), 1); } void CheckNonFiniteImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { check_exec(src.layout, dst.layout, workspace.size); typedef CheckNonFiniteOp Op; auto stream = cuda_stream(this->handle()); auto B = src.layout.total_nr_elems(); return run_reduce( workspace.ptr(), 1, B, 1, stream, Op(src.ptr(), dst.ptr(), B)); } } // namespace cuda } // namespace megdnn // vim: syntax=cpp.doxygen