GitOrigin-RevId: 6fab3b1402
tags/v1.3.0
@@ -176,7 +176,7 @@ TensorShape ChannelImpl::get_shape(void* handle) { | |||||
m_buffer.enqueue(Flush{info}); | m_buffer.enqueue(Flush{info}); | ||||
m_cv.wait(lock, [&]() { | m_cv.wait(lock, [&]() { | ||||
check_worker_exc_unsafe(); | check_worker_exc_unsafe(); | ||||
return bool(info->ptr); | |||||
return static_cast<bool>(info->ptr); | |||||
}); | }); | ||||
m_waitee = nullptr; | m_waitee = nullptr; | ||||
TensorShape ret = info->ptr->layout(); | TensorShape ret = info->ptr->layout(); | ||||
@@ -212,7 +212,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(void* handle) { | |||||
m_buffer.enqueue(Flush{info}); | m_buffer.enqueue(Flush{info}); | ||||
m_cv.wait(lock, [&]() { | m_cv.wait(lock, [&]() { | ||||
check_worker_exc_unsafe(); | check_worker_exc_unsafe(); | ||||
return bool(info->ptr); | |||||
return static_cast<bool>(info->ptr); | |||||
}); | }); | ||||
m_waitee = nullptr; | m_waitee = nullptr; | ||||
return info->ptr->dev_tensor(); | return info->ptr->dev_tensor(); | ||||
@@ -232,7 +232,7 @@ void ChannelImpl::close() { | |||||
} | } | ||||
void ChannelImpl::config_async_level(int level) { | void ChannelImpl::config_async_level(int level) { | ||||
mgb_assert(level <= 2 and level >= 0, "async_level should be 0, 1 or 2"); | |||||
mgb_assert(level <= 2 && level >= 0, "async_level should be 0, 1 or 2"); | |||||
m_async_level = level; | m_async_level = level; | ||||
} | } | ||||
@@ -49,7 +49,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> BackwardGraph::InternalGraph::i | |||||
expr_input_descs.push_back(node2attr.at(inp)); | expr_input_descs.push_back(node2attr.at(inp)); | ||||
} | } | ||||
auto[expr_output_descs, expr_validated] = OpDef::infer_output_attrs_fallible( | |||||
auto [expr_output_descs, expr_validated] = OpDef::infer_output_attrs_fallible( | |||||
*expr_op, expr_input_descs); | *expr_op, expr_input_descs); | ||||
validated = validated && expr_validated; | validated = validated && expr_validated; | ||||
@@ -54,16 +54,13 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
SmallVector<LogicalTensorDesc> out_shapes(nr_out); | SmallVector<LogicalTensorDesc> out_shapes(nr_out); | ||||
auto&& i0 = inputs[0]; | auto&& i0 = inputs[0]; | ||||
auto&& i1 = inputs[1]; | auto&& i1 = inputs[1]; | ||||
size_t i = 0; | |||||
if (!need_stat) { | |||||
out_shapes[0] = out_shapes[1] = {TensorLayout({0}, i0.layout.dtype, i0.layout.format), i0.comp_node}; | |||||
i = 2; | |||||
} | |||||
for (; i < nr_out-1; ++ i) { | |||||
// [running_mean, running_var,] save_mean, save_var | |||||
for (size_t i = 0; i < nr_out-1; ++ i) { | |||||
out_shapes[i] = {i1.layout, i1.comp_node}; | out_shapes[i] = {i1.layout, i1.comp_node}; | ||||
} | } | ||||
// output tensor | |||||
out_shapes[nr_out-1] = {i0.layout, i0.comp_node}; | out_shapes[nr_out-1] = {i0.layout, i0.comp_node}; | ||||
return {out_shapes, true}; | |||||
return {out_shapes, out_shapes[nr_out-1].layout.ndim != 0}; | |||||
} | } | ||||
OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm) | OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm) | ||||
@@ -61,17 +61,17 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
TensorLayout out_layout = src.layout; | TensorLayout out_layout = src.layout; | ||||
if (tshp.layout.ndim == 0 || tshp.value.empty()) { | if (tshp.layout.ndim == 0 || tshp.value.empty()) { | ||||
out_layout.ndim = 0; | out_layout.ndim = 0; | ||||
return {{{out_layout, src.comp_node}}, true}; | |||||
return {{{out_layout, src.comp_node}}, false}; | |||||
} | } | ||||
mgb_assert( | mgb_assert( | ||||
tshp.layout.ndim == 1, | |||||
"target shape of Broadcast expects ndim=1; got ndim=%lu actually", | |||||
tshp.layout.ndim == 1, | |||||
"target shape of Broadcast expects ndim=1; got ndim=%lu actually", | |||||
tshp.layout.ndim); | tshp.layout.ndim); | ||||
size_t target_ndim = tshp.layout.shape[0]; | size_t target_ndim = tshp.layout.shape[0]; | ||||
out_layout.ndim = target_ndim; | out_layout.ndim = target_ndim; | ||||
auto* ptr = tshp.value.ptr<dt_int32>(); | auto* ptr = tshp.value.ptr<dt_int32>(); | ||||
for(size_t i=0; i<target_ndim; ++i) { | |||||
for (size_t i = 0; i < target_ndim; ++i) { | |||||
out_layout.shape[i] = ptr[i]; | out_layout.shape[i] = ptr[i]; | ||||
} | } | ||||
mgb_assert(valid_broadcast(src.layout, out_layout), | mgb_assert(valid_broadcast(src.layout, out_layout), | ||||
@@ -76,7 +76,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
return {{ | return {{ | ||||
{TensorLayout(inputs[0].layout.dtype), cn}, | {TensorLayout(inputs[0].layout.dtype), cn}, | ||||
{TensorLayout(dtype::Int32()), cn} | {TensorLayout(dtype::Int32()), cn} | ||||
}, true}; | |||||
}, false}; | |||||
} | } | ||||
OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) | OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) | ||||
@@ -60,7 +60,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
TensorLayout out_layout; | TensorLayout out_layout; | ||||
out_layout.ndim = 0; | out_layout.ndim = 0; | ||||
out_layout.dtype = out_dt; | out_layout.dtype = out_dt; | ||||
return {{{out_layout, out_cn}}, true}; | |||||
return {{{out_layout, out_cn}}, false}; | |||||
} | } | ||||
} | } | ||||
@@ -59,7 +59,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size()); | mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size()); | ||||
auto&& desc = inputs[0]; | auto&& desc = inputs[0]; | ||||
if (!desc.layout.ndim) { | if (!desc.layout.ndim) { | ||||
return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, true}; | |||||
return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false}; | |||||
} | } | ||||
DeviceTensorND value; | DeviceTensorND value; | ||||
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){ | if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){ | ||||