Browse Source

perf(ops): specialize Broadcast

GitOrigin-RevId: 0cba3e6e93
tags/v1.6.0-rc1
Megvii Engine Team 3 years ago
parent
commit
cd60d26852
3 changed files with 85 additions and 25 deletions
  1. +38
    -0
      imperative/src/impl/ops/broadcast.cpp
  2. +47
    -0
      imperative/src/impl/ops/reduce.cpp
  3. +0
    -25
      imperative/src/impl/ops/specializations.cpp

+ 38
- 0
imperative/src/impl/ops/broadcast.cpp View File

@@ -12,6 +12,8 @@
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/tensor_manip.h"

#include "megbrain/graph/helper.h"

#include "../op_trait.h"

namespace mgb {
@@ -83,10 +85,46 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
}

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def,
const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
auto& input = inputs_tensors[0];
TensorShape target_shape;
cg::copy_tensor_value_to_shape(target_shape, inputs_tensors[1]->get_value().proxy_to_default_cpu());
// TODO: memory forward
// if (input->shape().eq_shape(target_shape)) {
// return {{{input->layout(), 0, input->comp_node(), StorageIdentifier::make(&inputs_mems[0])}}, {}};
// }
return {{{{target_shape, input->dtype()}, 0, input->comp_node(), StorageIdentifier::make(0)}}, {}};
}

void execute(
const OpDef& def,
SmallVector<TensorPtr> inputs,
SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
if (outputs[0]->layout().is_empty()) {
return;
}
if (inputs[0]->shape().eq_shape(outputs[0]->shape())) {
mgb_assert(inputs[0]->layout().eq_layout(outputs[0]->layout()));
// TODO: memory forward
// mgb_assert(inputs[0]->offset() == outputs[0]->offset());
// mgb_assert(inputs[0]->blob() == outputs[0]->blob());
outputs[0]->dev_tensor().copy_from_fixlayout(inputs[0]->dev_tensor());
} else {
TensorLayout input_layout = inputs[0]->layout().broadcast(outputs[0]->shape());
outputs[0]->dev_tensor().copy_from_fixlayout(inputs[0]->dev_tensor().sub(SubTensorSpec::make_from_layout(input_layout)));
}
}

OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.infer_output_mem_desc(infer_output_mem_desc)
.execute(execute)
.fallback();
} // broadcast



+ 47
- 0
imperative/src/impl/ops/reduce.cpp View File

@@ -0,0 +1,47 @@
/**
* \file imperative/src/impl/ops/reduce.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/basic_arith.h"

#include "../op_trait.h"
#include "../dnn_op_helper.h"

namespace mgb {
namespace imperative {
namespace {
namespace reduce {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& reduce = static_cast<const Reduce&>(def);
OperatorNodeConfig config{reduce.make_name()};
if (inputs.size() > 1) {
return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config);
} else {
return opr::Reduce::make(inputs[0], reduce.param(),
(cg::VarNode*)nullptr, config);
}
}

std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::Reduce>();
return Reduce::make(node->param());
}

OP_TRAIT_REG(Reduce, Reduce, opr::Reduce)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.fallback();
} // namespace reduce
} // namespace
} // namespace imperative
} // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 0
- 25
imperative/src/impl/ops/specializations.cpp View File

@@ -117,31 +117,6 @@ OP_TRAIT_REG(TopK, TopK).apply_on_var_node(apply_on_var_node).fallback();
} // namespace

namespace {
namespace reduce {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& reduce = static_cast<const Reduce&>(def);
OperatorNodeConfig config{reduce.make_name()};
if (inputs.size() > 1) {
return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config);
} else {
return opr::Reduce::make(inputs[0], reduce.param(),
(cg::VarNode*)nullptr, config);
}
}

std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::Reduce>();
return Reduce::make(node->param());
}

OP_TRAIT_REG(Reduce, Reduce, opr::Reduce)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.fallback();
} // namespace reduce
} // namespace

namespace {
namespace adaptive_pooling {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& pool = static_cast<const AdaptivePooling&>(def);


Loading…
Cancel
Save