@@ -12,6 +12,8 @@ | |||||
#include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
#include "megbrain/opr/tensor_manip.h" | #include "megbrain/opr/tensor_manip.h" | ||||
#include "megbrain/graph/helper.h" | |||||
#include "../op_trait.h" | #include "../op_trait.h" | ||||
namespace mgb { | namespace mgb { | ||||
@@ -83,10 +85,46 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; | return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; | ||||
} | } | ||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||||
const OpDef& def, | |||||
const SmallVector<TensorPtr>& inputs_tensors, | |||||
const SmallVector<MemoryDesc>& inputs_mems) { | |||||
auto& input = inputs_tensors[0]; | |||||
TensorShape target_shape; | |||||
cg::copy_tensor_value_to_shape(target_shape, inputs_tensors[1]->get_value().proxy_to_default_cpu()); | |||||
// TODO: memory forward | |||||
// if (input->shape().eq_shape(target_shape)) { | |||||
// return {{{input->layout(), 0, input->comp_node(), StorageIdentifier::make(&inputs_mems[0])}}, {}}; | |||||
// } | |||||
return {{{{target_shape, input->dtype()}, 0, input->comp_node(), StorageIdentifier::make(0)}}, {}}; | |||||
} | |||||
void execute( | |||||
const OpDef& def, | |||||
SmallVector<TensorPtr> inputs, | |||||
SmallVector<TensorPtr> outputs, | |||||
SmallVector<TensorPtr> workspace) { | |||||
if (outputs[0]->layout().is_empty()) { | |||||
return; | |||||
} | |||||
if (inputs[0]->shape().eq_shape(outputs[0]->shape())) { | |||||
mgb_assert(inputs[0]->layout().eq_layout(outputs[0]->layout())); | |||||
// TODO: memory forward | |||||
// mgb_assert(inputs[0]->offset() == outputs[0]->offset()); | |||||
// mgb_assert(inputs[0]->blob() == outputs[0]->blob()); | |||||
outputs[0]->dev_tensor().copy_from_fixlayout(inputs[0]->dev_tensor()); | |||||
} else { | |||||
TensorLayout input_layout = inputs[0]->layout().broadcast(outputs[0]->shape()); | |||||
outputs[0]->dev_tensor().copy_from_fixlayout(inputs[0]->dev_tensor().sub(SubTensorSpec::make_from_layout(input_layout))); | |||||
} | |||||
} | |||||
OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) | OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) | ||||
.make_from_op_node(make_from_op_node) | .make_from_op_node(make_from_op_node) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | .infer_output_attrs_fallible(infer_output_attrs_fallible) | ||||
.infer_output_mem_desc(infer_output_mem_desc) | |||||
.execute(execute) | |||||
.fallback(); | .fallback(); | ||||
} // broadcast | } // broadcast | ||||
@@ -0,0 +1,47 @@ | |||||
/** | |||||
* \file imperative/src/impl/ops/reduce.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/imperative/ops/autogen.h" | |||||
#include "megbrain/opr/basic_arith.h" | |||||
#include "../op_trait.h" | |||||
#include "../dnn_op_helper.h" | |||||
namespace mgb { | |||||
namespace imperative { | |||||
namespace { | |||||
namespace reduce { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& reduce = static_cast<const Reduce&>(def); | |||||
OperatorNodeConfig config{reduce.make_name()}; | |||||
if (inputs.size() > 1) { | |||||
return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config); | |||||
} else { | |||||
return opr::Reduce::make(inputs[0], reduce.param(), | |||||
(cg::VarNode*)nullptr, config); | |||||
} | |||||
} | |||||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
auto* node = &node_->cast_final_safe<opr::Reduce>(); | |||||
return Reduce::make(node->param()); | |||||
} | |||||
OP_TRAIT_REG(Reduce, Reduce, opr::Reduce) | |||||
.make_from_op_node(make_from_op_node) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace reduce | |||||
} // namespace | |||||
} // namespace imperative | |||||
} // namespace mgb | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -117,31 +117,6 @@ OP_TRAIT_REG(TopK, TopK).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace | } // namespace | ||||
namespace { | namespace { | ||||
namespace reduce { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& reduce = static_cast<const Reduce&>(def); | |||||
OperatorNodeConfig config{reduce.make_name()}; | |||||
if (inputs.size() > 1) { | |||||
return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config); | |||||
} else { | |||||
return opr::Reduce::make(inputs[0], reduce.param(), | |||||
(cg::VarNode*)nullptr, config); | |||||
} | |||||
} | |||||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
auto* node = &node_->cast_final_safe<opr::Reduce>(); | |||||
return Reduce::make(node->param()); | |||||
} | |||||
OP_TRAIT_REG(Reduce, Reduce, opr::Reduce) | |||||
.make_from_op_node(make_from_op_node) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace reduce | |||||
} // namespace | |||||
namespace { | |||||
namespace adaptive_pooling { | namespace adaptive_pooling { | ||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | ||||
auto&& pool = static_cast<const AdaptivePooling&>(def); | auto&& pool = static_cast<const AdaptivePooling&>(def); | ||||