|
@@ -62,6 +62,7 @@ BatchNormForward::BatchNormForward(VarNode *x, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
init_megdnn_opr(*this, param); |
|
|
init_megdnn_opr(*this, param); |
|
|
|
|
|
output(4)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); |
|
|
|
|
|
|
|
|
add_input({x, scale, bias, mean, variance}); |
|
|
add_input({x, scale, bias, mean, variance}); |
|
|
|
|
|
|
|
@@ -91,6 +92,7 @@ BatchNormForward::BatchNormForward(VarNode *x, |
|
|
{x, scale, bias}} |
|
|
{x, scale, bias}} |
|
|
{ |
|
|
{ |
|
|
init_megdnn_opr(*this, param); |
|
|
init_megdnn_opr(*this, param); |
|
|
|
|
|
output(4)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); |
|
|
|
|
|
|
|
|
add_input({x, scale, bias}); |
|
|
add_input({x, scale, bias}); |
|
|
auto mark_empty_var = [&](VarNode *var) { |
|
|
auto mark_empty_var = [&](VarNode *var) { |
|
@@ -139,6 +141,8 @@ SymbolVarArray BatchNormForward::make(SymbolVar x, |
|
|
cg::OperatorNodeBase::NodeProp* |
|
|
cg::OperatorNodeBase::NodeProp* |
|
|
BatchNormForward::do_make_node_prop() const { |
|
|
BatchNormForward::do_make_node_prop() const { |
|
|
auto ret = Super::do_make_node_prop(); |
|
|
auto ret = Super::do_make_node_prop(); |
|
|
|
|
|
ret->add_dep_type_existing_var(input(0), |
|
|
|
|
|
NodeProp::DepType::VALUE_ALLOW_EMPTY); |
|
|
if (need_stats() && m_force_inplace) { |
|
|
if (need_stats() && m_force_inplace) { |
|
|
ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR); |
|
|
ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR); |
|
|
} |
|
|
} |
|
@@ -148,8 +152,6 @@ BatchNormForward::do_make_node_prop() const { |
|
|
void BatchNormForward::scn_do_execute() { |
|
|
void BatchNormForward::scn_do_execute() { |
|
|
auto &&x = input(0)->dev_tensor(); |
|
|
auto &&x = input(0)->dev_tensor(); |
|
|
auto &&y = output(4)->dev_tensor(); |
|
|
auto &&y = output(4)->dev_tensor(); |
|
|
mgb_assert(x.layout().is_contiguous() && |
|
|
|
|
|
y.layout().is_contiguous()); |
|
|
|
|
|
if (need_stats()) { |
|
|
if (need_stats()) { |
|
|
auto &&o0 = output(0)->dev_tensor(), |
|
|
auto &&o0 = output(0)->dev_tensor(), |
|
|
&&o1 = output(1)->dev_tensor(), |
|
|
&&o1 = output(1)->dev_tensor(), |
|
@@ -172,6 +174,12 @@ void BatchNormForward::scn_do_execute() { |
|
|
&& o1.raw_ptr() == i1.raw_ptr()); |
|
|
&& o1.raw_ptr() == i1.raw_ptr()); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
mgb_assert(x.layout().eq_layout(y.layout())); |
|
|
|
|
|
if (x.layout().is_empty()) { |
|
|
|
|
|
return; |
|
|
|
|
|
} |
|
|
|
|
|
mgb_assert(x.layout().is_contiguous() && |
|
|
|
|
|
y.layout().is_contiguous()); |
|
|
auto scale = input(1)->dev_tensor().as_megdnn(); |
|
|
auto scale = input(1)->dev_tensor().as_megdnn(); |
|
|
auto bias = input(2)->dev_tensor().as_megdnn(); |
|
|
auto bias = input(2)->dev_tensor().as_megdnn(); |
|
|
megdnn::TensorND mean, variance; |
|
|
megdnn::TensorND mean, variance; |
|
@@ -196,6 +204,18 @@ void BatchNormForward::add_input_layout_constraint() { |
|
|
void BatchNormForward::get_output_var_shape( |
|
|
void BatchNormForward::get_output_var_shape( |
|
|
const TensorShapeArray &inp_shape, |
|
|
const TensorShapeArray &inp_shape, |
|
|
TensorShapeArray &out_shape) const { |
|
|
TensorShapeArray &out_shape) const { |
|
|
|
|
|
mgb_assert(inp_shape[0].ndim == 4 && inp_shape[0].ndim == 4 && inp_shape[1].ndim == 4, |
|
|
|
|
|
"expect input, scale and bias to be 4 dim tensor, but " |
|
|
|
|
|
"got input dim: %zu, scale dim: %zu, bias dim: %zu", |
|
|
|
|
|
inp_shape[0].ndim, inp_shape[1].ndim, inp_shape[2].ndim); |
|
|
|
|
|
|
|
|
|
|
|
size_t inp_c = inp_shape[0][1], |
|
|
|
|
|
scale_c = inp_shape[1][1], |
|
|
|
|
|
bias_c = inp_shape[2][1]; |
|
|
|
|
|
mgb_assert(inp_c == scale_c && inp_c == bias_c, |
|
|
|
|
|
"inconsistent channel size, input chennel: %zu, scale channel: %zu, bias channel: %zu", |
|
|
|
|
|
inp_c, scale_c, bias_c); |
|
|
|
|
|
|
|
|
out_shape[4] = inp_shape[0]; |
|
|
out_shape[4] = inp_shape[0]; |
|
|
for (size_t i = 0; i < 4; ++ i) { |
|
|
for (size_t i = 0; i < 4; ++ i) { |
|
|
out_shape[i] = inp_shape[1]; |
|
|
out_shape[i] = inp_shape[1]; |
|
|