From 2881934cb8c99c08568f57bc9e2848e2da7aa9d1 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 7 Dec 2021 16:49:52 +0800 Subject: [PATCH] feat(dnn/check_non_finite): addmul scale to check_non_finite opr GitOrigin-RevId: c35a219e52f852b37a311c6de505c6eb08e2ab8a --- dnn/include/megdnn/oprs/general.h | 2 +- dnn/scripts/opr_param_defs.py | 2 + dnn/src/common/reduce_helper.h | 31 -------------- dnn/src/common/reduce_helper_device.h | 8 +++- dnn/src/cuda/check_non_finite/opr_impl.cpp | 2 +- dnn/src/naive/check_non_finite/opr_impl.cpp | 8 ++-- imperative/python/megengine/amp/grad_scaler.py | 28 ++++++------- imperative/python/megengine/functional/math.py | 14 ++++--- .../python/test/unit/functional/test_math.py | 18 ++++---- imperative/src/impl/ops/misc.cpp | 42 +++++++++++++------ src/core/include/megbrain/ir/ops.td | 2 +- src/opr/impl/misc.cpp | 48 ++++++++++++++++------ src/opr/impl/misc.sereg.h | 11 ++++- src/opr/include/megbrain/opr/misc.h | 21 +++++----- src/serialization/impl/schema.fbs | 1 + 15 files changed, 137 insertions(+), 101 deletions(-) diff --git a/dnn/include/megdnn/oprs/general.h b/dnn/include/megdnn/oprs/general.h index b58fabba..30df4dd3 100644 --- a/dnn/include/megdnn/oprs/general.h +++ b/dnn/include/megdnn/oprs/general.h @@ -1344,7 +1344,7 @@ protected: * \brief check whether input contains inf or nan value. */ class CheckNonFinite : public OperatorBase { - DEF_OPR_PARAM(Empty); + DEF_OPR_PARAM(CheckNonFinite); DEF_OPR_IMPL(CheckNonFinite, OperatorBase, -1, 1); size_t m_size = 0; diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index f825a598..8eba9199 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -1176,6 +1176,8 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o ) pdef('Fill').add_fields('float32', 'value', '0') +pdef('CheckNonFinite').add_fields('float32', 'scale', '1.0') + PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), Doc('REFLECT = 1', 'fedcba|abcdefgh|hgfedcb'), diff --git a/dnn/src/common/reduce_helper.h b/dnn/src/common/reduce_helper.h index 0a300e91..50855823 100644 --- a/dnn/src/common/reduce_helper.h +++ b/dnn/src/common/reduce_helper.h @@ -156,37 +156,6 @@ struct MaxOp { : INIT(wtype(DTypeTrait::min())), src(src), dst(dst), B(B) {} }; -template -struct CheckNonFiniteOp { - typedef wtype_ wtype; - const wtype INIT; - - RefPtr* srcs; - RefPtr srcs_total_nr_elems; - RefPtr dst; - const size_t B; - - wtype read(uint32_t idx) { - size_t x = idx / B; - size_t y = idx % B; - if (y < srcs_total_nr_elems.ptr()[x]) { - RefPtr src = srcs[x]; - return !std::isfinite(src.ptr()[y]); - } - return 0; - } - void write(uint32_t idx, wtype val) { dst.ptr()[idx] = val; } - static wtype apply(wtype lhs, wtype rhs) { return lhs | rhs; } - CheckNonFiniteOp( - RefPtr* srcs, const RefPtr& srcs_total_nr_elems, const RefPtr& dst, - size_t B) - : INIT(wtype(0)), - srcs(srcs), - srcs_total_nr_elems(srcs_total_nr_elems), - dst(dst), - B(B) {} -}; - void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, size_t axis); } // namespace reduce diff --git a/dnn/src/common/reduce_helper_device.h b/dnn/src/common/reduce_helper_device.h index c68d0bff..31521261 100644 --- a/dnn/src/common/reduce_helper_device.h +++ b/dnn/src/common/reduce_helper_device.h @@ -194,6 +194,7 @@ struct CheckNonFiniteOp { index_ctype* srcs_total_nr_elems; dst_ctype* dst; const size_t B; + const src_ctype scale; MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { size_t x = idx / B; @@ -204,6 +205,8 @@ struct CheckNonFiniteOp { #else wtype val = std::isfinite(srcs[x][y]); #endif + if (val) + srcs[x][y] *= scale; return !val; } return 0; @@ -214,12 +217,13 @@ struct CheckNonFiniteOp { } MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp( src_ctype** srcs, index_ctype* srcs_total_nr_elems, dst_ctype* dst, - size_t B) + size_t B, src_ctype scale) : INIT(wtype(0)), srcs(srcs), srcs_total_nr_elems(srcs_total_nr_elems), dst(dst), - B(B) {} + B(B), + scale(scale) {} }; } // namespace device_reduce diff --git a/dnn/src/cuda/check_non_finite/opr_impl.cpp b/dnn/src/cuda/check_non_finite/opr_impl.cpp index 9e344a0b..54ac36e8 100644 --- a/dnn/src/cuda/check_non_finite/opr_impl.cpp +++ b/dnn/src/cuda/check_non_finite/opr_impl.cpp @@ -97,7 +97,7 @@ void CheckNonFiniteImpl::exec( workspace_gpu.total_size_in_bytes())), 1, m_size * total_nr_elems_max, 1, stream, Op(srcs_gpu, srcs_total_nr_elems_gpu, dst.ptr(), - total_nr_elems_max)); + total_nr_elems_max, param().scale)); } } // namespace cuda diff --git a/dnn/src/naive/check_non_finite/opr_impl.cpp b/dnn/src/naive/check_non_finite/opr_impl.cpp index dad4bc83..fc7c1cd9 100644 --- a/dnn/src/naive/check_non_finite/opr_impl.cpp +++ b/dnn/src/naive/check_non_finite/opr_impl.cpp @@ -19,7 +19,7 @@ using namespace megdnn; #define wtype dt_int32 -void reduce_fwd(const TensorNDArray& srcs, wtype* dptr) { +void reduce_fwd(const TensorNDArray& srcs, wtype* dptr, dt_float32 scale) { dptr[0] = 0; for (auto src : srcs) { auto sptr = src.ptr(); @@ -31,6 +31,8 @@ void reduce_fwd(const TensorNDArray& srcs, wtype* dptr) { return func(l, mid) | func(mid, r); } else { auto val = std::isfinite(sptr[l]); + if (val) + sptr[l] *= scale; return static_cast(!val); } }; @@ -47,9 +49,9 @@ void CheckNonFiniteImpl::exec( _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, _megdnn_workspace workspace) { check_exec(srcs, dst, workspace.size); - + float scale = param().scale; auto handle = static_cast(this->handle()); - MEGDNN_DISPATCH_CPU_KERN(handle, reduce_fwd(srcs, dst.ptr())); + MEGDNN_DISPATCH_CPU_KERN(handle, reduce_fwd(srcs, dst.ptr(), scale)); } } // namespace naive } // namespace megdnn diff --git a/imperative/python/megengine/amp/grad_scaler.py b/imperative/python/megengine/amp/grad_scaler.py index 5f159951..b5b8d42b 100644 --- a/imperative/python/megengine/amp/grad_scaler.py +++ b/imperative/python/megengine/amp/grad_scaler.py @@ -128,28 +128,28 @@ class GradScaler: grad_tensors: Tensors needed to unscale grads. Should be all tensors that are affected by ``target`` tensor in GradManager's backward. """ - # to support tracing, _check_gradients should be applied to every grad. - if self._check_gradients([x.grad for x in grad_tensors]): - self._found_non_finite = True - - if self._found_non_finite: - for tensor in grad_tensors: - if tensor is None or getattr(tensor, "grad", None) is None: - continue - tensor.grad = None - else: + if self.growth_interval == 0: # use float64 for better precision inv_scale = Tensor(1.0 / self.scale_factor) for tensor in grad_tensors: if tensor is None or getattr(tensor, "grad", None) is None: continue tensor.grad *= inv_scale + return self + + # to support tracing, _check_gradients should be applied to every grad. + if self._check_gradients( + [x.grad for x in grad_tensors], 1.0 / self.scale_factor + ): + self._found_non_finite = True + for tensor in grad_tensors: + if tensor is None or getattr(tensor, "grad", None) is None: + continue + tensor.grad = None return self - def _check_gradients(self, grad): - if self.growth_interval == 0: - return False - return _check_non_finite(grad) + def _check_gradients(self, grad, scale): + return _check_non_finite(grad, scale) def update(self, new_scale: float = None): r"""Update the scale factor according to whether encountered overflow grad. diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index 65d3edba..facd8206 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -1183,7 +1183,7 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: return U, sigma, V -def _check_non_finite(inps: Iterable[Tensor]) -> Tensor: +def _check_non_finite(inps: Iterable[Tensor], scale=1.0) -> Tensor: r"""Check whether input contains infinite or nan value. Args: @@ -1192,7 +1192,11 @@ def _check_non_finite(inps: Iterable[Tensor]) -> Tensor: Returns: a int32 scalar tensor, 0 for False and 1 for True. """ - op = builtin.CheckNonFinite() - (oup,) = apply(op, *inps) - oup._setscalar() - return oup + op = builtin.CheckNonFinite(scale=scale) + oups = apply(op, *inps) + out = oups[-1] + for i in range(len(inps)): + inps[i]._reset(oups[i]) + + out._setscalar() + return out diff --git a/imperative/python/test/unit/functional/test_math.py b/imperative/python/test/unit/functional/test_math.py index 16651f8a..12f37ade 100644 --- a/imperative/python/test/unit/functional/test_math.py +++ b/imperative/python/test/unit/functional/test_math.py @@ -191,17 +191,21 @@ def test_sum_neg_axis(): def test_non_finite(): shape = (32, 3, 32, 32) - data1 = np.random.random(shape).astype(np.float32) - data2 = np.random.random(shape).astype(np.float32) - rst = F.math._check_non_finite([tensor(data1), tensor(data2)]) + data = [] + for i in range(2): + data.append(np.random.random(shape).astype(np.float32)) + tensorList = [tensor(x) for x in data] + rst = F.math._check_non_finite(tensorList, 0.7) np.testing.assert_equal(rst.numpy(), [0]) + for i in range(len(tensorList)): + np.testing.assert_allclose(tensorList[i].numpy() / 0.7, data[i], rtol=1e-6) - data2[0][0][0][0] = float("inf") - rst = F.math._check_non_finite([tensor(data1), tensor(data2)]) + data[1][0][0][0][0] = float("inf") + rst = F.math._check_non_finite([tensor(x) for x in data], 0.7) np.testing.assert_equal(rst.numpy(), [1]) - data2[0][0][0][0] = float("nan") - rst = F.math._check_non_finite([tensor(data1), tensor(data2)]) + data[1][0][0][0][0] = float("nan") + rst = F.math._check_non_finite([tensor(x) for x in data], 0.7) np.testing.assert_equal(rst.numpy(), [1]) diff --git a/imperative/src/impl/ops/misc.cpp b/imperative/src/impl/ops/misc.cpp index b4961661..ae70e16d 100644 --- a/imperative/src/impl/ops/misc.cpp +++ b/imperative/src/impl/ops/misc.cpp @@ -17,44 +17,62 @@ namespace mgb { namespace imperative { namespace check_non_finite { -SymbolVar apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { +SymbolVarArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = def.cast_final_safe(); OperatorNodeConfig config{op.make_name()}; - return opr::CheckNonFinite::make(inputs, {}, config); + return opr::CheckNonFinite::make(inputs, op.param(), config); } SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs) { size_t size = inputs.size(); - - auto dest = Tensor::make( + auto&& op = def.cast_final_safe(); + SmallVector outputs(size + 1); + outputs[size] = Tensor::make( TensorLayout(TensorShape({1}), dtype::Int32()), inputs[0]->comp_node()); + + auto dest = outputs[size]; auto cn = dest->comp_node(); auto&& dnn_opr = opr::intl::create_megdnn_opr(cn); size_t wk_size = 0; SmallVector srcs(size); + // copy an outputs to the dnn for inplace for (size_t i = 0; i < size; ++i) { - srcs[i] = inputs[i]->dev_tensor().as_megdnn(); + outputs[i] = Tensor::make(inputs[i]->layout(), inputs[0]->comp_node()); + outputs[i]->dev_tensor().copy_from_fixlayout(inputs[i]->dev_tensor()); + srcs[i] = outputs[i]->dev_tensor().as_megdnn(); } + megdnn::CheckNonFinite::Param param({op.scale}); + dnn_opr->param() = param; wk_size = dnn_opr->get_workspace_in_bytes(srcs, dest->layout()); auto wk = Blob::make(cn, wk_size); megdnn::Workspace dnn_wk(wk->storage().get(), wk_size); dnn_opr->exec(srcs, dest->dev_tensor().as_megdnn(), dnn_wk); - return {dest}; + return outputs; } std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { - SmallVector dests(1); - dests[0].comp_node = inputs[0].comp_node; - dests[0].layout = TensorLayout(TensorShape({1}), dtype::Int32()); + size_t size = inputs.size(); + SmallVector dests(size + 1); + for (size_t i = 0; i < size; ++i) { + dests[i].comp_node = inputs[i].comp_node; + dests[i].layout = inputs[i].layout; + } + dests[size].comp_node = inputs[0].comp_node; + dests[size].layout = TensorLayout(TensorShape({1}), dtype::Int32()); return {dests, true}; } SmallVector infer_output_attrs( const OpDef& def, const SmallVector& inputs) { - SmallVector dests(1); - dests[0].comp_node = inputs[0]->comp_node(); - dests[0].layout = TensorLayout(TensorShape({1}), dtype::Int32()); + size_t size = inputs.size(); + SmallVector dests(size + 1); + for (size_t i = 0; i < size; ++i) { + dests[i].comp_node = inputs[i]->comp_node(); + dests[i].layout = inputs[i]->layout(); + } + dests[size].comp_node = inputs[0]->comp_node(); + dests[size].layout = TensorLayout(TensorShape({1}), dtype::Int32()); return dests; } std::tuple, SmallVector> infer_output_mem_desc( diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index ceb8d777..233c99f3 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -397,7 +397,7 @@ def MagicMindRuntime: MgbHashableOp<"MagicMindRuntime"> { def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>; -def CheckNonFinite: MgbHashableOp<"CheckNonFinite", [EmptyParam]>; +def CheckNonFinite: MgbHashableOp<"CheckNonFinite", [CheckNonFiniteParam]>; def FastpathCopy: MgbHashableOp<"FastpathCopy">; diff --git a/src/opr/impl/misc.cpp b/src/opr/impl/misc.cpp index d6d0a353..439d4cc6 100644 --- a/src/opr/impl/misc.cpp +++ b/src/opr/impl/misc.cpp @@ -487,39 +487,60 @@ CheckNonFinite::CheckNonFinite( const VarNodeArrayView& inp, const Param& param, const OperatorNodeConfig& config) : Super(OperatorNodeBaseCtorParam{ - inp[0]->owner_graph(), config, "check_non_finite", inp}) { + inp[0]->owner_graph(), config, "check_non_finite", inp}), + m_scale(param.scale) { mgb_assert(!inp.empty()); + for (auto&& i : inp) { add_input({i}); + add_output(None) + ->dtype(dtype::Float32()) + .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); } add_output(None)->dtype(dtype::Int32()).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); cg::add_workspace_output(this); } -SymbolVar CheckNonFinite::make( +SymbolVarArray CheckNonFinite::make( const VarNodeArrayView& inp, const Param& param, const OperatorNodeConfig& config) { mgb_assert(!inp.empty()); intl::BatchedDTypePromotion dtp{inp}; - return SymbolVar{inp[0]}.insert_single_output_opr( - dtp.get_vars(), param, config); + auto outputs = + inp[0]->owner_graph() + ->insert_opr(std::make_unique(inp, param, config)) + ->output(); + mgb_assert(outputs.size() == inp.size() + 2); + SymbolVarArray ret(outputs.size() - 1); + for (size_t i = 0; i < ret.size(); ++i) + ret[i] = outputs[i]; + return ret; } void CheckNonFinite::scn_do_execute() { - megdnn::TensorNDArray inp_arr(input().size()); - for (size_t i = 0; i < input().size(); ++i) { - inp_arr[i] = input()[i]->dev_tensor().as_megdnn(); + size_t size = input().size(); + megdnn::TensorNDArray oup_arr(size); + // copy an outputs to the dnn for inplace + for (size_t i = 0; i < size; ++i) { + oup_arr[i] = output(i) + ->dev_tensor() + .copy_from_fixlayout(input(i)->dev_tensor()) + .as_megdnn(); } + megdnn_opr()->param().scale = m_scale; megdnn_opr()->exec( - inp_arr, output(0)->dev_tensor().as_megdnn(), - intl::get_megdnn_workspace_from_var(output(1))); + oup_arr, output(size)->dev_tensor().as_megdnn(), + intl::get_megdnn_workspace_from_var(output(size + 1))); } void CheckNonFinite::init_output_static_infer_desc() { using namespace cg::static_infer; auto&& mgr = owner_graph()->static_infer_manager(); - + size_t size = input().size(); + for (size_t i = 0; i < size; ++i) { + mgr.register_shape_infer(output(i), ShapeInferDesc::make_identity(input(i))); + } auto infer_oshp = [](TensorShape& dest, const InpVal& iv) { TensorLayout dst; dst.shape[0] = 1; @@ -532,7 +553,7 @@ void CheckNonFinite::init_output_static_infer_desc() { DepVal deps; for (auto i : input()) deps.push_back({i, DepType::SHAPE}); - mgr.register_shape_infer(output(0), {SourceType::DEP, deps, infer_oshp}); + mgr.register_shape_infer(output(size), {SourceType::DEP, deps, infer_oshp}); auto infer_wk = [this](TensorShape& dest, const InpVal& inp) { dest.ndim = 1; @@ -541,10 +562,11 @@ void CheckNonFinite::init_output_static_infer_desc() { inp_arr[i] = {NULL, {inp.val.at(i).shape(), input(0)->dtype()}}; } dest.shape[0] = megdnn_opr()->get_workspace_in_bytes( - inp_arr, {output(0)->shape(), output(0)->dtype()}); + inp_arr, {output(input().size() + 1)->shape(), + output(input().size() + 1)->dtype()}); return true; }; - mgr.register_shape_infer(output(1), {SourceType::DEP, deps, infer_wk}); + mgr.register_shape_infer(output(size + 1), {SourceType::DEP, deps, infer_wk}); } void CheckNonFinite::add_input_layout_constraint() { diff --git a/src/opr/impl/misc.sereg.h b/src/opr/impl/misc.sereg.h index 9dbc1dfb..f9c16a01 100644 --- a/src/opr/impl/misc.sereg.h +++ b/src/opr/impl/misc.sereg.h @@ -56,7 +56,16 @@ struct OprMaker { }; template <> -struct OprMaker : public OprMakerVariadic { +struct OprMaker { + using Opr = opr::CheckNonFinite; + using Param = Opr::Param; + static cg::OperatorNodeBase* make( + const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph, + const OperatorNodeConfig& config) { + MGB_MARK_USED_VAR(graph); + auto out = Opr::make(inputs, param, config); + return out[0].node()->owner_opr(); + } }; } // namespace serialization diff --git a/src/opr/include/megbrain/opr/misc.h b/src/opr/include/megbrain/opr/misc.h index 69705f4b..f8fd7c37 100644 --- a/src/opr/include/megbrain/opr/misc.h +++ b/src/opr/include/megbrain/opr/misc.h @@ -183,18 +183,19 @@ public: const OperatorNodeConfig& config = {}); }; -MGB_DEFINE_OPR_CLASS(CheckNonFinite, intl::CheckNonFiniteBase) //{ -void scn_do_execute() override; -void init_output_static_infer_desc() override; -void add_input_layout_constraint() override; +MGB_DEFINE_OPR_CLASS(CheckNonFinite, intl::CheckNonFiniteBase) // { + void scn_do_execute() override; + void init_output_static_infer_desc() override; + void add_input_layout_constraint() override; + float m_scale = 1; public: -MGE_WIN_DECLSPEC_FUC CheckNonFinite( - const VarNodeArrayView& inp, const Param& param, - const OperatorNodeConfig& config); -MGE_WIN_DECLSPEC_FUC static SymbolVar make( - const VarNodeArrayView& inp, const Param& param = {}, - const OperatorNodeConfig& config = {}); + MGE_WIN_DECLSPEC_FUC CheckNonFinite( + const VarNodeArrayView& inp, const Param& param, + const OperatorNodeConfig& config); + MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( + const VarNodeArrayView& inp, const Param& param = {}, + const OperatorNodeConfig& config = {}); }; } // namespace opr diff --git a/src/serialization/impl/schema.fbs b/src/serialization/impl/schema.fbs index 9eff6038..f91477e6 100644 --- a/src/serialization/impl/schema.fbs +++ b/src/serialization/impl/schema.fbs @@ -115,6 +115,7 @@ union OperatorParam { param.SlidingWindowTranspose = 81, param.Padding = 82, param.ShuffleRNG = 83, + param.CheckNonFinite = 84, } table Operator {