From da91e650a56ce6fcc0f0ba31cd0fd499a4110d27 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 24 May 2022 21:33:48 +0800 Subject: [PATCH] refactor(ops/layer_norm): speed up the host speed of layer_norm GitOrigin-RevId: 6f359b5b295f3d340947e0f6ea948c0fc1c19886 --- dnn/include/megdnn/oprs/nn.h | 5 + dnn/src/common/layer_norm.cpp | 19 ++-- imperative/python/megengine/functional/nn.py | 4 +- imperative/src/impl/ops/layer_norm.cpp | 115 +++++++++++++++++++++ imperative/src/impl/ops/specializations.cpp | 25 ----- .../src/impl/transformations/dtype_promote.cpp | 23 +++++ 6 files changed, 157 insertions(+), 34 deletions(-) create mode 100644 imperative/src/impl/ops/layer_norm.cpp diff --git a/dnn/include/megdnn/oprs/nn.h b/dnn/include/megdnn/oprs/nn.h index 814212f8..b117b7e1 100644 --- a/dnn/include/megdnn/oprs/nn.h +++ b/dnn/include/megdnn/oprs/nn.h @@ -1939,6 +1939,11 @@ class LayerNormBase : public OperatorBase { DEF_OPR_IMPL_CTOR(LayerNormBase, OperatorBase); DEF_OPR_PARAM(LayerNorm); +public: + MGE_WIN_DECLSPEC_FUC static void deduce_layout_fwd_impl( + const TensorLayout& data, const Param& p, TensorLayout& dst, + TensorLayout& mean, TensorLayout& rstd); + protected: void deduce_layout_fwd( const TensorLayout& data, const TensorLayout& weight, diff --git a/dnn/src/common/layer_norm.cpp b/dnn/src/common/layer_norm.cpp index e33acc1b..c052c527 100644 --- a/dnn/src/common/layer_norm.cpp +++ b/dnn/src/common/layer_norm.cpp @@ -4,12 +4,11 @@ namespace megdnn { -void LayerNormBase::deduce_layout_fwd( - const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, - TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) { - MEGDNN_MARK_USED_VAR(weight); - MEGDNN_MARK_USED_VAR(bias); - auto p = param(); +using Param = LayerNormBase::Param; + +void LayerNormBase::deduce_layout_fwd_impl( + const TensorLayout& data, const Param& p, TensorLayout& dst, TensorLayout& mean, + TensorLayout& rstd) { TensorShape unnormalized_shape; unnormalized_shape.ndim = data.ndim - p.normalized_dim; for (size_t i = 0; i < unnormalized_shape.ndim; ++i) { @@ -22,6 +21,14 @@ void LayerNormBase::deduce_layout_fwd( rstd = unnormalized_layout; } +void LayerNormBase::deduce_layout_fwd( + const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, + TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) { + MEGDNN_MARK_USED_VAR(weight); + MEGDNN_MARK_USED_VAR(bias); + deduce_layout_fwd_impl(data, param(), dst, mean, rstd); +} + void LayerNormBase::check_layout_fwd( const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd) { diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 6f7cc79a..ff1d468c 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -63,6 +63,7 @@ __all__ = [ "hsigmoid", "hswish", "indexing_one_hot", + "layer_norm", "leaky_relu", "linear", "local_conv2d", @@ -1135,9 +1136,6 @@ def layer_norm( bias: must not be None when the affine is true eps: a value added to the denominator for numerical stability. Default: 1e-5 """ - if amp._enabled: - inp, weight, bias = cast_tensors(inp, weight, bias, promote=True) - if isinstance(normalized_shape, int): normalized_shape = [normalized_shape] diff --git a/imperative/src/impl/ops/layer_norm.cpp b/imperative/src/impl/ops/layer_norm.cpp new file mode 100644 index 00000000..53633c8e --- /dev/null +++ b/imperative/src/impl/ops/layer_norm.cpp @@ -0,0 +1,115 @@ +#include "megbrain/opr/dnn/layer_norm.h" +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" + +#include "../blob_manager_impl.h" +#include "../dnn_op_helper.h" +#include "../op_trait.h" + +namespace mgb::imperative { +namespace layer_norm { + +cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = static_cast(def); + size_t nr_inp = inputs.size(); + auto p = op.param(); + mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine)); + OperatorNodeConfig config{op.make_name()}; + if (nr_inp == 3) { + return opr::LayerNorm::make( + inputs[0], inputs[1], inputs[2], op.param(), config)[0] + .node() + ->owner_opr(); + } else { + return opr::LayerNorm::make(inputs[0], op.param(), config)[0] + .node() + ->owner_opr(); + } +} + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + auto&& op_def = def.cast_final_safe(); + size_t nr_inp = inputs.size(); + auto p = op_def.param(); + mgb_assert( + (nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine), + "num of inputs of pooling should be 1 or 3 but you give %zu", + inputs.size()); + + auto&& inp = inputs[0]; + auto& inp_cn = inp.comp_node; + + if (inp.layout.ndim == 0) { + return {{{TensorLayout{inp.layout.dtype}, inp_cn, {}}, + {TensorLayout{dtype::Float32()}, inp_cn, {}}, + {TensorLayout{dtype::Float32()}, inp_cn, {}}}, + false}; + } + + TensorLayout oup_layout, mean_layout, rstd_layout; + megdnn::LayerNorm::deduce_layout_fwd_impl( + inp.layout, p, oup_layout, mean_layout, rstd_layout); + return {{{oup_layout, inp_cn, {}}, + {mean_layout, inp_cn, {}}, + {rstd_layout, inp_cn, {}}}, + true}; +} + +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + auto&& op_def = def.cast_final_safe(); + size_t nr_inp = inputs.size(); + auto p = op_def.param(); + + mgb_assert( + (nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine), + "num of inputs of pooling should be 1 or 3 but you give %zu", + inputs.size()); + + auto cn = inputs[0]->comp_node(); + DnnOprCaller caller(cn); + auto&& dnn_opr = caller.op; + dnn_opr->param() = p; + + TensorLayout oup_layout, mean_layout, rstd_layout; + megdnn::LayerNorm::deduce_layout_fwd_impl( + inputs[0]->dnn_tensor().layout, p, oup_layout, mean_layout, rstd_layout); + + DeviceTensorND out_devtensor = + BlobManager::inst()->alloc_workspace_with_defrag(cn, oup_layout); + DeviceTensorND mean_devtensor = + BlobManager::inst()->alloc_workspace_with_defrag(cn, mean_layout); + DeviceTensorND rstd_devtensor = + BlobManager::inst()->alloc_workspace_with_defrag(cn, rstd_layout); + + megdnn::Workspace dnn_wk; + auto wk_size = caller.op->get_workspace_in_bytes( + inputs[0]->dnn_tensor().layout, + p.affine ? inputs[1]->dnn_tensor().layout : TensorLayout(), + p.affine ? inputs[2]->dnn_tensor().layout : TensorLayout(), oup_layout, + mean_layout, rstd_layout); + if (wk_size != 0) { + TensorLayout w_layout({wk_size}, dtype::Byte()); + dnn_wk = caller.create_workspace(w_layout); + } + + dnn_opr->exec( + inputs[0]->dnn_tensor(), + p.affine ? inputs[1]->dnn_tensor() : megdnn::TensorND(), + p.affine ? inputs[2]->dnn_tensor() : megdnn::TensorND(), + out_devtensor.as_megdnn(), mean_devtensor.as_megdnn(), + rstd_devtensor.as_megdnn(), dnn_wk); + return {Tensor::make(out_devtensor), Tensor::make(mean_devtensor), + Tensor::make(rstd_devtensor)}; +} + +OP_TRAIT_REG(LayerNorm, LayerNorm) + .apply_on_var_node(apply_on_var_node) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .apply_on_physical_tensor(apply_on_physical_tensor) + .fallback(); + +} // namespace layer_norm +} // namespace mgb::imperative \ No newline at end of file diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index e842ca91..528d3ae7 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -8,7 +8,6 @@ #include "megbrain/opr/dnn/correlation.h" #include "megbrain/opr/dnn/fake_quant.h" #include "megbrain/opr/dnn/images2neibs.h" -#include "megbrain/opr/dnn/layer_norm.h" #include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/lrn.h" #include "megbrain/opr/dnn/lsq.h" @@ -729,28 +728,4 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { OP_TRAIT_REG(LRN, LRN).apply_on_var_node(apply_on_var_node).fallback(); } // namespace lrn -namespace layer_norm { - -cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { - auto&& op = static_cast(def); - size_t nr_inp = inputs.size(); - auto p = op.param(); - mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine)); - OperatorNodeConfig config{op.make_name()}; - if (nr_inp == 3) { - return opr::LayerNorm::make( - inputs[0], inputs[1], inputs[2], op.param(), config)[0] - .node() - ->owner_opr(); - } else { - return opr::LayerNorm::make(inputs[0], op.param(), config)[0] - .node() - ->owner_opr(); - } -} - -OP_TRAIT_REG(LayerNorm, LayerNorm).apply_on_var_node(apply_on_var_node).fallback(); - -} // namespace layer_norm - } // namespace mgb::imperative diff --git a/imperative/src/impl/transformations/dtype_promote.cpp b/imperative/src/impl/transformations/dtype_promote.cpp index 5501bcf9..274ddf0e 100644 --- a/imperative/src/impl/transformations/dtype_promote.cpp +++ b/imperative/src/impl/transformations/dtype_promote.cpp @@ -289,6 +289,28 @@ ValueRefList batch_norm_rule(const OpDef& op, Span inputs) { return imperative::apply(op, inputs); } +ValueRefList layer_norm_rule(const OpDef& op, Span inputs) { + // avoid the amp_dtype_autocast + if (DTypePromoteCfg::amp_dtype_autocast_enabled) { + SmallVector dtypes = get_value_dtypes(inputs); + ValueRefList converted(inputs.size()); + + for (size_t i = 0; i < inputs.size(); ++i) { + mgb::DType target_dtype = DTypePromoteCfg::amp_high_prec_dtype; + if (dtypes[i] != target_dtype) { + converted[i] = imperative::apply( + ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; + } else { + converted[i] = inputs[i]; + } + } + + return imperative::apply(op, converted); + } + + return imperative::apply(op, inputs); +} + ValueRefList naive_promote_rule(const OpDef& op, Span inputs) { SmallVector dtypes = get_value_dtypes(inputs); mgb::DType target_dtype = get_promoted_dtype(dtypes); @@ -319,6 +341,7 @@ struct DTypePromoteRuleRegistry { register_dtype_promote_rule(batch_norm_rule); register_dtype_promote_rule(naive_promote_rule); register_dtype_promote_rule(naive_promote_rule); + register_dtype_promote_rule(layer_norm_rule); } } register_helper;