|
|
@@ -58,10 +58,24 @@ SmallVector<TensorPtr> apply_on_physical_tensor( |
|
|
|
return proxy_graph_detail::apply_on_physical_tensor(def, inputs); |
|
|
|
} |
|
|
|
|
|
|
|
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { |
|
|
|
auto [output_descs, validated] = |
|
|
|
proxy_graph_detail::infer_output_attrs_fallible(def, inputs); |
|
|
|
if (inputs.size() == 2 && !output_descs[0].layout.ndim) { |
|
|
|
if (!inputs[1].value.empty()) { |
|
|
|
cg::copy_tensor_value_to_shape(output_descs[0].layout, inputs[1].value); |
|
|
|
output_descs[0].layout.init_contiguous_stride(); |
|
|
|
} |
|
|
|
} |
|
|
|
return {output_descs, validated}; |
|
|
|
} |
|
|
|
|
|
|
|
OP_TRAIT_REG(Reduce, Reduce, opr::Reduce) |
|
|
|
.make_from_op_node(make_from_op_node) |
|
|
|
.apply_on_var_node(apply_on_var_node) |
|
|
|
.apply_on_physical_tensor(apply_on_physical_tensor) |
|
|
|
.infer_output_attrs_fallible(infer_output_attrs_fallible) |
|
|
|
.fallback(); |
|
|
|
} // namespace reduce |
|
|
|
} // namespace |
|
|
|