Browse Source

perf(imperative/src): improve elemwise

GitOrigin-RevId: 78aa487277
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
8446626193
5 changed files with 53 additions and 7 deletions
  1. +36
    -7
      imperative/src/impl/ops/elemwise.cpp
  2. +5
    -0
      imperative/src/impl/physical_tensor.cpp
  3. +2
    -0
      imperative/src/include/megbrain/imperative/physical_tensor.h
  4. +6
    -0
      src/opr/impl/basic_arith.cpp
  5. +4
    -0
      src/opr/include/megbrain/opr/basic_arith.h

+ 36
- 7
imperative/src/impl/ops/elemwise.cpp View File

@@ -114,15 +114,44 @@ void apply_on_device_tensornd(
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
SmallVector<DeviceTensorND> inp_tensornds(inputs.size());
auto comp_node = inputs[0]->comp_node();
using Mode = Elemwise::Mode;
using TensorND = megdnn::TensorND;
auto&& op_def = def.cast_final_safe<Elemwise>();
SmallVector<TensorND> inp_tensornds;
TensorShapeArray inp_shapes(inputs.size());
inp_tensornds.reserve(inputs.size());

TensorLayout layout{inputs[0]->layout().dtype};
bool is_empty = false;
for (unsigned i = 0; i < inputs.size(); ++i) {
inp_tensornds[i] = inputs[i]->dev_tensor();
if (inputs[i]->layout().is_empty()) {
is_empty = true;
}
inp_tensornds.push_back(inputs[i]->dnn_tensor());
inp_shapes[i] = inputs[i]->layout();
}
megdnn::Elemwise::deduce_shape(inp_shapes, layout);
layout.init_contiguous_stride();

DeviceTensorND out =
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, layout);
if (is_empty) {
return {Tensor::make(out)};
}
auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::Elemwise>(comp_node);

dnn_opr->param() = op_def.param();
if (dnn_opr->param().mode == Mode::FUSE_MUL_ADD3 ||
dnn_opr->param().mode == Mode::FUSE_MUL_ADD4 ||
(inp_tensornds.size() &&
inp_tensornds[0].layout.dtype.category() == DTypeCategory::QUANTIZED)) {
opr::Elemwise::perform_dnn(comp_node, out, inp_tensornds, dnn_opr);
} else {
dnn_opr->exec(inp_tensornds, out.as_megdnn());
}
DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag(
inp_tensornds[0].comp_node(), output_descs[0].layout);
SmallVector<DeviceTensorND> oup_tensornds = {out};
apply_on_device_tensornd(def, inp_tensornds, &oup_tensornds);
return {Tensor::make(oup_tensornds[0])};

return {Tensor::make(out)};
}

MGB_DEFINE_OPR_CLASS(


+ 5
- 0
imperative/src/impl/physical_tensor.cpp View File

@@ -212,6 +212,11 @@ DeviceTensorND Tensor::dev_tensor(bool contiguous) {
return ret;
}

megdnn::TensorND Tensor::dnn_tensor() {
mgb_assert(m_blob, "uninitialized tensor.");
return {m_layout, {m_blob->storage().get(), m_offset}};
}

void Tensor::fetch_value() {
MGB_LOCK_GUARD(m_blob_mtx);
MGB_LOCK_GUARD(m_value_mtx);


+ 2
- 0
imperative/src/include/megbrain/imperative/physical_tensor.h View File

@@ -110,6 +110,8 @@ public:

void assign_from_dev_tensor(DeviceTensorND);

megdnn::TensorND dnn_tensor();

static TensorPtr make_scalar(DTypeScalar value, CompNode cn);

TensorPtr make_scalar(DTypeScalar value) const {


+ 6
- 0
src/opr/impl/basic_arith.cpp View File

@@ -268,6 +268,12 @@ void Elemwise::perform(
call_megdnn_opr_exec(out_cn, dnn_inputs, dest.as_megdnn(), opr.get(), nullptr);
}

void Elemwise::perform_dnn(
CompNode cn, DeviceTensorND& dest, megdnn::TensorNDArray& inputs,
intl::UniqPtrWithCN<megdnn::Elemwise>& opr) {
call_megdnn_opr_exec(cn, inputs, dest.as_megdnn(), opr.get(), nullptr);
}

TensorLayoutArray Elemwise::collective_collapse(const TensorLayoutArray& layouts) {
TensorLayoutPtrArray inp(layouts.size());
TensorLayoutArray result(inp.size());


+ 4
- 0
src/opr/include/megbrain/opr/basic_arith.h View File

@@ -88,6 +88,10 @@ public:
Mode mode, DeviceTensorND& dest, const SmallVector<DeviceTensorND>& inputs,
intl::UniqPtrWithCN<megdnn::Elemwise>& opr);

MGE_WIN_DECLSPEC_FUC static void perform_dnn(
CompNode cn, DeviceTensorND& dest, megdnn::TensorNDArray& inputs,
intl::UniqPtrWithCN<megdnn::Elemwise>& opr);

using TensorLayoutPtrArray = SmallVector<TensorLayout*>;

/*!


Loading…
Cancel
Save