GitOrigin-RevId: 6f359b5b29
release-1.10
@@ -1939,6 +1939,11 @@ class LayerNormBase : public OperatorBase { | |||
DEF_OPR_IMPL_CTOR(LayerNormBase, OperatorBase); | |||
DEF_OPR_PARAM(LayerNorm); | |||
public: | |||
MGE_WIN_DECLSPEC_FUC static void deduce_layout_fwd_impl( | |||
const TensorLayout& data, const Param& p, TensorLayout& dst, | |||
TensorLayout& mean, TensorLayout& rstd); | |||
protected: | |||
void deduce_layout_fwd( | |||
const TensorLayout& data, const TensorLayout& weight, | |||
@@ -4,12 +4,11 @@ | |||
namespace megdnn { | |||
void LayerNormBase::deduce_layout_fwd( | |||
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||
TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) { | |||
MEGDNN_MARK_USED_VAR(weight); | |||
MEGDNN_MARK_USED_VAR(bias); | |||
auto p = param(); | |||
using Param = LayerNormBase::Param; | |||
void LayerNormBase::deduce_layout_fwd_impl( | |||
const TensorLayout& data, const Param& p, TensorLayout& dst, TensorLayout& mean, | |||
TensorLayout& rstd) { | |||
TensorShape unnormalized_shape; | |||
unnormalized_shape.ndim = data.ndim - p.normalized_dim; | |||
for (size_t i = 0; i < unnormalized_shape.ndim; ++i) { | |||
@@ -22,6 +21,14 @@ void LayerNormBase::deduce_layout_fwd( | |||
rstd = unnormalized_layout; | |||
} | |||
void LayerNormBase::deduce_layout_fwd( | |||
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||
TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) { | |||
MEGDNN_MARK_USED_VAR(weight); | |||
MEGDNN_MARK_USED_VAR(bias); | |||
deduce_layout_fwd_impl(data, param(), dst, mean, rstd); | |||
} | |||
void LayerNormBase::check_layout_fwd( | |||
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||
const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd) { | |||
@@ -63,6 +63,7 @@ __all__ = [ | |||
"hsigmoid", | |||
"hswish", | |||
"indexing_one_hot", | |||
"layer_norm", | |||
"leaky_relu", | |||
"linear", | |||
"local_conv2d", | |||
@@ -1135,9 +1136,6 @@ def layer_norm( | |||
bias: must not be None when the affine is true | |||
eps: a value added to the denominator for numerical stability. Default: 1e-5 | |||
""" | |||
if amp._enabled: | |||
inp, weight, bias = cast_tensors(inp, weight, bias, promote=True) | |||
if isinstance(normalized_shape, int): | |||
normalized_shape = [normalized_shape] | |||
@@ -0,0 +1,115 @@ | |||
#include "megbrain/opr/dnn/layer_norm.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
#include "../blob_manager_impl.h" | |||
#include "../dnn_op_helper.h" | |||
#include "../op_trait.h" | |||
namespace mgb::imperative { | |||
namespace layer_norm { | |||
cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const LayerNorm&>(def); | |||
size_t nr_inp = inputs.size(); | |||
auto p = op.param(); | |||
mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine)); | |||
OperatorNodeConfig config{op.make_name()}; | |||
if (nr_inp == 3) { | |||
return opr::LayerNorm::make( | |||
inputs[0], inputs[1], inputs[2], op.param(), config)[0] | |||
.node() | |||
->owner_opr(); | |||
} else { | |||
return opr::LayerNorm::make(inputs[0], op.param(), config)[0] | |||
.node() | |||
->owner_opr(); | |||
} | |||
} | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||
auto&& op_def = def.cast_final_safe<LayerNorm>(); | |||
size_t nr_inp = inputs.size(); | |||
auto p = op_def.param(); | |||
mgb_assert( | |||
(nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine), | |||
"num of inputs of pooling should be 1 or 3 but you give %zu", | |||
inputs.size()); | |||
auto&& inp = inputs[0]; | |||
auto& inp_cn = inp.comp_node; | |||
if (inp.layout.ndim == 0) { | |||
return {{{TensorLayout{inp.layout.dtype}, inp_cn, {}}, | |||
{TensorLayout{dtype::Float32()}, inp_cn, {}}, | |||
{TensorLayout{dtype::Float32()}, inp_cn, {}}}, | |||
false}; | |||
} | |||
TensorLayout oup_layout, mean_layout, rstd_layout; | |||
megdnn::LayerNorm::deduce_layout_fwd_impl( | |||
inp.layout, p, oup_layout, mean_layout, rstd_layout); | |||
return {{{oup_layout, inp_cn, {}}, | |||
{mean_layout, inp_cn, {}}, | |||
{rstd_layout, inp_cn, {}}}, | |||
true}; | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
auto&& op_def = def.cast_final_safe<LayerNorm>(); | |||
size_t nr_inp = inputs.size(); | |||
auto p = op_def.param(); | |||
mgb_assert( | |||
(nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine), | |||
"num of inputs of pooling should be 1 or 3 but you give %zu", | |||
inputs.size()); | |||
auto cn = inputs[0]->comp_node(); | |||
DnnOprCaller<megdnn::LayerNorm> caller(cn); | |||
auto&& dnn_opr = caller.op; | |||
dnn_opr->param() = p; | |||
TensorLayout oup_layout, mean_layout, rstd_layout; | |||
megdnn::LayerNorm::deduce_layout_fwd_impl( | |||
inputs[0]->dnn_tensor().layout, p, oup_layout, mean_layout, rstd_layout); | |||
DeviceTensorND out_devtensor = | |||
BlobManager::inst()->alloc_workspace_with_defrag(cn, oup_layout); | |||
DeviceTensorND mean_devtensor = | |||
BlobManager::inst()->alloc_workspace_with_defrag(cn, mean_layout); | |||
DeviceTensorND rstd_devtensor = | |||
BlobManager::inst()->alloc_workspace_with_defrag(cn, rstd_layout); | |||
megdnn::Workspace dnn_wk; | |||
auto wk_size = caller.op->get_workspace_in_bytes( | |||
inputs[0]->dnn_tensor().layout, | |||
p.affine ? inputs[1]->dnn_tensor().layout : TensorLayout(), | |||
p.affine ? inputs[2]->dnn_tensor().layout : TensorLayout(), oup_layout, | |||
mean_layout, rstd_layout); | |||
if (wk_size != 0) { | |||
TensorLayout w_layout({wk_size}, dtype::Byte()); | |||
dnn_wk = caller.create_workspace(w_layout); | |||
} | |||
dnn_opr->exec( | |||
inputs[0]->dnn_tensor(), | |||
p.affine ? inputs[1]->dnn_tensor() : megdnn::TensorND(), | |||
p.affine ? inputs[2]->dnn_tensor() : megdnn::TensorND(), | |||
out_devtensor.as_megdnn(), mean_devtensor.as_megdnn(), | |||
rstd_devtensor.as_megdnn(), dnn_wk); | |||
return {Tensor::make(out_devtensor), Tensor::make(mean_devtensor), | |||
Tensor::make(rstd_devtensor)}; | |||
} | |||
OP_TRAIT_REG(LayerNorm, LayerNorm) | |||
.apply_on_var_node(apply_on_var_node) | |||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||
.fallback(); | |||
} // namespace layer_norm | |||
} // namespace mgb::imperative |
@@ -8,7 +8,6 @@ | |||
#include "megbrain/opr/dnn/correlation.h" | |||
#include "megbrain/opr/dnn/fake_quant.h" | |||
#include "megbrain/opr/dnn/images2neibs.h" | |||
#include "megbrain/opr/dnn/layer_norm.h" | |||
#include "megbrain/opr/dnn/local.h" | |||
#include "megbrain/opr/dnn/lrn.h" | |||
#include "megbrain/opr/dnn/lsq.h" | |||
@@ -729,28 +728,4 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
OP_TRAIT_REG(LRN, LRN).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace lrn | |||
namespace layer_norm { | |||
cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const LayerNorm&>(def); | |||
size_t nr_inp = inputs.size(); | |||
auto p = op.param(); | |||
mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine)); | |||
OperatorNodeConfig config{op.make_name()}; | |||
if (nr_inp == 3) { | |||
return opr::LayerNorm::make( | |||
inputs[0], inputs[1], inputs[2], op.param(), config)[0] | |||
.node() | |||
->owner_opr(); | |||
} else { | |||
return opr::LayerNorm::make(inputs[0], op.param(), config)[0] | |||
.node() | |||
->owner_opr(); | |||
} | |||
} | |||
OP_TRAIT_REG(LayerNorm, LayerNorm).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace layer_norm | |||
} // namespace mgb::imperative |
@@ -289,6 +289,28 @@ ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) { | |||
return imperative::apply(op, inputs); | |||
} | |||
ValueRefList layer_norm_rule(const OpDef& op, Span<ValueRef> inputs) { | |||
// avoid the amp_dtype_autocast | |||
if (DTypePromoteCfg::amp_dtype_autocast_enabled) { | |||
SmallVector<DType> dtypes = get_value_dtypes(inputs); | |||
ValueRefList converted(inputs.size()); | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
mgb::DType target_dtype = DTypePromoteCfg::amp_high_prec_dtype; | |||
if (dtypes[i] != target_dtype) { | |||
converted[i] = imperative::apply( | |||
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; | |||
} else { | |||
converted[i] = inputs[i]; | |||
} | |||
} | |||
return imperative::apply(op, converted); | |||
} | |||
return imperative::apply(op, inputs); | |||
} | |||
ValueRefList naive_promote_rule(const OpDef& op, Span<ValueRef> inputs) { | |||
SmallVector<DType> dtypes = get_value_dtypes(inputs); | |||
mgb::DType target_dtype = get_promoted_dtype(dtypes); | |||
@@ -319,6 +341,7 @@ struct DTypePromoteRuleRegistry { | |||
register_dtype_promote_rule<BatchNorm>(batch_norm_rule); | |||
register_dtype_promote_rule<Convolution3D>(naive_promote_rule); | |||
register_dtype_promote_rule<Convolution3DBackwardData>(naive_promote_rule); | |||
register_dtype_promote_rule<LayerNorm>(layer_norm_rule); | |||
} | |||
} register_helper; | |||