|
@@ -10,6 +10,8 @@ |
|
|
*/ |
|
|
*/ |
|
|
|
|
|
|
|
|
#include "megbrain/opr/dnn/batch_norm.h" |
|
|
#include "megbrain/opr/dnn/batch_norm.h" |
|
|
|
|
|
#include "../blob_manager_impl.h" |
|
|
|
|
|
#include "../dnn_op_helper.h" |
|
|
#include "../op_trait.h" |
|
|
#include "../op_trait.h" |
|
|
#include "megbrain/imperative/graph_builder.h" |
|
|
#include "megbrain/imperative/graph_builder.h" |
|
|
#include "megbrain/imperative/ops/autogen.h" |
|
|
#include "megbrain/imperative/ops/autogen.h" |
|
@@ -138,7 +140,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
SmallVector<LogicalTensorDesc> out_shapes(nr_out); |
|
|
SmallVector<LogicalTensorDesc> out_shapes(nr_out); |
|
|
auto&& i0 = inputs[0]; |
|
|
auto&& i0 = inputs[0]; |
|
|
auto&& i1 = inputs[1]; |
|
|
auto&& i1 = inputs[1]; |
|
|
// [running_mean, running_var,] save_mean, save_var |
|
|
|
|
|
|
|
|
// [running_mean, running_var,] save_mean, save_variance |
|
|
for (size_t i = 0; i < nr_out - 2; ++i) { |
|
|
for (size_t i = 0; i < nr_out - 2; ++i) { |
|
|
out_shapes[i] = {i1.layout, i1.comp_node}; |
|
|
out_shapes[i] = {i1.layout, i1.comp_node}; |
|
|
} |
|
|
} |
|
@@ -148,10 +150,122 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
return {out_shapes, out_shapes[nr_out - 1].layout.ndim != 0}; |
|
|
return {out_shapes, out_shapes[nr_out - 1].layout.ndim != 0}; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
SmallVector<TensorPtr> apply_on_physical_tensor( |
|
|
|
|
|
const OpDef& def, const SmallVector<TensorPtr>& inputs, |
|
|
|
|
|
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { |
|
|
|
|
|
auto&& op_def = def.cast_final_safe<BatchNorm>(); |
|
|
|
|
|
auto&& comp_node = inputs[0]->comp_node(); |
|
|
|
|
|
|
|
|
|
|
|
using TensorND = megdnn::TensorND; |
|
|
|
|
|
|
|
|
|
|
|
SmallVector<TensorND> inp_tensornds(inputs.size()); |
|
|
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
|
|
|
inp_tensornds[i] = inputs[i]->dnn_tensor(); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
DnnOprCaller<megdnn::BN> dnn_opr(comp_node); |
|
|
|
|
|
dnn_opr.op->param() = op_def.param(); |
|
|
|
|
|
|
|
|
|
|
|
TensorLayout src_layout = inputs[0]->layout(); |
|
|
|
|
|
TensorLayout scale_layout = inputs[1]->layout(); |
|
|
|
|
|
bool empty_input = src_layout.is_empty(); |
|
|
|
|
|
size_t nr_inp = inputs.size(); |
|
|
|
|
|
|
|
|
|
|
|
DeviceTensorND ws, reserve; |
|
|
|
|
|
size_t sz = 0, rsz = 0; |
|
|
|
|
|
|
|
|
|
|
|
TensorLayout w_layout({sz}, dtype::Byte()); |
|
|
|
|
|
TensorLayout r_layout({rsz}, dtype::Byte()); |
|
|
|
|
|
|
|
|
|
|
|
if (!empty_input) { |
|
|
|
|
|
sz = dnn_opr.op->get_workspace_in_bytes( |
|
|
|
|
|
src_layout, src_layout, src_layout, src_layout, src_layout, src_layout, |
|
|
|
|
|
src_layout, src_layout, src_layout); |
|
|
|
|
|
rsz = dnn_opr.op->get_reserve_in_bytes(src_layout); |
|
|
|
|
|
|
|
|
|
|
|
w_layout = TensorLayout({sz}, dtype::Byte()); |
|
|
|
|
|
r_layout = TensorLayout({rsz}, dtype::Byte()); |
|
|
|
|
|
} |
|
|
|
|
|
auto wk = Blob::make(comp_node, sz); |
|
|
|
|
|
auto ptr = wk->storage().get(); |
|
|
|
|
|
megdnn::Workspace dnn_wk(ptr, sz); |
|
|
|
|
|
reserve = BlobManager::inst()->alloc_workspace_with_defrag(comp_node, r_layout); |
|
|
|
|
|
|
|
|
|
|
|
// alloc memory |
|
|
|
|
|
DeviceTensorND y = |
|
|
|
|
|
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, src_layout); |
|
|
|
|
|
|
|
|
|
|
|
DeviceTensorND save_mean = |
|
|
|
|
|
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, scale_layout); |
|
|
|
|
|
DeviceTensorND save_variance = |
|
|
|
|
|
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, scale_layout); |
|
|
|
|
|
|
|
|
|
|
|
if (op_def.fwd_mode == ::megdnn::param::BN::FwdMode::INFERENCE) { |
|
|
|
|
|
if (!empty_input) |
|
|
|
|
|
dnn_opr.op->exec( |
|
|
|
|
|
inp_tensornds[0], inp_tensornds[1], inp_tensornds[2], |
|
|
|
|
|
inp_tensornds[3], inp_tensornds[4], save_mean.as_megdnn(), |
|
|
|
|
|
save_variance.as_megdnn(), reserve.as_megdnn(), y.as_megdnn(), |
|
|
|
|
|
dnn_wk); |
|
|
|
|
|
return {inputs[3], inputs[4], Tensor::make(reserve), Tensor::make(y)}; |
|
|
|
|
|
} else { |
|
|
|
|
|
DeviceTensorND mean, variance; |
|
|
|
|
|
if (nr_inp == 5) { |
|
|
|
|
|
mean = BlobManager::inst()->alloc_workspace_with_defrag( |
|
|
|
|
|
comp_node, scale_layout); |
|
|
|
|
|
variance = BlobManager::inst()->alloc_workspace_with_defrag( |
|
|
|
|
|
comp_node, scale_layout); |
|
|
|
|
|
|
|
|
|
|
|
megdnn::RefPtr src_ptr1( |
|
|
|
|
|
inp_tensornds[3].get_ref_ptr().get_ptr(), inputs[3]->offset()); |
|
|
|
|
|
megdnn::RefPtr dst_ptr1( |
|
|
|
|
|
mean.storage().get_ref_ptr(), mean.storage().offset(), false); |
|
|
|
|
|
comp_node.peer_copy_to_ref( |
|
|
|
|
|
comp_node, dst_ptr1, src_ptr1, scale_layout.span().high_byte); |
|
|
|
|
|
|
|
|
|
|
|
megdnn::RefPtr src_ptr2( |
|
|
|
|
|
inp_tensornds[4].get_ref_ptr().get_ptr(), inputs[4]->offset()); |
|
|
|
|
|
megdnn::RefPtr dst_ptr2( |
|
|
|
|
|
variance.storage().get_ref_ptr(), variance.storage().offset(), |
|
|
|
|
|
false); |
|
|
|
|
|
comp_node.peer_copy_to_ref( |
|
|
|
|
|
comp_node, dst_ptr2, src_ptr2, scale_layout.span().high_byte); |
|
|
|
|
|
|
|
|
|
|
|
if (!empty_input) |
|
|
|
|
|
dnn_opr.op->exec( |
|
|
|
|
|
inp_tensornds[0], inp_tensornds[1], inp_tensornds[2], |
|
|
|
|
|
mean.as_megdnn(), variance.as_megdnn(), save_mean.as_megdnn(), |
|
|
|
|
|
save_variance.as_megdnn(), reserve.as_megdnn(), y.as_megdnn(), |
|
|
|
|
|
dnn_wk); |
|
|
|
|
|
|
|
|
|
|
|
return {Tensor::make(mean), Tensor::make(variance), |
|
|
|
|
|
Tensor::make(save_mean), Tensor::make(save_variance), |
|
|
|
|
|
Tensor::make(reserve), Tensor::make(y)}; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
TensorLayout m_layout({0}, scale_layout.dtype); |
|
|
|
|
|
mean = BlobManager::inst()->alloc_workspace_with_defrag(comp_node, m_layout); |
|
|
|
|
|
variance = |
|
|
|
|
|
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, m_layout); |
|
|
|
|
|
|
|
|
|
|
|
if (!empty_input) { |
|
|
|
|
|
dnn_opr.op->exec( |
|
|
|
|
|
inp_tensornds[0], inp_tensornds[1], inp_tensornds[2], |
|
|
|
|
|
mean.as_megdnn(), variance.as_megdnn(), save_mean.as_megdnn(), |
|
|
|
|
|
save_variance.as_megdnn(), reserve.as_megdnn(), y.as_megdnn(), |
|
|
|
|
|
dnn_wk); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
return {Tensor::make(save_mean), Tensor::make(save_variance), |
|
|
|
|
|
Tensor::make(reserve), Tensor::make(y)}; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm) |
|
|
OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm) |
|
|
.make_from_op_node(make_from_op_node) |
|
|
.make_from_op_node(make_from_op_node) |
|
|
.apply_on_var_node(apply_on_var_node) |
|
|
.apply_on_var_node(apply_on_var_node) |
|
|
.infer_output_attrs_fallible(infer_output_attrs_fallible) |
|
|
.infer_output_attrs_fallible(infer_output_attrs_fallible) |
|
|
|
|
|
.apply_on_physical_tensor(apply_on_physical_tensor) |
|
|
.fallback(); |
|
|
.fallback(); |
|
|
|
|
|
|
|
|
} // namespace bn |
|
|
} // namespace bn |
|
|