@@ -1509,7 +1509,7 @@ def sync_batch_norm( | |||||
""" | """ | ||||
_eps_mode = eps_mode.lower() | _eps_mode = eps_mode.lower() | ||||
assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode) | assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode) | ||||
if _eps_mode == "additive" and not (is_distributed() or training): | |||||
if _eps_mode == "additive" and not (is_distributed() and training): | |||||
return batch_norm( | return batch_norm( | ||||
inp, | inp, | ||||
running_mean, | running_mean, | ||||
@@ -717,7 +717,6 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { | |||||
if (state.options.enable_dtr_auto_drop || state.options.disable_memory_forwarding) { | if (state.options.enable_dtr_auto_drop || state.options.disable_memory_forwarding) { | ||||
ptr->to_contiguous_inplace(); | ptr->to_contiguous_inplace(); | ||||
} | } | ||||
dest->desc.layout = ptr->layout(); | |||||
dest->desc.comp_node = ptr->comp_node(); | dest->desc.comp_node = ptr->comp_node(); | ||||
dest->memory = ptr->blob()->size(); | dest->memory = ptr->blob()->size(); | ||||
dest->ptr = std::move(ptr); | dest->ptr = std::move(ptr); | ||||
@@ -205,6 +205,12 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
size_t size = inputs.size(); | size_t size = inputs.size(); | ||||
SmallVector<LogicalTensorDesc> dests(size); | SmallVector<LogicalTensorDesc> dests(size); | ||||
for (size_t i = 0; i < size; i++) { | |||||
if (inputs[i].layout.ndim == 0) { | |||||
return {{{TensorLayout(inputs[0].layout.dtype), inputs[0].comp_node}}, | |||||
false}; | |||||
} | |||||
} | |||||
if (size > 1) { | if (size > 1) { | ||||
auto [output_descs, validated] = | auto [output_descs, validated] = | ||||
proxy_graph_detail::infer_output_attrs_fallible(def, inputs); | proxy_graph_detail::infer_output_attrs_fallible(def, inputs); | ||||
@@ -115,6 +115,9 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
TensorShapeArray src(inputs.size()); | TensorShapeArray src(inputs.size()); | ||||
for (size_t i = 0; i < inputs.size(); ++i) { | for (size_t i = 0; i < inputs.size(); ++i) { | ||||
src[i] = inputs[i].layout; | src[i] = inputs[i].layout; | ||||
if (!src[i].ndim) { | |||||
return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false}; | |||||
} | |||||
} | } | ||||
megdnn::Elemwise::deduce_shape(src, shp); | megdnn::Elemwise::deduce_shape(src, shp); | ||||
} | } | ||||