From 11d75fecb53a51be94e81bd83eaa7f8e2684cccb Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 28 Oct 2021 11:18:55 +0800 Subject: [PATCH] feat(dnn/check_non_finite): add batch check_non_finite GitOrigin-RevId: e108133282cb2c9129292715ae6eab1e396cd0bc --- dnn/include/megdnn/oprs/general.h | 13 ++-- dnn/src/common/check_non_finite.cpp | 13 ++-- dnn/src/common/reduce_helper.h | 26 +++++-- dnn/src/common/reduce_helper_device.h | 25 +++++-- dnn/src/cuda/check_non_finite/kern.cu | 3 +- dnn/src/cuda/check_non_finite/opr_impl.cpp | 79 ++++++++++++++++++--- dnn/src/cuda/check_non_finite/opr_impl.h | 6 +- dnn/src/naive/check_non_finite/opr_impl.cpp | 49 ++++++------- dnn/src/naive/check_non_finite/opr_impl.h | 10 ++- dnn/test/common/opr_proxy.h | 21 ++++++ dnn/test/cuda/check_non_finite.cpp | 9 ++- dnn/test/naive/check_non_finite.cpp | 13 ++-- imperative/python/megengine/amp/grad_scaler.py | 19 ++--- imperative/python/megengine/functional/math.py | 6 +- .../python/test/unit/amp/test_grad_scaler.py | 31 ++++---- .../python/test/unit/functional/test_math.py | 13 ++-- imperative/src/impl/ops/misc.cpp | 48 ++++++++++++- src/opr/impl/misc.cpp | 82 ++++++++++++++++++---- src/opr/impl/misc.sereg.h | 6 +- src/opr/include/megbrain/opr/misc.h | 16 ++++- 20 files changed, 366 insertions(+), 122 deletions(-) diff --git a/dnn/include/megdnn/oprs/general.h b/dnn/include/megdnn/oprs/general.h index 8de736e4..b58fabba 100644 --- a/dnn/include/megdnn/oprs/general.h +++ b/dnn/include/megdnn/oprs/general.h @@ -1345,22 +1345,23 @@ protected: */ class CheckNonFinite : public OperatorBase { DEF_OPR_PARAM(Empty); - DEF_OPR_IMPL(CheckNonFinite, OperatorBase, 1, 1); + DEF_OPR_IMPL(CheckNonFinite, OperatorBase, -1, 1); + size_t m_size = 0; public: virtual size_t get_workspace_in_bytes( - const TensorLayout& src, const TensorLayout& dst) = 0; + const TensorNDArray& srcs, const TensorLayout& dst) = 0; - void deduce_layout(const TensorLayout& src, TensorLayout& dst); + void deduce_layout(const TensorLayoutArray& srcs, TensorLayout& dst); virtual void exec( - _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; protected: void check_exec( - const TensorLayout& src, const TensorLayout& dst, - size_t workspace_in_bytes); + const TensorNDArray& srcs, const TensorND& dst, size_t workspace_in_bytes); + virtual size_t _get_workspace_in_bytes() = 0; }; /*! diff --git a/dnn/src/common/check_non_finite.cpp b/dnn/src/common/check_non_finite.cpp index e6ea7b28..e03d7800 100644 --- a/dnn/src/common/check_non_finite.cpp +++ b/dnn/src/common/check_non_finite.cpp @@ -15,16 +15,15 @@ namespace megdnn { void CheckNonFinite::check_exec( - const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { - megdnn_assert_contiguous(src); - megdnn_assert_contiguous(dst); - megdnn_assert(src.ndim == 1); - megdnn_assert(src.dtype == dtype::Float32()); - auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); + const TensorNDArray& srcs, const TensorND& dst, size_t workspace_in_bytes) { + megdnn_assert_contiguous(dst.layout); + megdnn_assert(srcs.size() > 0); + megdnn_assert(srcs.begin()->layout.dtype == dtype::Float32()); + auto required_workspace_in_bytes = _get_workspace_in_bytes(); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void CheckNonFinite::deduce_layout(const TensorLayout&, TensorLayout& dst) { +void CheckNonFinite::deduce_layout(const TensorLayoutArray&, TensorLayout& dst) { dst.shape[0] = 1; dst.ndim = 1; dst.dtype = dtype::Int32(); diff --git a/dnn/src/common/reduce_helper.h b/dnn/src/common/reduce_helper.h index 46fac414..0a300e91 100644 --- a/dnn/src/common/reduce_helper.h +++ b/dnn/src/common/reduce_helper.h @@ -156,21 +156,35 @@ struct MaxOp { : INIT(wtype(DTypeTrait::min())), src(src), dst(dst), B(B) {} }; -template +template struct CheckNonFiniteOp { typedef wtype_ wtype; const wtype INIT; - RefPtr src; + RefPtr* srcs; + RefPtr srcs_total_nr_elems; RefPtr dst; const size_t B; - wtype read(uint32_t idx) { return !std::isfinite(src.ptr()[idx]); } + wtype read(uint32_t idx) { + size_t x = idx / B; + size_t y = idx % B; + if (y < srcs_total_nr_elems.ptr()[x]) { + RefPtr src = srcs[x]; + return !std::isfinite(src.ptr()[y]); + } + return 0; + } void write(uint32_t idx, wtype val) { dst.ptr()[idx] = val; } static wtype apply(wtype lhs, wtype rhs) { return lhs | rhs; } - MEGDNN_HOST MEGDNN_DEVICE - CheckNonFiniteOp(const RefPtr& src, const RefPtr& dst, size_t B) - : INIT(wtype(0)), src(src), dst(dst), B(B) {} + CheckNonFiniteOp( + RefPtr* srcs, const RefPtr& srcs_total_nr_elems, const RefPtr& dst, + size_t B) + : INIT(wtype(0)), + srcs(srcs), + srcs_total_nr_elems(srcs_total_nr_elems), + dst(dst), + B(B) {} }; void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, size_t axis); diff --git a/dnn/src/common/reduce_helper_device.h b/dnn/src/common/reduce_helper_device.h index 31ceb194..c68d0bff 100644 --- a/dnn/src/common/reduce_helper_device.h +++ b/dnn/src/common/reduce_helper_device.h @@ -185,28 +185,41 @@ struct MaxOp { : INIT(wtype(DTypeTrait::min())), src(src), dst(dst), B(B) {} }; -template +template struct CheckNonFiniteOp { typedef wtype_ wtype; const wtype INIT; - src_ctype* src; + src_ctype** srcs; + index_ctype* srcs_total_nr_elems; dst_ctype* dst; const size_t B; MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { + size_t x = idx / B; + size_t y = idx % B; + if (y < srcs_total_nr_elems[x]) { #if defined(__CUDA_ARCH__) - return !isfinite(src[idx]); + wtype val = isfinite(srcs[x][y]); #else - return !std::isfinite(src[idx]); + wtype val = std::isfinite(srcs[x][y]); #endif + return !val; + } + return 0; } MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { return lhs | rhs; } - MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp(src_ctype* src, dst_ctype* dst, size_t B) - : INIT(wtype(0)), src(src), dst(dst), B(B) {} + MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp( + src_ctype** srcs, index_ctype* srcs_total_nr_elems, dst_ctype* dst, + size_t B) + : INIT(wtype(0)), + srcs(srcs), + srcs_total_nr_elems(srcs_total_nr_elems), + dst(dst), + B(B) {} }; } // namespace device_reduce diff --git a/dnn/src/cuda/check_non_finite/kern.cu b/dnn/src/cuda/check_non_finite/kern.cu index 5cd9f1d8..8251090a 100644 --- a/dnn/src/cuda/check_non_finite/kern.cu +++ b/dnn/src/cuda/check_non_finite/kern.cu @@ -19,7 +19,8 @@ namespace cuda { #define COMMA , INST_REDUCE( - device_reduce::CheckNonFiniteOp, + device_reduce::CheckNonFiniteOp< + dt_float32 COMMA size_t COMMA dt_int32 COMMA dt_int32>, false); #undef COMMA diff --git a/dnn/src/cuda/check_non_finite/opr_impl.cpp b/dnn/src/cuda/check_non_finite/opr_impl.cpp index 94657921..9e344a0b 100644 --- a/dnn/src/cuda/check_non_finite/opr_impl.cpp +++ b/dnn/src/cuda/check_non_finite/opr_impl.cpp @@ -21,22 +21,83 @@ namespace megdnn { namespace cuda { using device_reduce::CheckNonFiniteOp; +#define total_nr_elems_max 2048 +size_t CheckNonFiniteImpl::_get_workspace_in_bytes() { + // Call the _get_workspace_in_bytes to reduce the loop fetch workspace bytes + typedef CheckNonFiniteOp Op; + megdnn_assert(m_size > 0); + WorkspaceBundle bundle( + nullptr, { + sizeof(dt_float32*) * m_size, + sizeof(size_t) * m_size, + }); + return get_reduce_workspace_in_bytes(1, m_size * total_nr_elems_max, 1) + + bundle.total_size_in_bytes(); +} size_t CheckNonFiniteImpl::get_workspace_in_bytes( - const TensorLayout& src, const TensorLayout& dst) { - typedef CheckNonFiniteOp Op; - return get_reduce_workspace_in_bytes(1, src.total_nr_elems(), 1); + const TensorNDArray& srcs, const TensorLayout&) { + m_size = 0; + for (const auto& src : srcs) { + m_size += DIVUP(src.layout.total_nr_elems(), total_nr_elems_max); + } + return _get_workspace_in_bytes(); } void CheckNonFiniteImpl::exec( - _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { - check_exec(src.layout, dst.layout, workspace.size); - typedef CheckNonFiniteOp Op; + _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { + check_exec(srcs, dst, workspace.size); + typedef CheckNonFiniteOp Op; auto stream = cuda_stream(this->handle()); - auto B = src.layout.total_nr_elems(); + SmallVector workspace_sizes{ + sizeof(dt_float32*) * m_size, + sizeof(size_t) * m_size, + }; + WorkspaceBundle workspace_cpu(nullptr, workspace_sizes), + workspace_gpu(nullptr, workspace_sizes); + auto total_workspace_size = workspace_cpu.total_size_in_bytes(); + void* workspace_cpu_raw = malloc(total_workspace_size); + megdnn_assert_internal(workspace_cpu_raw); + void* workspace_gpu_raw = workspace.raw_ptr; + workspace_cpu = WorkspaceBundle(workspace_cpu_raw, workspace_sizes); + workspace_gpu = WorkspaceBundle(workspace_gpu_raw, workspace_sizes); + + auto srcs_cpu = static_cast(workspace_cpu.get(0)); + auto srcs_gpu = static_cast(workspace_gpu.get(0)); + auto srcs_total_nr_elems_cpu = static_cast(workspace_cpu.get(1)); + auto srcs_total_nr_elems_gpu = static_cast(workspace_gpu.get(1)); + + // srcs + // cut the tensor to a fixed length of total_nr_elems_max + size_t i = 0; + for (const auto& src : srcs) { + size_t src_nr_elems = src.layout.total_nr_elems(); + size_t nr_elems = DIVUP(src_nr_elems, total_nr_elems_max); + for (size_t j = 0; j < nr_elems; ++j, ++i) { + srcs_cpu[i] = src.ptr() + j * total_nr_elems_max; + if (j + 1 == nr_elems && src_nr_elems % total_nr_elems_max) { + srcs_total_nr_elems_cpu[i] = src_nr_elems % total_nr_elems_max; + } else { + srcs_total_nr_elems_cpu[i] = total_nr_elems_max; + } + } + } + for (size_t i = 0; i < workspace_cpu.nr_workspace(); ++i) { + cuda_check(cudaMemcpyAsync( + workspace_gpu.get(i), workspace_cpu.get(i), workspace_cpu.get_size(i), + cudaMemcpyHostToDevice, stream)); + } + cuda_check(cudaStreamAddCallback( + stream, callback_free, static_cast(workspace_cpu_raw), 0)); + return run_reduce( - workspace.ptr(), 1, B, 1, stream, - Op(src.ptr(), dst.ptr(), B)); + static_cast( + (void*)((char*)workspace_gpu_raw + + workspace_gpu.total_size_in_bytes())), + 1, m_size * total_nr_elems_max, 1, stream, + Op(srcs_gpu, srcs_total_nr_elems_gpu, dst.ptr(), + total_nr_elems_max)); } } // namespace cuda diff --git a/dnn/src/cuda/check_non_finite/opr_impl.h b/dnn/src/cuda/check_non_finite/opr_impl.h index 8c89b61a..7392c062 100644 --- a/dnn/src/cuda/check_non_finite/opr_impl.h +++ b/dnn/src/cuda/check_non_finite/opr_impl.h @@ -18,16 +18,18 @@ namespace megdnn { namespace cuda { class CheckNonFiniteImpl final : public CheckNonFinite { + size_t _get_workspace_in_bytes() override; + public: using CheckNonFinite::CheckNonFinite; size_t get_workspace_in_bytes( - const TensorLayout& src, const TensorLayout& dst) override; + const TensorNDArray& srcs, const TensorLayout& dst) override; bool is_thread_safe() const override { return true; } void exec( - _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, _megdnn_workspace workspace) override; }; diff --git a/dnn/src/naive/check_non_finite/opr_impl.cpp b/dnn/src/naive/check_non_finite/opr_impl.cpp index 3518f589..dad4bc83 100644 --- a/dnn/src/naive/check_non_finite/opr_impl.cpp +++ b/dnn/src/naive/check_non_finite/opr_impl.cpp @@ -17,21 +17,25 @@ namespace { using namespace megdnn; -#define src_ctype dt_float32 -#define wtype dt_int32 - -void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) { - std::function func; - func = [&](size_t l, size_t r) -> wtype { - if (l + 1 < r) { - size_t mid = l + (r - l) / 2; - return func(l, mid) | func(mid, r); - } else { - return static_cast(!std::isfinite(sptr[l])); - } - }; - - dptr[0] = func(0, size); +#define wtype dt_int32 + +void reduce_fwd(const TensorNDArray& srcs, wtype* dptr) { + dptr[0] = 0; + for (auto src : srcs) { + auto sptr = src.ptr(); + size_t size = src.layout.total_nr_elems(); + std::function func; + func = [&](wtype l, wtype r) -> wtype { + if (l + 1 < r) { + wtype mid = l + (r - l) / 2; + return func(l, mid) | func(mid, r); + } else { + auto val = std::isfinite(sptr[l]); + return static_cast(!val); + } + }; + dptr[0] |= func(0, size); + } } } // namespace @@ -39,20 +43,13 @@ void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) { namespace megdnn { namespace naive { -size_t CheckNonFiniteImpl::get_workspace_in_bytes( - const TensorLayout&, const TensorLayout&) { - return 0; -} - void CheckNonFiniteImpl::exec( - _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { - check_exec(src.layout, dst.layout, workspace.size); + _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { + check_exec(srcs, dst, workspace.size); auto handle = static_cast(this->handle()); - MEGDNN_DISPATCH_CPU_KERN( - handle, reduce_fwd( - src.ptr(), dst.ptr(), - src.layout.total_nr_elems())); + MEGDNN_DISPATCH_CPU_KERN(handle, reduce_fwd(srcs, dst.ptr())); } } // namespace naive } // namespace megdnn diff --git a/dnn/src/naive/check_non_finite/opr_impl.h b/dnn/src/naive/check_non_finite/opr_impl.h index 9a4528d5..2360a719 100644 --- a/dnn/src/naive/check_non_finite/opr_impl.h +++ b/dnn/src/naive/check_non_finite/opr_impl.h @@ -17,16 +17,20 @@ namespace megdnn { namespace naive { class CheckNonFiniteImpl final : public CheckNonFinite { + size_t _get_workspace_in_bytes() override { return 0; } + public: using CheckNonFinite::CheckNonFinite; bool is_thread_safe() const override { return true; } - size_t get_workspace_in_bytes( - const TensorLayout& src, const TensorLayout& dst) override; + size_t get_workspace_in_bytes(const TensorNDArray&, const TensorLayout&) override { + m_size = 0; + return _get_workspace_in_bytes(); + } void exec( - _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, _megdnn_workspace workspace) override; }; diff --git a/dnn/test/common/opr_proxy.h b/dnn/test/common/opr_proxy.h index 42c854ce..9744b4eb 100644 --- a/dnn/test/common/opr_proxy.h +++ b/dnn/test/common/opr_proxy.h @@ -203,6 +203,27 @@ struct OprProxy { }; template <> +struct OprProxy { + static void deduce_layout(CheckNonFinite* opr, TensorLayoutArray& layouts) { + megdnn_assert(layouts.size() >= 2); + auto inp = layouts; + inp.pop_back(); + opr->deduce_layout(inp, layouts.back()); + } + + static void exec(CheckNonFinite* opr, const TensorNDArray& tensors) { + megdnn_assert(tensors.size() >= 2); + auto inps = tensors; + inps.pop_back(); + + WorkspaceWrapper W( + opr->handle(), + opr->get_workspace_in_bytes(inps, tensors.back().layout)); + opr->exec(inps, tensors.back(), W.workspace()); + } +}; + +template <> struct OprProxy : DeduceLayoutProxy { WorkspaceWrapper W; void exec(SplitForward* opr, const TensorNDArray& tensors) { diff --git a/dnn/test/cuda/check_non_finite.cpp b/dnn/test/cuda/check_non_finite.cpp index 4c15a848..3a228466 100644 --- a/dnn/test/cuda/check_non_finite.cpp +++ b/dnn/test/cuda/check_non_finite.cpp @@ -22,13 +22,16 @@ TEST_F(CUDA, CHECK_NON_FINITE_BASIC) { const auto nan = std::numeric_limits::quiet_NaN(); UniformFloatWithValueRNG rng(-1.0f, 1.0f, 0.1f, inf); checker.set_rng(0, &rng); - checker.execs({{512 * 16}, {1}}); + checker.execs({{512 * 4}, {4}, {1}}); rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, inf); checker.set_rng(0, &rng); - checker.execs({{512 * 16}, {1}}); + checker.execs({{4}, {512 * 4}, {1}}); rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, nan); checker.set_rng(0, &rng); - checker.execs({{512 * 16}, {1}}); + checker.execs({{32}, {256}, {1}}); + rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 0.f, nan); + checker.set_rng(0, &rng); + checker.execs({{16}, {16}, {2}, {1}}); } } // namespace test diff --git a/dnn/test/naive/check_non_finite.cpp b/dnn/test/naive/check_non_finite.cpp index a28f1b47..c5f6b501 100644 --- a/dnn/test/naive/check_non_finite.cpp +++ b/dnn/test/naive/check_non_finite.cpp @@ -20,23 +20,28 @@ namespace test { TEST_F(NAIVE, CHECK_NON_FINITE_BASIC) { Checker checker(handle(), false); checker.exect( - Testcase{TensorValue({4}, dtype::Float32(), {1.1, 2.2, 3.3, 4.3}), {}}, - Testcase{{}, TensorValue({1}, dtype::Int32(), {0})}); + Testcase{ + TensorValue({4}, dtype::Float32(), {1.1, 2.2, 3.3, 4.3}), + TensorValue({4}, dtype::Float32(), {1.1, 2.2, 3.3, 4.3}), + {}}, + Testcase{{}, {}, TensorValue({1}, dtype::Int32(), {0})}); checker.exect( Testcase{ + TensorValue({4}, dtype::Float32(), {1.1, 2.2, 3.3, 4.3}), TensorValue( {4}, dtype::Float32(), {1.1f, 2.2f, 3.3f, std::numeric_limits::infinity()}), {}}, - Testcase{{}, TensorValue({1}, dtype::Int32(), {1})}); + Testcase{{}, {}, TensorValue({1}, dtype::Int32(), {1})}); checker.exect( Testcase{ + TensorValue({4}, dtype::Float32(), {1.1, 2.2, 3.3, 4.3}), TensorValue( {4}, dtype::Float32(), {1.1f, 2.2f, 3.3f, std::numeric_limits::quiet_NaN()}), {}}, - Testcase{{}, TensorValue({1}, dtype::Int32(), {1})}); + Testcase{{}, {}, TensorValue({1}, dtype::Int32(), {1})}); } } // namespace test diff --git a/imperative/python/megengine/amp/grad_scaler.py b/imperative/python/megengine/amp/grad_scaler.py index f1b64cfe..5f159951 100644 --- a/imperative/python/megengine/amp/grad_scaler.py +++ b/imperative/python/megengine/amp/grad_scaler.py @@ -128,21 +128,22 @@ class GradScaler: grad_tensors: Tensors needed to unscale grads. Should be all tensors that are affected by ``target`` tensor in GradManager's backward. """ - # use float64 for better precision - inv_scale = Tensor(1.0 / self.scale_factor) - for tensor in grad_tensors: - if tensor is None or getattr(tensor, "grad", None) is None: - continue - # to support tracing, _check_gradients should be applied to every grad. - if self._check_gradients(tensor.grad): - self._found_non_finite = True - tensor.grad *= inv_scale + # to support tracing, _check_gradients should be applied to every grad. + if self._check_gradients([x.grad for x in grad_tensors]): + self._found_non_finite = True if self._found_non_finite: for tensor in grad_tensors: if tensor is None or getattr(tensor, "grad", None) is None: continue tensor.grad = None + else: + # use float64 for better precision + inv_scale = Tensor(1.0 / self.scale_factor) + for tensor in grad_tensors: + if tensor is None or getattr(tensor, "grad", None) is None: + continue + tensor.grad *= inv_scale return self def _check_gradients(self, grad): diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index 440caafd..65d3edba 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -9,7 +9,7 @@ import collections import math from functools import lru_cache -from typing import Optional, Sequence, Tuple, Union +from typing import Iterable, Optional, Sequence, Tuple, Union from ..core import _config from ..core._imperative_rt.core2 import apply, dtype_promotion @@ -1183,7 +1183,7 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: return U, sigma, V -def _check_non_finite(inp: Tensor) -> Tensor: +def _check_non_finite(inps: Iterable[Tensor]) -> Tensor: r"""Check whether input contains infinite or nan value. Args: @@ -1193,6 +1193,6 @@ def _check_non_finite(inp: Tensor) -> Tensor: a int32 scalar tensor, 0 for False and 1 for True. """ op = builtin.CheckNonFinite() - (oup,) = apply(op, inp.reshape(-1).astype("float32")) + (oup,) = apply(op, *inps) oup._setscalar() return oup diff --git a/imperative/python/test/unit/amp/test_grad_scaler.py b/imperative/python/test/unit/amp/test_grad_scaler.py index 9303b516..c554aa11 100644 --- a/imperative/python/test/unit/amp/test_grad_scaler.py +++ b/imperative/python/test/unit/amp/test_grad_scaler.py @@ -10,21 +10,26 @@ import numpy as np import megengine as mge from megengine.amp import GradScaler from megengine.autodiff import GradManager +from megengine.jit import trace def test_grad_scaler(): - gm = GradManager() - scaler = GradScaler() + def f(): + gm = GradManager() + scaler = GradScaler() - x = mge.tensor(1.0) - for _ in range(3): - with gm: - y = x + 1 - gm.attach(y) - loss = y + 1 - scaler.backward(gm, loss, unscale_grad=False) - np.testing.assert_equal(y.grad.numpy(), scaler.scale_factor) + x = mge.tensor(1.0) + for _ in range(3): + with gm: + y = x + 1 + gm.attach(y) + loss = y + 1 + scaler.backward(gm, loss, unscale_grad=False) + np.testing.assert_equal(y.grad.numpy(), scaler.scale_factor) + scaler.unscale(gm.attached_tensors()) + np.testing.assert_equal(y.grad.numpy(), 1) + # test handle None elements scaler.unscale(gm.attached_tensors()) - np.testing.assert_equal(y.grad.numpy(), 1) - # test handle None elements - scaler.unscale(gm.attached_tensors()) + + f() + trace(f)() diff --git a/imperative/python/test/unit/functional/test_math.py b/imperative/python/test/unit/functional/test_math.py index a59e6493..16651f8a 100644 --- a/imperative/python/test/unit/functional/test_math.py +++ b/imperative/python/test/unit/functional/test_math.py @@ -191,16 +191,17 @@ def test_sum_neg_axis(): def test_non_finite(): shape = (32, 3, 32, 32) - data = np.random.random(shape).astype(np.float32) - rst = F.math._check_non_finite(tensor(data)) + data1 = np.random.random(shape).astype(np.float32) + data2 = np.random.random(shape).astype(np.float32) + rst = F.math._check_non_finite([tensor(data1), tensor(data2)]) np.testing.assert_equal(rst.numpy(), [0]) - data[0][0][0][0] = float("inf") - rst = F.math._check_non_finite(tensor(data)) + data2[0][0][0][0] = float("inf") + rst = F.math._check_non_finite([tensor(data1), tensor(data2)]) np.testing.assert_equal(rst.numpy(), [1]) - data[0][0][0][0] = float("nan") - rst = F.math._check_non_finite(tensor(data)) + data2[0][0][0][0] = float("nan") + rst = F.math._check_non_finite([tensor(data1), tensor(data2)]) np.testing.assert_equal(rst.numpy(), [1]) diff --git a/imperative/src/impl/ops/misc.cpp b/imperative/src/impl/ops/misc.cpp index a29f0ef6..b4961661 100644 --- a/imperative/src/impl/ops/misc.cpp +++ b/imperative/src/impl/ops/misc.cpp @@ -17,14 +17,56 @@ namespace mgb { namespace imperative { namespace check_non_finite { -auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { +SymbolVar apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = def.cast_final_safe(); - mgb_assert(inputs.size() == 1); OperatorNodeConfig config{op.make_name()}; - return opr::CheckNonFinite::make(inputs[0], {}, config); + return opr::CheckNonFinite::make(inputs, {}, config); +} + +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs) { + size_t size = inputs.size(); + + auto dest = Tensor::make( + TensorLayout(TensorShape({1}), dtype::Int32()), inputs[0]->comp_node()); + auto cn = dest->comp_node(); + auto&& dnn_opr = opr::intl::create_megdnn_opr(cn); + size_t wk_size = 0; + SmallVector srcs(size); + for (size_t i = 0; i < size; ++i) { + srcs[i] = inputs[i]->dev_tensor().as_megdnn(); + } + wk_size = dnn_opr->get_workspace_in_bytes(srcs, dest->layout()); + auto wk = Blob::make(cn, wk_size); + megdnn::Workspace dnn_wk(wk->storage().get(), wk_size); + dnn_opr->exec(srcs, dest->dev_tensor().as_megdnn(), dnn_wk); + return {dest}; +} + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + SmallVector dests(1); + dests[0].comp_node = inputs[0].comp_node; + dests[0].layout = TensorLayout(TensorShape({1}), dtype::Int32()); + return {dests, true}; +} +SmallVector infer_output_attrs( + const OpDef& def, const SmallVector& inputs) { + SmallVector dests(1); + dests[0].comp_node = inputs[0]->comp_node(); + dests[0].layout = TensorLayout(TensorShape({1}), dtype::Int32()); + return dests; +} +std::tuple, SmallVector> infer_output_mem_desc( + const OpDef& def, const SmallVector& inputs_tensors, + const SmallVector& inputs_mems) { + return {{}, {}}; } OP_TRAIT_REG(CheckNonFinite, CheckNonFinite) .apply_on_var_node(apply_on_var_node) + .apply_on_physical_tensor(apply_on_physical_tensor) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .infer_output_mem_desc(infer_output_mem_desc) .fallback(); } // namespace check_non_finite diff --git a/src/opr/impl/misc.cpp b/src/opr/impl/misc.cpp index 584ecb5e..d6d0a353 100644 --- a/src/opr/impl/misc.cpp +++ b/src/opr/impl/misc.cpp @@ -482,18 +482,74 @@ MGB_IMPL_OPR_GRAD(TopK) { #endif /* ================= CheckNonFinite ================= */ -namespace mgb { -namespace opr { -namespace intl { -template <> -struct MegDNNOprInitPostCtor { - static void apply(cg::OperatorNodeBase& opr) { - opr.output(0)->dtype(dtype::Int32()); - } -}; -} // namespace intl -} // namespace opr -} // namespace mgb MGB_DYN_TYPE_OBJ_FINAL_IMPL(CheckNonFinite); -MEGDNN_OPR_INIT1(CheckNonFinite, "check_non_finite") +CheckNonFinite::CheckNonFinite( + const VarNodeArrayView& inp, const Param& param, + const OperatorNodeConfig& config) + : Super(OperatorNodeBaseCtorParam{ + inp[0]->owner_graph(), config, "check_non_finite", inp}) { + mgb_assert(!inp.empty()); + for (auto&& i : inp) { + add_input({i}); + } + add_output(None)->dtype(dtype::Int32()).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + cg::add_workspace_output(this); +} + +SymbolVar CheckNonFinite::make( + const VarNodeArrayView& inp, const Param& param, + const OperatorNodeConfig& config) { + mgb_assert(!inp.empty()); + intl::BatchedDTypePromotion dtp{inp}; + return SymbolVar{inp[0]}.insert_single_output_opr( + dtp.get_vars(), param, config); +} + +void CheckNonFinite::scn_do_execute() { + megdnn::TensorNDArray inp_arr(input().size()); + for (size_t i = 0; i < input().size(); ++i) { + inp_arr[i] = input()[i]->dev_tensor().as_megdnn(); + } + megdnn_opr()->exec( + inp_arr, output(0)->dev_tensor().as_megdnn(), + intl::get_megdnn_workspace_from_var(output(1))); +} + +void CheckNonFinite::init_output_static_infer_desc() { + using namespace cg::static_infer; + + auto&& mgr = owner_graph()->static_infer_manager(); + + auto infer_oshp = [](TensorShape& dest, const InpVal& iv) { + TensorLayout dst; + dst.shape[0] = 1; + dst.ndim = 1; + dst.dtype = dtype::Int32(); + dst.init_contiguous_stride(); + dest = dst; + return true; + }; + DepVal deps; + for (auto i : input()) + deps.push_back({i, DepType::SHAPE}); + mgr.register_shape_infer(output(0), {SourceType::DEP, deps, infer_oshp}); + + auto infer_wk = [this](TensorShape& dest, const InpVal& inp) { + dest.ndim = 1; + megdnn::TensorNDArray inp_arr(input().size()); + for (size_t i = 0; i < input().size(); ++i) { + inp_arr[i] = {NULL, {inp.val.at(i).shape(), input(0)->dtype()}}; + } + dest.shape[0] = megdnn_opr()->get_workspace_in_bytes( + inp_arr, {output(0)->shape(), output(0)->dtype()}); + return true; + }; + mgr.register_shape_infer(output(1), {SourceType::DEP, deps, infer_wk}); +} + +void CheckNonFinite::add_input_layout_constraint() { + for (auto i : input()) { + i->add_layout_constraint_contiguous(); + } +} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/misc.sereg.h b/src/opr/impl/misc.sereg.h index 64ef9772..9dbc1dfb 100644 --- a/src/opr/impl/misc.sereg.h +++ b/src/opr/impl/misc.sereg.h @@ -55,6 +55,10 @@ struct OprMaker { } }; +template <> +struct OprMaker : public OprMakerVariadic { +}; + } // namespace serialization namespace opr { @@ -72,7 +76,7 @@ MGB_SEREG_OPR(CumsumV1, 1); #if MGB_CUDA MGB_SEREG_OPR(NvOf, 1); #endif -MGB_SEREG_OPR(CheckNonFinite, 1); +MGB_SEREG_OPR(CheckNonFinite, 0); } // namespace opr } // namespace mgb diff --git a/src/opr/include/megbrain/opr/misc.h b/src/opr/include/megbrain/opr/misc.h index 33bc6c70..69705f4b 100644 --- a/src/opr/include/megbrain/opr/misc.h +++ b/src/opr/include/megbrain/opr/misc.h @@ -142,6 +142,8 @@ using CondTakeBase = cg::SingleCNOperatorNode< cg::OperatorNodeBase, mixin::MegDNNOprHolderImpl>; using TopKBase = cg::SingleCNOperatorNode< cg::OperatorNodeBase, mixin::MegDNNOprHolderImpl>; +using CheckNonFiniteBase = cg::SingleCNOperatorNode< + cg::OperatorNodeBase, mixin::MegDNNOprHolderImpl>; } // namespace intl /*! @@ -181,7 +183,19 @@ public: const OperatorNodeConfig& config = {}); }; -MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(CheckNonFinite); +MGB_DEFINE_OPR_CLASS(CheckNonFinite, intl::CheckNonFiniteBase) //{ +void scn_do_execute() override; +void init_output_static_infer_desc() override; +void add_input_layout_constraint() override; + +public: +MGE_WIN_DECLSPEC_FUC CheckNonFinite( + const VarNodeArrayView& inp, const Param& param, + const OperatorNodeConfig& config); +MGE_WIN_DECLSPEC_FUC static SymbolVar make( + const VarNodeArrayView& inp, const Param& param = {}, + const OperatorNodeConfig& config = {}); +}; } // namespace opr } // namespace mgb