|
|
@@ -81,10 +81,33 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; |
|
|
|
} |
|
|
|
|
|
|
|
SmallVector<TensorPtr> apply_on_physical_tensor( |
|
|
|
const OpDef& def, const SmallVector<TensorPtr>& inputs) { |
|
|
|
auto& input = inputs[0]; |
|
|
|
TensorShape target_shape; |
|
|
|
cg::copy_tensor_value_to_shape( |
|
|
|
target_shape, inputs[1]->get_value().proxy_to_default_cpu()); |
|
|
|
TensorPtr output = Tensor::make( |
|
|
|
TensorLayout(target_shape, input->dtype()), input->comp_node()); |
|
|
|
if (output->layout().is_empty()) { |
|
|
|
return {output}; |
|
|
|
} |
|
|
|
if (input->shape().eq_shape(output->shape())) { |
|
|
|
mgb_assert(input->layout().eq_layout(output->layout())); |
|
|
|
output->dev_tensor().copy_from_fixlayout(input->dev_tensor()); |
|
|
|
} else { |
|
|
|
TensorLayout input_layout = input->layout().broadcast(output->shape()); |
|
|
|
output->dev_tensor().copy_from_fixlayout( |
|
|
|
input->dev_tensor().sub(SubTensorSpec::make_from_layout(input_layout))); |
|
|
|
} |
|
|
|
return {output}; |
|
|
|
} |
|
|
|
|
|
|
|
OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) |
|
|
|
.make_from_op_node(make_from_op_node) |
|
|
|
.apply_on_var_node(apply_on_var_node) |
|
|
|
.infer_output_attrs_fallible(infer_output_attrs_fallible) |
|
|
|
.apply_on_physical_tensor(apply_on_physical_tensor) |
|
|
|
.fallback(); |
|
|
|
} // namespace broadcast |
|
|
|
|
|
|
@@ -147,9 +170,31 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; |
|
|
|
} |
|
|
|
|
|
|
|
SmallVector<TensorPtr> apply_on_physical_tensor( |
|
|
|
const OpDef& def, const SmallVector<TensorPtr>& inputs) { |
|
|
|
auto&& op_def = def.cast_final_safe<Reshape>(); |
|
|
|
size_t nr_inp = inputs.size(); |
|
|
|
mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp); |
|
|
|
auto&& src = inputs[0]; |
|
|
|
auto&& tshp_nd = inputs[1]; |
|
|
|
auto slayout = src->layout(); |
|
|
|
|
|
|
|
TensorShape tshp; |
|
|
|
cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu()); |
|
|
|
if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) { |
|
|
|
mgb_assert(tshp[op_def.axis] == -1); |
|
|
|
tshp[op_def.axis] = 1; |
|
|
|
tshp[op_def.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems(); |
|
|
|
} |
|
|
|
TensorLayout tlayout = slayout.reshape(tshp); |
|
|
|
// memory forward |
|
|
|
return {Tensor::make(src->blob(), 0, tlayout)}; |
|
|
|
} |
|
|
|
|
|
|
|
OP_TRAIT_REG(Reshape, Reshape) |
|
|
|
.apply_on_var_node(apply_on_var_node) |
|
|
|
.infer_output_attrs_fallible(infer_output_attrs_fallible) |
|
|
|
.apply_on_physical_tensor(apply_on_physical_tensor) |
|
|
|
.fallback(); |
|
|
|
} // namespace reshape |
|
|
|
|
|
|
|