|
|
@@ -152,9 +152,43 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; |
|
|
|
} |
|
|
|
|
|
|
|
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( |
|
|
|
const OpDef& def, |
|
|
|
const SmallVector<TensorPtr>& inputs, |
|
|
|
const SmallVector<MemoryDesc>& inputs_mems) { |
|
|
|
auto&& op_def = def.cast_final_safe<Reshape>(); |
|
|
|
size_t nr_inp = inputs.size(); |
|
|
|
mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp); |
|
|
|
auto&& src = inputs[0]; |
|
|
|
auto&& tshp_nd = inputs[1]; |
|
|
|
auto slayout = src->layout(); |
|
|
|
|
|
|
|
TensorShape tshp; |
|
|
|
cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu()); |
|
|
|
if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) { |
|
|
|
mgb_assert(tshp[op_def.axis] == -1); |
|
|
|
tshp[op_def.axis] = 1; |
|
|
|
tshp[op_def.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems(); |
|
|
|
} |
|
|
|
TensorLayout tlayout = slayout.reshape(tshp); |
|
|
|
// memory forward |
|
|
|
return {{{tlayout, 0, src->comp_node(), StorageIdentifier::make(&inputs_mems[0])}}, {}}; |
|
|
|
} |
|
|
|
|
|
|
|
void execute( |
|
|
|
const OpDef& def, |
|
|
|
SmallVector<TensorPtr> inputs, |
|
|
|
SmallVector<TensorPtr> outputs, |
|
|
|
SmallVector<TensorPtr> workspace) { |
|
|
|
mgb_assert(inputs[0]->offset() == outputs[0]->offset()); |
|
|
|
mgb_assert(inputs[0]->blob() == outputs[0]->blob()); |
|
|
|
} |
|
|
|
|
|
|
|
OP_TRAIT_REG(Reshape, Reshape) |
|
|
|
.apply_on_var_node(apply_on_var_node) |
|
|
|
.infer_output_attrs_fallible(infer_output_attrs_fallible) |
|
|
|
.infer_output_mem_desc(infer_output_mem_desc) |
|
|
|
.execute(execute) |
|
|
|
.fallback(); |
|
|
|
} // reshape |
|
|
|
|
|
|
|