diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 480fc59f..3971073c 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -178,7 +178,11 @@ TensorInfo* ChannelImpl::put_impl( auto _ = StackManager::Guard{"Put", &state.stack_manager}; auto info = alloc(); MGB_RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandKind::Put); + constexpr int size_threshold = TensorShape::MAX_NDIM; init(info, {data.layout(), data.comp_node()}); + if ((!hvalue.empty()) && info->desc.layout.total_nr_elems() <= size_threshold) { + info->desc.value = hvalue.proxy_to_default_cpu(); + } info->ptr = Tensor::make(data, hvalue); MGB_RECORD_EVENT( TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, diff --git a/imperative/src/impl/ops/reduce.cpp b/imperative/src/impl/ops/reduce.cpp index 030709ee..16f2d816 100644 --- a/imperative/src/impl/ops/reduce.cpp +++ b/imperative/src/impl/ops/reduce.cpp @@ -58,10 +58,24 @@ SmallVector apply_on_physical_tensor( return proxy_graph_detail::apply_on_physical_tensor(def, inputs); } +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + auto [output_descs, validated] = + proxy_graph_detail::infer_output_attrs_fallible(def, inputs); + if (inputs.size() == 2 && !output_descs[0].layout.ndim) { + if (!inputs[1].value.empty()) { + cg::copy_tensor_value_to_shape(output_descs[0].layout, inputs[1].value); + output_descs[0].layout.init_contiguous_stride(); + } + } + return {output_descs, validated}; +} + OP_TRAIT_REG(Reduce, Reduce, opr::Reduce) .make_from_op_node(make_from_op_node) .apply_on_var_node(apply_on_var_node) .apply_on_physical_tensor(apply_on_physical_tensor) + .infer_output_attrs_fallible(infer_output_attrs_fallible) .fallback(); } // namespace reduce } // namespace