|
|
@@ -44,7 +44,7 @@ BatchNormForward::BatchNormForward(VarNode *x, |
|
|
|
m_force_inplace = false; |
|
|
|
} |
|
|
|
|
|
|
|
if (m_force_inplace) { |
|
|
|
if (m_force_inplace && param.fwd_mode == Param::FwdMode::TRAINING) { |
|
|
|
auto check_dest = [&](VarNode* dest) { |
|
|
|
auto dest_opr = dest->owner_opr(); |
|
|
|
mgb_throw_if(!(dest_opr->same_type<SharedDeviceTensor>() || |
|
|
@@ -62,7 +62,14 @@ BatchNormForward::BatchNormForward(VarNode *x, |
|
|
|
|
|
|
|
add_input({x, scale, bias, mean, variance}); |
|
|
|
|
|
|
|
if (m_force_inplace) { |
|
|
|
if (param.fwd_mode == Param::FwdMode::INFERENCE) { |
|
|
|
auto mark_empty_var = [&](VarNode *var) { |
|
|
|
var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) |
|
|
|
.add_flag(VarNode::Flag::VOLATILE_CONTENT); |
|
|
|
}; |
|
|
|
mark_empty_var(output(0)); |
|
|
|
mark_empty_var(output(1)); |
|
|
|
} else if (m_force_inplace) { |
|
|
|
output(0)-> |
|
|
|
set_fwd_in2out_writable_force(input(3)). |
|
|
|
add_flag(VarNode::Flag::NO_MEM_RECLAIM); |
|
|
@@ -129,7 +136,7 @@ SymbolVarArray BatchNormForward::make(SymbolVar x, |
|
|
|
cg::OperatorNodeBase::NodeProp* |
|
|
|
BatchNormForward::do_make_node_prop() const { |
|
|
|
auto ret = Super::do_make_node_prop(); |
|
|
|
if (input().size() == 5) { |
|
|
|
if (need_stats()) { |
|
|
|
ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR); |
|
|
|
} |
|
|
|
return ret; |
|
|
@@ -140,7 +147,7 @@ void BatchNormForward::scn_do_execute() { |
|
|
|
auto &&y = output(4)->dev_tensor(); |
|
|
|
mgb_assert(x.layout().is_contiguous() && |
|
|
|
y.layout().is_contiguous()); |
|
|
|
if (input().size() == 5) { // need running mean/variance |
|
|
|
if (need_stats()) { |
|
|
|
auto &&o0 = output(0)->dev_tensor(), |
|
|
|
&&o1 = output(1)->dev_tensor(), |
|
|
|
&&i0 = input(3)->dev_tensor(), |
|
|
@@ -164,8 +171,14 @@ void BatchNormForward::scn_do_execute() { |
|
|
|
} |
|
|
|
auto scale = input(1)->dev_tensor().as_megdnn(); |
|
|
|
auto bias = input(2)->dev_tensor().as_megdnn(); |
|
|
|
auto mean = output(0)->dev_tensor().as_megdnn(); |
|
|
|
auto variance = output(1)->dev_tensor().as_megdnn(); |
|
|
|
megdnn::TensorND mean, variance; |
|
|
|
if (param().fwd_mode == Param::FwdMode::INFERENCE) { |
|
|
|
mean = input(3)->dev_tensor().as_megdnn(); |
|
|
|
variance = input(4)->dev_tensor().as_megdnn(); |
|
|
|
} else { |
|
|
|
mean = output(0)->dev_tensor().as_megdnn(); |
|
|
|
variance = output(1)->dev_tensor().as_megdnn(); |
|
|
|
} |
|
|
|
auto save_mean = output(2)->dev_tensor().as_megdnn(); |
|
|
|
auto save_variance = output(3)->dev_tensor().as_megdnn(); |
|
|
|
auto workspace = intl::get_megdnn_workspace_from_var(output().back()); |
|
|
@@ -180,12 +193,11 @@ void BatchNormForward::add_input_layout_constraint() { |
|
|
|
void BatchNormForward::get_output_var_shape( |
|
|
|
const TensorShapeArray &inp_shape, |
|
|
|
TensorShapeArray &out_shape) const { |
|
|
|
size_t nr_inp = input().size(); |
|
|
|
out_shape[4] = inp_shape[0]; |
|
|
|
for (size_t i = 0; i < 4; ++ i) { |
|
|
|
out_shape[i] = inp_shape[1]; |
|
|
|
} |
|
|
|
if (nr_inp == 3) { |
|
|
|
if (!need_stats()) { |
|
|
|
out_shape[0] = out_shape[1] = {0}; |
|
|
|
} |
|
|
|
} |
|
|
@@ -221,7 +233,7 @@ void BatchNormForward::init_output_dtype() { |
|
|
|
} |
|
|
|
|
|
|
|
void BatchNormForward::mem_plan_fwd_in2out_writable() { |
|
|
|
if (!m_force_inplace && input().size() == 5) { |
|
|
|
if (need_stats() && !m_force_inplace) { |
|
|
|
// TODO: testing |
|
|
|
output(0)->set_fwd_in2out_writable(input(3)); |
|
|
|
output(1)->set_fwd_in2out_writable(input(4)); |
|
|
|