Browse Source

perf(opdef/reshape): specialize Reshape

GitOrigin-RevId: 26d0e151ca
tags/v1.6.0-rc1
Megvii Engine Team 3 years ago
parent
commit
0ef5183c65
1 changed files with 34 additions and 0 deletions
  1. +34
    -0
      imperative/src/impl/ops/broadcast.cpp

+ 34
- 0
imperative/src/impl/ops/broadcast.cpp View File

@@ -152,9 +152,43 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
}

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def,
const SmallVector<TensorPtr>& inputs,
const SmallVector<MemoryDesc>& inputs_mems) {
auto&& op_def = def.cast_final_safe<Reshape>();
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp);
auto&& src = inputs[0];
auto&& tshp_nd = inputs[1];
auto slayout = src->layout();

TensorShape tshp;
cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu());
if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) {
mgb_assert(tshp[op_def.axis] == -1);
tshp[op_def.axis] = 1;
tshp[op_def.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems();
}
TensorLayout tlayout = slayout.reshape(tshp);
// memory forward
return {{{tlayout, 0, src->comp_node(), StorageIdentifier::make(&inputs_mems[0])}}, {}};
}

void execute(
const OpDef& def,
SmallVector<TensorPtr> inputs,
SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
mgb_assert(inputs[0]->offset() == outputs[0]->offset());
mgb_assert(inputs[0]->blob() == outputs[0]->blob());
}

OP_TRAIT_REG(Reshape, Reshape)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.infer_output_mem_desc(infer_output_mem_desc)
.execute(execute)
.fallback();
} // reshape



Loading…
Cancel
Save