|
|
@@ -448,24 +448,32 @@ Status GraphBuilder::SetInputSize(const ge::NodePtr &node_ptr) { |
|
|
|
auto node_op_desc = node_ptr->GetOpDesc(); |
|
|
|
GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); |
|
|
|
// set dst_node.input_desc = src_node.output_desc |
|
|
|
ge::GeTensorDesc desc_temp(src_op->GetOutputDesc(peer_out_anchor->GetIdx())); |
|
|
|
|
|
|
|
auto output_desc = src_op->GetOutputDescPtr(peer_out_anchor->GetIdx()); |
|
|
|
int64_t size = 0; |
|
|
|
GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc_temp, size) != SUCCESS, GELOGI("Get size failed!")); |
|
|
|
GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(*output_desc, size) != SUCCESS, GELOGI("Get size failed!")); |
|
|
|
GELOGD("src node %s output desc, dim_size: %zu, mem_size: %ld, format: %s, type: %s.", src_node->GetName().c_str(), |
|
|
|
desc_temp.GetShape().GetDimNum(), size, TypeUtils::FormatToSerialString(desc_temp.GetFormat()).c_str(), |
|
|
|
TypeUtils::DataTypeToSerialString(desc_temp.GetDataType()).c_str()); |
|
|
|
for (size_t i = 0; i < desc_temp.GetShape().GetDimNum(); ++i) { |
|
|
|
GELOGD("dims[%zu]: %ld", i, desc_temp.GetShape().GetDim(i)); |
|
|
|
output_desc->GetShape().GetDimNum(), size, TypeUtils::FormatToSerialString(output_desc->GetFormat()).c_str(), |
|
|
|
TypeUtils::DataTypeToSerialString(output_desc->GetDataType()).c_str()); |
|
|
|
for (size_t i = 0; i < output_desc->GetShape().GetDimNum(); ++i) { |
|
|
|
GELOGD("dims[%zu]: %ld", i, output_desc->GetShape().GetDim(i)); |
|
|
|
} |
|
|
|
|
|
|
|
auto input_desc = node_op_desc->GetInputDescPtr(in_data_anchor->GetIdx()); |
|
|
|
auto input_desc = node_op_desc->MutableInputDesc(in_data_anchor->GetIdx()); |
|
|
|
GE_CHECK_NOTNULL(input_desc); |
|
|
|
ge::TensorUtils::SetSize(const_cast<GeTensorDesc &>(*input_desc), size); |
|
|
|
(void) ge::TensorUtils::SetSize(*input_desc, size); |
|
|
|
GE_CHK_STATUS_RET(node_op_desc->UpdateInputDesc(in_data_anchor->GetIdx(), *input_desc)); |
|
|
|
GELOGD("%s input desc, dim_size: %zu, mem_size: %ld, format: %s, type: %s.", node_ptr->GetName().c_str(), |
|
|
|
input_desc->GetShape().GetDimNum(), size, TypeUtils::FormatToSerialString(input_desc->GetFormat()).c_str(), |
|
|
|
TypeUtils::DataTypeToSerialString(input_desc->GetDataType()).c_str()); |
|
|
|
// inherit some attr |
|
|
|
int64_t tensor_size_attr; |
|
|
|
if (AttrUtils::GetInt(output_desc, ATTR_NAME_SPECIAL_OUTPUT_SIZE, tensor_size_attr) && (tensor_size_attr > 0)) { |
|
|
|
GE_IF_BOOL_EXEC(!AttrUtils::SetInt(*input_desc, ATTR_NAME_SPECIAL_OUTPUT_SIZE, tensor_size_attr), |
|
|
|
GELOGW("Set size attr failed!"); continue); |
|
|
|
GELOGD("node[%s] [%d]th output has sepcial size[%ld], and update to node[%s] [%d]th input", |
|
|
|
src_op->GetName().c_str(), peer_out_anchor->GetIdx(), tensor_size_attr, |
|
|
|
node_op_desc->GetName().c_str(), in_data_anchor->GetIdx()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|