From: @zhengyuanhua Reviewed-by: @wan_xuelei,@wqtshg,@xchu42 Signed-off-by: @ljl0711tags/v1.2.0
@@ -444,31 +444,20 @@ Status HybridModelAsyncExecutor::Execute(const std::vector<DataBuffer> &inputs, | |||||
TensorValue tensor_value(inputs[i].data, inputs[i].length); | TensorValue tensor_value(inputs[i].data, inputs[i].length); | ||||
args.inputs[i] = tensor_value; | args.inputs[i] = tensor_value; | ||||
} | } | ||||
for (size_t i = 0; i < outputs.size(); ++i) { | |||||
args.outputs.emplace_back(TensorValue(outputs[i].data, outputs[i].length)); | |||||
} | |||||
// usr must designate input tensorDesc when input shape is dynamic in inference | |||||
for (size_t i = 0; i < input_desc.size(); ++i) { | |||||
ConstGeTensorDescPtr tensor_desc_ptr = MakeShared<GeTensorDesc>(input_desc[i]); | |||||
args.input_desc.emplace_back(tensor_desc_ptr); | |||||
} | |||||
GE_CHK_STATUS_RET(executor_->Execute(args), "Failed to execute model."); | GE_CHK_STATUS_RET(executor_->Execute(args), "Failed to execute model."); | ||||
for (const auto &output_tensor_desc : args.output_desc) { | for (const auto &output_tensor_desc : args.output_desc) { | ||||
output_desc.emplace_back(*output_tensor_desc); | output_desc.emplace_back(*output_tensor_desc); | ||||
} | } | ||||
for (size_t i = 0; i < args.outputs.size(); ++i) { | |||||
int64_t output_real_size = 0; | |||||
ge::graphStatus graph_status = TensorUtils::GetTensorSizeInBytes(output_desc[i], output_real_size); | |||||
if (graph_status != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "Get tensor size in bytes failed."); | |||||
return FAILED; | |||||
} | |||||
if (output_real_size > 0) { | |||||
if (outputs[i].length < static_cast<uint64_t>(output_real_size)) { | |||||
GELOGE(FAILED, "output idx[%zu], the memory size of output[%lu] given by " | |||||
"user should be greater than or equal to the real size of output[%ld]", | |||||
i, outputs[i].length, output_real_size); | |||||
return FAILED; | |||||
} | |||||
GE_CHK_RT_RET(rtMemcpy(outputs[i].data, outputs[i].length, args.outputs[i].GetData(), output_real_size, | |||||
RT_MEMCPY_DEVICE_TO_DEVICE)); | |||||
} | |||||
outputs[i].length = output_real_size; | |||||
} | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -44,6 +44,27 @@ ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item( | |||||
} | } | ||||
} | } | ||||
Status ShapeInferenceState::CheckInputShapeByShapeRange(const GeTensorDesc &tensor_desc, | |||||
const GeTensorDesc &target_tensor_desc) const { | |||||
std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
if (tensor_desc.GetShapeRange(shape_range) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Get shape range failed."); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (shape_range.empty()) { | |||||
GELOGD("Shape range is empty, no need to check input shape."); | |||||
return SUCCESS; | |||||
} | |||||
GeShape target_shape = target_tensor_desc.GetShape(); | |||||
if (TensorUtils::CheckShapeByShapeRange(target_shape, shape_range) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Check shape by shape range failed."); | |||||
return PARAM_INVALID; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status ShapeInferenceState::UpdateInputShape(int idx, const GeTensorDesc &target) { | Status ShapeInferenceState::UpdateInputShape(int idx, const GeTensorDesc &target) { | ||||
if (node_item.IsInputShapeStatic(idx)) { | if (node_item.IsInputShapeStatic(idx)) { | ||||
GELOGD("[%s] Trying to update static shape, idx = %d. old shape = [%s], new shape = [%s]", | GELOGD("[%s] Trying to update static shape, idx = %d. old shape = [%s], new shape = [%s]", | ||||
@@ -54,19 +75,31 @@ Status ShapeInferenceState::UpdateInputShape(int idx, const GeTensorDesc &target | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
std::lock_guard<std::mutex> lk(mu_); | |||||
auto &input_desc = input_tensor_desc[idx]; | |||||
if (CheckInputShapeByShapeRange(input_desc, target) != SUCCESS) { | |||||
GELOGE(FAILED, "[%s] Check input shape by shape range failed.", node_item.NodeName().c_str()); | |||||
return FAILED; | |||||
} | |||||
GeShape shape = target.GetShape(); | |||||
input_desc.SetShape(shape); | |||||
input_desc.SetOriginShape(target.GetOriginShape()); | |||||
int64_t tensor_size = -1; | int64_t tensor_size = -1; | ||||
(void) TensorUtils::GetSize(target, tensor_size); | (void) TensorUtils::GetSize(target, tensor_size); | ||||
if (tensor_size <= 0) { | |||||
Format format = input_desc.GetFormat(); | |||||
DataType data_type = input_desc.GetDataType(); | |||||
if (TensorUtils::CalcTensorMemSize(shape, format, data_type, tensor_size) != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "[%s] Calculate tensor memory size failed.", node_item.NodeName().c_str()); | |||||
return FAILED; | |||||
} | |||||
} | |||||
GELOGD("[%s] Update input shape [%d] with Shape: [%s] and OriginalShape: [%s], size = %ld", | GELOGD("[%s] Update input shape [%d] with Shape: [%s] and OriginalShape: [%s], size = %ld", | ||||
node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
idx, | idx, | ||||
target.GetShape().ToString().c_str(), | |||||
shape.ToString().c_str(), | |||||
target.GetOriginShape().ToString().c_str(), | target.GetOriginShape().ToString().c_str(), | ||||
tensor_size); | tensor_size); | ||||
std::lock_guard<std::mutex> lk(mu_); | |||||
auto &input_desc = input_tensor_desc[idx]; | |||||
input_desc.SetShape(target.GetShape()); | |||||
input_desc.SetOriginShape(target.GetOriginShape()); | |||||
(void) TensorUtils::SetSize(input_desc, tensor_size); | (void) TensorUtils::SetSize(input_desc, tensor_size); | ||||
if (--num_pending_shapes_ <= 0) { | if (--num_pending_shapes_ <= 0) { | ||||
ready_cv_.notify_all(); | ready_cv_.notify_all(); | ||||
@@ -58,6 +58,8 @@ struct ShapeInferenceState { | |||||
const vector<GeTensorDesc> &GetOutputTensorDesc() const; | const vector<GeTensorDesc> &GetOutputTensorDesc() const; | ||||
Status CheckInputShapeByShapeRange(const GeTensorDesc &tensor_desc, const GeTensorDesc &target_tensor_desc) const; | |||||
const NodeItem &node_item; | const NodeItem &node_item; | ||||
private: | private: | ||||
@@ -225,23 +225,19 @@ Status HybridModel::GetInputDescInfo(vector<InputOutputDescInfo> &input_desc, st | |||||
GE_CHECK_NOTNULL(op_desc->GetInputDescPtr(0)); | GE_CHECK_NOTNULL(op_desc->GetInputDescPtr(0)); | ||||
Format format = op_desc->GetInputDescPtr(0)->GetFormat(); | Format format = op_desc->GetInputDescPtr(0)->GetFormat(); | ||||
input.data_type = op_desc->GetInputDescPtr(0)->GetDataType(); | |||||
DataType data_type = op_desc->GetInputDescPtr(0)->GetDataType(); | |||||
input.data_type = static_cast<uint32_t>(data_type); | |||||
input.name = op_desc->GetName(); | input.name = op_desc->GetName(); | ||||
int64_t input_size = 0; | |||||
GE_CHK_STATUS_RET(TensorUtils::GetSize(*op_desc->GetInputDescPtr(0), input_size), "get input size failed."); | |||||
// support dynamic shape | |||||
if (input_size < 0) { | |||||
GELOGD("dynamic shape scene, input size is unknown. " | |||||
"format=%d, data_type=%d, input_size=%ld", | |||||
format, input.data_type, input_size); | |||||
input_size = kMemSizeUnknownShape; // -1 | |||||
GeShape shape = op_desc->GetInputDescPtr(0)->GetShape(); | |||||
int64_t tensor_size = 0; | |||||
if (TensorUtils::CalcTensorMemSize(shape, format, data_type, tensor_size) != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "Calculate tensor mem size failed."); | |||||
return FAILED; | |||||
} | } | ||||
// not support dynamic shape input for now, so input_size here will be not less than zero. | |||||
input.size = input_size; | |||||
if (tensor_size == kMemSizeUnknownShape) { | |||||
tensor_size = 0; | |||||
} | |||||
input.size = static_cast<uint64_t>(tensor_size); | |||||
CreateInputDimsInfo(op_desc, input); | CreateInputDimsInfo(op_desc, input); | ||||
formats.push_back(format); | formats.push_back(format); | ||||
@@ -284,6 +280,9 @@ void HybridModel::CreateOutput(ConstGeTensorDescPtr &output_desc, | |||||
} | } | ||||
int64_t tensor_size = 0; | int64_t tensor_size = 0; | ||||
(void)TensorUtils::CalcTensorMemSize(shape, format, data_type, tensor_size); | (void)TensorUtils::CalcTensorMemSize(shape, format, data_type, tensor_size); | ||||
if (tensor_size == kMemSizeUnknownShape) { | |||||
tensor_size = 0; | |||||
} | |||||
output_desc_info.size = static_cast<uint64_t>(tensor_size); | output_desc_info.size = static_cast<uint64_t>(tensor_size); | ||||
output_desc_info.data_type = output_desc->GetDataType(); | output_desc_info.data_type = output_desc->GetDataType(); | ||||
} | } | ||||
@@ -19,7 +19,9 @@ | |||||
#include "framework/common/string_util.h" | #include "framework/common/string_util.h" | ||||
#include "framework/common/types.h" | #include "framework/common/types.h" | ||||
#include "framework/common/util.h" | #include "framework/common/util.h" | ||||
#include "graph/compute_graph.h" | |||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include "graph/utils/tensor_utils.h" | |||||
using std::pair; | using std::pair; | ||||
using std::string; | using std::string; | ||||
@@ -52,6 +54,11 @@ const char *const kCompressWeightError = "it must be appointed when appoint para | |||||
const char *const kSelectImplmodeError = "only support high_performance, high_precision"; | const char *const kSelectImplmodeError = "only support high_performance, high_precision"; | ||||
const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \""; | const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \""; | ||||
const char *const kKeepDtypeError = "file not found"; | const char *const kKeepDtypeError = "file not found"; | ||||
const char *const kInputShapeRangeInvalid = "format of shape range is invalid"; | |||||
const char *const kShapeRangeValueConvertError = "transfer from string to int64 error"; | |||||
const char *const kInputShapeRangeSample1 = "\"input_name1:[n1~n2,c1,h1,w1]\""; | |||||
const char *const kInputShapeRangeSample2 = "\"[]\""; | |||||
const char *const kInputShapeRangeSample3 = "\"[1~20,3,3~6,-1]\""; | |||||
vector<string> SplitInputShape(const std::string &input_shape) { | vector<string> SplitInputShape(const std::string &input_shape) { | ||||
vector<string> shape_pair_vec; | vector<string> shape_pair_vec; | ||||
@@ -257,8 +264,132 @@ bool CheckAndParseDynamicDims(int32_t dynamic_dim_num, std::string &dynamic_dims | |||||
return true; | return true; | ||||
} | } | ||||
bool StringToLongNoThrow(const string &str, long &val) { | |||||
try { | |||||
val = std::stol(str); | |||||
return true; | |||||
} catch (const std::invalid_argument) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, | |||||
{str, kShapeRangeValueConvertError, kInputShapeRangeSample3}); | |||||
GELOGE(PARAM_INVALID, | |||||
"Parse input parameter [--input_shape_range]'s shape range[%s] failed, reason: %s, correct sample is %s.", | |||||
str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3); | |||||
} catch (const std::out_of_range) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, | |||||
{str, kShapeRangeValueConvertError, kInputShapeRangeSample3}); | |||||
GELOGE(PARAM_INVALID, | |||||
"Parse input parameter [--input_shape_range]'s shape range[%s] failed, reason: %s, correct sample is %s.", | |||||
str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3); | |||||
} | |||||
return false; | |||||
} | |||||
bool ParseSingleShapeRange(std::string &shape_range, vector<pair<int64_t, int64_t>> &shape_range_vec) { | |||||
vector<char> square_brackets; | |||||
for (auto ch : shape_range) { | |||||
if (ch == '[' || ch == ']') { | |||||
square_brackets.push_back(ch); | |||||
} | |||||
} | |||||
bool is_square_brackets = (square_brackets[0] == '[') && (square_brackets[1] == ']') && (square_brackets.size() == 2); | |||||
if (!is_square_brackets) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, | |||||
{shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample2}); | |||||
GELOGE(PARAM_INVALID, | |||||
"Parse input parameter [--input_shape_range]'s shape range[%s] failed, reason: %s, correct sample is %s.", | |||||
shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample2); | |||||
return false; | |||||
} | |||||
// trim start bytes, after that, single input should be "1~20,3,3~6,-1" | |||||
if (ge::StringUtils::StartWith(shape_range, "[")) { | |||||
shape_range = shape_range.substr(1, shape_range.size() - 1); | |||||
} | |||||
// parse shape_range of single input. eg. "1~20,3,3~6,-1" | |||||
vector<string> dim_range_set = ge::StringUtils::Split(shape_range, ','); | |||||
for (const auto &range_pair_str : dim_range_set) { | |||||
vector<string> range_pair_set = ge::StringUtils::Split(range_pair_str, '~'); | |||||
pair<int64_t, int64_t> range_pair; | |||||
if (range_pair_set.size() == 1) { | |||||
long range_value = 0; | |||||
if (!StringToLongNoThrow(range_pair_set.at(0), range_value)) { | |||||
return false; | |||||
} | |||||
if (range_value < 0) { | |||||
range_pair = std::make_pair(1, range_value); | |||||
} else { | |||||
range_pair = std::make_pair(range_value, range_value); | |||||
} | |||||
} else if (range_pair_set.size() == 2) { | |||||
// unknown dim, should get range. | |||||
long range_left = 0; | |||||
if (!StringToLongNoThrow(range_pair_set.at(0), range_left)) { | |||||
return false; | |||||
} | |||||
long range_right = 0; | |||||
if (!StringToLongNoThrow(range_pair_set.at(1), range_right)) { | |||||
return false; | |||||
} | |||||
if (range_left < 0 || (range_right < 0)) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, | |||||
{shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3}); | |||||
GELOGE(PARAM_INVALID, | |||||
"Parse input parameter [--input_shape_range]'s shape range[%s] failed, reason: %s, correct sample is %s.", | |||||
shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3); | |||||
return false; | |||||
} | |||||
range_pair = std::make_pair(range_left, range_right); | |||||
} else { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, | |||||
{shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3}); | |||||
GELOGE(PARAM_INVALID, | |||||
"Parse input parameter [--input_shape_range]'s shape range[%s] failed, reason: %s, correct sample is %s.", | |||||
shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3); | |||||
return false; | |||||
} | |||||
shape_range_vec.emplace_back(range_pair); | |||||
} | |||||
return true; | |||||
} | |||||
bool ParseInputShapeRange(const std::string &shape_range, | |||||
std::map<string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map) { | |||||
GELOGD("Input shape range %s", shape_range.c_str()); | |||||
vector<string> shape_range_vec = StringUtils::Split(shape_range, ';'); | |||||
const int DEFAULT_SHAPE_RANGE_PAIR_SIZE = 2; | |||||
for (const auto &shape_range_item : shape_range_vec) { | |||||
vector<string> shape_range_pair_vec = SplitInputShape(shape_range_item); | |||||
if (shape_range_pair_vec.size() != DEFAULT_SHAPE_RANGE_PAIR_SIZE) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, | |||||
{shape_range, kSplitError1, kInputShapeRangeSample1}); | |||||
GELOGE(PARAM_INVALID, "Parse input parameter [--input_shape_range]'s shape range[%s] failed, " | |||||
"reason: %s, correct sample is %s.", shape_range.c_str(), kSplitError1, kInputShapeRangeSample1); | |||||
return false; | |||||
} | |||||
if (shape_range_pair_vec[1].empty()) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape", "reason", "sample"}, | |||||
{shape_range, kEmptyError, kInputShapeRangeSample1}); | |||||
GELOGE(PARAM_INVALID, "Parse input parameter [--input_shape_range]'s shape range[%s] failed," | |||||
"reason: %s, correct sample is %s.", shape_range.c_str(), kEmptyError, kInputShapeRangeSample1); | |||||
return false; | |||||
} | |||||
string shape_range_str = shape_range_pair_vec[1]; | |||||
vector<pair<int64_t, int64_t>> shape_range_val; | |||||
if (!ParseSingleShapeRange(shape_range_str, shape_range_val)) { | |||||
GELOGE(PARAM_INVALID, "Parse single shape range %s error.", shape_range_str.c_str()); | |||||
return false; | |||||
} | |||||
shape_range_map.emplace(make_pair(StringUtils::Trim(shape_range_pair_vec[0]), shape_range_val)); | |||||
} | |||||
return true; | |||||
} | |||||
Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_image_size, string &dynamic_dims, | Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_image_size, string &dynamic_dims, | ||||
const string input_shape, const string input_format, bool &is_dynamic_input) { | |||||
const string input_shape, const string input_shape_range, const string input_format, | |||||
bool &is_dynamic_input) { | |||||
int32_t param_size = static_cast<int32_t>(!dynamic_batch_size.empty()) + | int32_t param_size = static_cast<int32_t>(!dynamic_batch_size.empty()) + | ||||
static_cast<int32_t>(!dynamic_image_size.empty()) + static_cast<int32_t>(!dynamic_dims.empty()); | static_cast<int32_t>(!dynamic_image_size.empty()) + static_cast<int32_t>(!dynamic_dims.empty()); | ||||
if (param_size > 1) { | if (param_size > 1) { | ||||
@@ -269,6 +400,13 @@ Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_i | |||||
} | } | ||||
if (param_size == 0) { | if (param_size == 0) { | ||||
if (!input_shape_range.empty()) { | |||||
std::map<string, std::vector<std::pair<int64_t, int64_t>>> shape_range_map; | |||||
if(!ParseInputShapeRange(input_shape_range, shape_range_map)) { | |||||
GELOGE(ge::PARAM_INVALID, "Failed to parse input shape range: %s", input_shape_range.c_str()); | |||||
return ge::PARAM_INVALID; | |||||
} | |||||
} | |||||
return ge::SUCCESS; | return ge::SUCCESS; | ||||
} | } | ||||
@@ -546,4 +684,91 @@ void EraseEndSemicolon(string ¶m) { | |||||
param.erase(param.end() - 1); | param.erase(param.end() - 1); | ||||
} | } | ||||
} | } | ||||
Status UpdateDataOpShape(const OpDescPtr &op, map<string, vector<int64_t>> &shape_map) { | |||||
GE_CHECK_NOTNULL(op); | |||||
if (shape_map.empty()) { | |||||
GELOGI("Shape map of data op [%s] is empty, no need to update.", op->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
auto tensor_input = op->MutableInputDesc(0); | |||||
auto tensor_output = op->MutableOutputDesc(0); | |||||
GE_CHECK_NOTNULL(tensor_input); | |||||
GE_CHECK_NOTNULL(tensor_output); | |||||
string data_op_name = op->GetName(); | |||||
auto iter = shape_map.find(data_op_name); | |||||
if (iter != shape_map.end()) { | |||||
tensor_input->SetShape(ge::GeShape(iter->second)); | |||||
tensor_output->SetShape(ge::GeShape(iter->second)); | |||||
GELOGI("Update input [%s] shape info", data_op_name.c_str()); | |||||
} else { | |||||
GELOGI("No need update input [%s] attr because not found from input_shape.", data_op_name.c_str()); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status UpdateDataOpShapeRange(const OpDescPtr &op, | |||||
map<string, vector<pair<int64_t, int64_t>>> &shape_range_map) { | |||||
GE_CHECK_NOTNULL(op); | |||||
if (shape_range_map.empty()) { | |||||
GELOGI("Shape range map of data op [%s] is empty.", op->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
auto tensor_input = op->MutableInputDesc(0); | |||||
GE_CHECK_NOTNULL(tensor_input); | |||||
string data_op_name = op->GetName(); | |||||
auto origin_shape = tensor_input->GetShape(); | |||||
auto iter = shape_range_map.find(data_op_name); | |||||
if (iter != shape_range_map.end()) { | |||||
auto cur_shape_range = iter->second; | |||||
if (TensorUtils::CheckShapeByShapeRange(origin_shape, cur_shape_range) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "[%s] Check shape by shape range failed.", op->GetName().c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
for (size_t idx = 0; idx < cur_shape_range.size(); idx++) { | |||||
auto left_range = cur_shape_range[idx].first; | |||||
auto right_range = cur_shape_range[idx].second; | |||||
if (left_range != right_range) { | |||||
origin_shape.SetDim(idx, UNKNOWN_DIM); | |||||
} | |||||
} | |||||
tensor_input->SetShape(origin_shape); | |||||
tensor_input->SetShapeRange(cur_shape_range); | |||||
GELOGI("Update input [%s] shape range info", data_op_name.c_str()); | |||||
} else { | |||||
GELOGI("No need to update input [%s] attr because not found from input_shape_range.", data_op_name.c_str()); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range) { | |||||
if (input_shape_range.empty()) { | |||||
return SUCCESS; | |||||
} | |||||
GE_CHECK_NOTNULL(compute_graph); | |||||
map<string, vector<pair<int64_t, int64_t>>> shape_range_map; | |||||
if (!ParseInputShapeRange(input_shape_range, shape_range_map)) { | |||||
GELOGE(PARAM_INVALID, "Parse input shape range failed."); | |||||
return PARAM_INVALID; | |||||
} | |||||
for (NodePtr &input_node : compute_graph->GetDirectNode()) { | |||||
GE_CHECK_NOTNULL(input_node); | |||||
OpDescPtr op = input_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op); | |||||
if (op->GetType() == DATA) { | |||||
if (UpdateDataOpShapeRange(op, shape_range_map) != SUCCESS) { | |||||
GELOGE(FAILED, "Update data op [%s] input shape range failed.", op->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -59,10 +59,13 @@ bool CheckAndParseDynamicDims(int32_t dynamic_dim_num, std::string &dynamic_dims | |||||
Status CheckDynamicInputParamValid(std::string &dynamic_batch_size, std::string &dynamic_image_size, | Status CheckDynamicInputParamValid(std::string &dynamic_batch_size, std::string &dynamic_image_size, | ||||
std::string &dynamic_dims, const std::string input_shape, | std::string &dynamic_dims, const std::string input_shape, | ||||
const std::string input_format, bool &is_dynamic_input); | |||||
const std::string input_shape_range, const std::string input_format, | |||||
bool &is_dynamic_input); | |||||
bool ParseInputShape(const std::string &input_shape, std::map<string, std::vector<int64_t>> &shape_map, | bool ParseInputShape(const std::string &input_shape, std::map<string, std::vector<int64_t>> &shape_map, | ||||
std::vector<std::pair<string, vector<int64_t>>> &user_shape_map, bool is_dynamic_input = false); | std::vector<std::pair<string, vector<int64_t>>> &user_shape_map, bool is_dynamic_input = false); | ||||
bool ParseInputShapeRange(const std::string &shape_range, | |||||
std::map<string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map); | |||||
Status CheckOutputTypeParamValid(const std::string output_type); | Status CheckOutputTypeParamValid(const std::string output_type); | ||||
Status CheckBufferOptimizeParamValid(const std::string buffer_optimize); | Status CheckBufferOptimizeParamValid(const std::string buffer_optimize); | ||||
@@ -76,5 +79,9 @@ Status CheckInputFormat(const string &input_format); | |||||
Status CheckKeepTypeParamValid(const std::string &keep_dtype); | Status CheckKeepTypeParamValid(const std::string &keep_dtype); | ||||
void PrintOptionMap(std::map<std::string, std::string> &options, std::string tips); | void PrintOptionMap(std::map<std::string, std::string> &options, std::string tips); | ||||
void EraseEndSemicolon(std::string ¶m); | void EraseEndSemicolon(std::string ¶m); | ||||
Status UpdateDataOpShape(const OpDescPtr &op, std::map<std::string, std::vector<int64_t>> &shape_map); | |||||
Status UpdateDataOpShapeRange(const OpDescPtr &op, | |||||
std::map<std::string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map); | |||||
Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range); | |||||
} | } | ||||
#endif // FRAMEWORK_DOMI_ATC_IR_COMMON_H_ | #endif // FRAMEWORK_DOMI_ATC_IR_COMMON_H_ |
@@ -55,6 +55,7 @@ const std::string IR_OPTION_DISABLE_REUSE_MEMORY_DEFAULT = "0"; | |||||
const std::string IR_OPTION_ENABLE_COMPRESS_WEIGHT_DEFAULT = "false"; | const std::string IR_OPTION_ENABLE_COMPRESS_WEIGHT_DEFAULT = "false"; | ||||
const std::string KEEP_DTYPE_OPTION = "keep_dtype"; | const std::string KEEP_DTYPE_OPTION = "keep_dtype"; | ||||
const std::string kInputShape = "input_shape"; | const std::string kInputShape = "input_shape"; | ||||
const std::string kInputShapeRange = "input_shape_range"; | |||||
const std::string kInputFormat = "input_format"; | const std::string kInputFormat = "input_format"; | ||||
/** | /** | ||||
@@ -289,13 +290,20 @@ graphStatus Impl::InferShapePrepare(const ComputeGraphPtr &compute_graph) { | |||||
graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { | graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { | ||||
GELOGD("Enter Update Data Attr Process!"); | GELOGD("Enter Update Data Attr Process!"); | ||||
if (options_.find(kInputShape) == options_.end()) { | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
std::string input_shape = (options_.find(kInputShape) == options_.end()) ? "" : options_[kInputShape]; | |||||
std::string input_shape_range = (options_.find(kInputShapeRange) == options_.end()) ? "" : options_[kInputShapeRange]; | |||||
map<string, vector<int64_t>> shape_map; | map<string, vector<int64_t>> shape_map; | ||||
vector<pair<string, vector<int64_t>>> user_shape_map; | vector<pair<string, vector<int64_t>>> user_shape_map; | ||||
GE_CHK_BOOL_EXEC(ParseInputShape(options_[kInputShape], shape_map, user_shape_map, true), | |||||
return GRAPH_PARAM_INVALID, "parse input shape failed!"); | |||||
if (!input_shape.empty()) { | |||||
GE_CHK_BOOL_EXEC(ParseInputShape(input_shape, shape_map, user_shape_map, true), | |||||
return GRAPH_PARAM_INVALID, "Parse input shape failed!"); | |||||
} | |||||
std::map<string, std::vector<std::pair<int64_t, int64_t>>> shape_range_map; | |||||
if (!input_shape_range.empty()) { | |||||
GE_CHK_BOOL_EXEC(ParseInputShapeRange(input_shape_range, shape_range_map), | |||||
return GRAPH_PARAM_INVALID, "Parse input shape range failed."); | |||||
} | |||||
auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); | auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); | ||||
GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
for (ge::NodePtr &input_node : compute_graph->GetDirectNode()) { | for (ge::NodePtr &input_node : compute_graph->GetDirectNode()) { | ||||
@@ -303,21 +311,31 @@ graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { | |||||
ge::OpDescPtr op = input_node->GetOpDesc(); | ge::OpDescPtr op = input_node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(op); | GE_CHECK_NOTNULL(op); | ||||
if (op->GetType() == DATA) { | if (op->GetType() == DATA) { | ||||
auto tensor_input = op->MutableInputDesc(0); | |||||
auto tensor_output = op->MutableOutputDesc(0); | |||||
GE_CHECK_NOTNULL(tensor_input); | |||||
GE_CHECK_NOTNULL(tensor_output); | |||||
string data_op_name = op->GetName(); | |||||
auto iter = shape_map.find(data_op_name); | |||||
if (iter != shape_map.end()) { | |||||
tensor_input->SetShape(ge::GeShape(iter->second)); | |||||
tensor_output->SetShape(ge::GeShape(iter->second)); | |||||
GELOGD("update input [%s] shape info", data_op_name.c_str()); | |||||
} else { | |||||
GELOGI("no need update input [%s] attr because not found from input_shape.", data_op_name.c_str()); | |||||
if (UpdateDataOpShape(op, shape_map) != SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Update data op [%s] shape failed.", op->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
if (UpdateDataOpShapeRange(op, shape_range_map) != SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Update data op [%s] shape range failed.", op->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
if (shape_range_map.empty()) { | |||||
auto tensor_input = op->MutableInputDesc(0); | |||||
GE_CHECK_NOTNULL(tensor_input); | |||||
GeShape shape = tensor_input->GetShape(); | |||||
std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
if (tensor_input->GetShapeRange(shape_range) != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "[%s] Get shape range failed.", op->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
if (TensorUtils::CheckShapeByShapeRange(shape, shape_range) != SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "[%s] Check shape by shape range failed.", op->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } | ||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
@@ -400,9 +418,11 @@ graphStatus Impl::Init(const Graph &graph, const std::map<std::string, std::stri | |||||
: options_[ge::ir_option::DYNAMIC_IMAGE_SIZE]; | : options_[ge::ir_option::DYNAMIC_IMAGE_SIZE]; | ||||
string dynamic_dims = | string dynamic_dims = | ||||
options_.find(ge::ir_option::DYNAMIC_DIMS) == options_.end() ? "" : options_[ge::ir_option::DYNAMIC_DIMS]; | options_.find(ge::ir_option::DYNAMIC_DIMS) == options_.end() ? "" : options_[ge::ir_option::DYNAMIC_DIMS]; | ||||
string input_shape_range = | |||||
options_.find(ge::INPUT_SHAPE_RANGE) == options_.end() ? "" : options_[ge::INPUT_SHAPE_RANGE]; | |||||
auto status = CheckDynamicInputParamValid(dynamic_batch_size, dynamic_image_size, dynamic_dims, input_shape, | auto status = CheckDynamicInputParamValid(dynamic_batch_size, dynamic_image_size, dynamic_dims, input_shape, | ||||
input_format, is_dynamic_input_); | |||||
input_shape_range, input_format, is_dynamic_input_); | |||||
if (status != ge::SUCCESS) { | if (status != ge::SUCCESS) { | ||||
GELOGE(GRAPH_PARAM_INVALID, "Check dynamic input size failed!"); | GELOGE(GRAPH_PARAM_INVALID, "Check dynamic input size failed!"); | ||||
return GRAPH_PARAM_INVALID; | return GRAPH_PARAM_INVALID; | ||||
@@ -84,6 +84,10 @@ DEFINE_string(input_shape, "", | |||||
"Optional; shape of input data. Required when framework is caffe " | "Optional; shape of input data. Required when framework is caffe " | ||||
"or TensorFLow or MindSpore or Onnx. " | "or TensorFLow or MindSpore or Onnx. " | ||||
"Format: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\""); | "Format: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\""); | ||||
DEFINE_string(input_shape_range, "", | |||||
"Optional; shape range of input data. Required when framework is caffe " | |||||
"or TensorFLow or Onnx. " | |||||
"Format: \"input_name1:[n1~n2,c1,h1,w1];input_name2:[n2~n3,c2,h2,w2]\""); | |||||
DEFINE_bool(h, false, "show this help message"); | DEFINE_bool(h, false, "show this help message"); | ||||
DEFINE_string(cal_conf, "", "Optional; the calibration config file."); | DEFINE_string(cal_conf, "", "Optional; the calibration config file."); | ||||
@@ -240,6 +244,7 @@ class GFlagUtils { | |||||
" --framework Framework type. 0:Caffe; 1:MindSpore; 3:Tensorflow; 5:Onnx\n" | " --framework Framework type. 0:Caffe; 1:MindSpore; 3:Tensorflow; 5:Onnx\n" | ||||
" --input_format Format of input data. E.g.: \"NCHW\"\n" | " --input_format Format of input data. E.g.: \"NCHW\"\n" | ||||
" --input_shape Shape of input data. Separate multiple nodes with semicolons (;). " | " --input_shape Shape of input data. Separate multiple nodes with semicolons (;). " | ||||
" --input_shape_range Shape range of input data. Separate multiple nodes with semicolons (;)." | |||||
"Use double quotation marks (\") to enclose each argument.\n" | "Use double quotation marks (\") to enclose each argument.\n" | ||||
" E.g.: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\"\n" | " E.g.: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\"\n" | ||||
" --dynamic_batch_size Set dynamic batch size. E.g.: \"batchsize1,batchsize2,batchsize3\"\n" | " --dynamic_batch_size Set dynamic batch size. E.g.: \"batchsize1,batchsize2,batchsize3\"\n" | ||||
@@ -373,7 +378,7 @@ class GFlagUtils { | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | ||||
ge::CheckDynamicInputParamValid(FLAGS_dynamic_batch_size, FLAGS_dynamic_image_size, | ge::CheckDynamicInputParamValid(FLAGS_dynamic_batch_size, FLAGS_dynamic_image_size, | ||||
FLAGS_dynamic_dims, FLAGS_input_shape, | |||||
FLAGS_dynamic_dims, FLAGS_input_shape, FLAGS_input_shape_range, | |||||
FLAGS_input_format, is_dynamic_input) != ge::SUCCESS, | FLAGS_input_format, is_dynamic_input) != ge::SUCCESS, | ||||
ret = ge::FAILED, "check dynamic size(batch size, image size or dims) failed!"); | ret = ge::FAILED, "check dynamic size(batch size, image size or dims) failed!"); | ||||
@@ -985,6 +990,7 @@ domi::Status GenerateModel(std::map<string, string> &options, std::string output | |||||
} else { | } else { | ||||
std::map<string, string> atc_params; | std::map<string, string> atc_params; | ||||
atc_params.insert(std::pair<string, string>("input_shape", FLAGS_input_shape)); | atc_params.insert(std::pair<string, string>("input_shape", FLAGS_input_shape)); | ||||
atc_params.insert(std::pair<string, string>(ge::INPUT_SHAPE_RANGE, FLAGS_input_shape_range)); | |||||
atc_params.insert(std::pair<string, string>("out_nodes", FLAGS_out_nodes)); | atc_params.insert(std::pair<string, string>("out_nodes", FLAGS_out_nodes)); | ||||
atc_params.insert(std::pair<string, string>("input_format", FLAGS_input_format)); | atc_params.insert(std::pair<string, string>("input_format", FLAGS_input_format)); | ||||
atc_params.insert(std::pair<string, string>("check_report", FLAGS_check_report)); | atc_params.insert(std::pair<string, string>("check_report", FLAGS_check_report)); | ||||
@@ -576,6 +576,7 @@ Status InitDomiOmgContext(const string &input_shape, const string &input_format, | |||||
GELOGE(PARAM_INVALID, "Failed to parse input shape: %s", input_shape.c_str()); | GELOGE(PARAM_INVALID, "Failed to parse input shape: %s", input_shape.c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -788,6 +789,12 @@ FMK_FUNC_HOST_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<stri | |||||
GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "ATC weights parse ret fail."); | GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "ATC weights parse ret fail."); | ||||
// parser input shape range and update op shape range | |||||
std::string input_shape_range; | |||||
ParseAtcParms(atc_params, INPUT_SHAPE_RANGE, input_shape_range); | |||||
GE_RETURN_WITH_LOG_IF_ERROR(UpdateDynamicInputShapeRange(compute_graph, input_shape_range), | |||||
"Update input shape range failed"); | |||||
GELOGI("ATC parser success."); | GELOGI("ATC parser success."); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -311,6 +311,9 @@ const std::string OP_BANK_UPDATE_FLAG = "ge.op_bank_update"; | |||||
// 0: data multi; 1: model multi; | // 0: data multi; 1: model multi; | ||||
const std::string HCOM_MULTI_MODE = "ge.hcomMultiMode"; | const std::string HCOM_MULTI_MODE = "ge.hcomMultiMode"; | ||||
// atc and ir option | |||||
const char *const INPUT_SHAPE_RANGE = "input_shape_range"; | |||||
// Graph run mode | // Graph run mode | ||||
enum GraphRunMode { PREDICTION = 0, TRAIN }; | enum GraphRunMode { PREDICTION = 0, TRAIN }; | ||||
@@ -390,6 +393,7 @@ static const char *const OP_DEBUG_LEVEL = ge::OP_DEBUG_LEVEL.c_str(); | |||||
#ifdef __GNUC__ | #ifdef __GNUC__ | ||||
const std::set<std::string> ir_builder_suppported_options = {INPUT_FORMAT, | const std::set<std::string> ir_builder_suppported_options = {INPUT_FORMAT, | ||||
INPUT_SHAPE, | INPUT_SHAPE, | ||||
INPUT_SHAPE_RANGE, | |||||
OP_NAME_MAP, | OP_NAME_MAP, | ||||
DYNAMIC_BATCH_SIZE, | DYNAMIC_BATCH_SIZE, | ||||
DYNAMIC_IMAGE_SIZE, | DYNAMIC_IMAGE_SIZE, | ||||
@@ -45,6 +45,7 @@ include_directories(${GE_CODE_DIR}/inc) | |||||
include_directories(${GE_CODE_DIR}/metadef/inc) | include_directories(${GE_CODE_DIR}/metadef/inc) | ||||
include_directories(${GE_CODE_DIR}/ge) | include_directories(${GE_CODE_DIR}/ge) | ||||
include_directories(${GE_CODE_DIR}/ge/inc) | include_directories(${GE_CODE_DIR}/ge/inc) | ||||
include_directories(${GE_CODE_DIR}/ge/ir_build) | |||||
include_directories(${GE_CODE_DIR}/metadef) | include_directories(${GE_CODE_DIR}/metadef) | ||||
include_directories(${GE_CODE_DIR}/metadef/graph) | include_directories(${GE_CODE_DIR}/metadef/graph) | ||||
include_directories(${GE_CODE_DIR}/inc/external) | include_directories(${GE_CODE_DIR}/inc/external) | ||||
@@ -61,6 +62,7 @@ include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/cce) | |||||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/ops) | include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/ops) | ||||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain) | include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain) | ||||
include_directories(${GE_CODE_DIR}/tests/ut/ge) | include_directories(${GE_CODE_DIR}/tests/ut/ge) | ||||
include_directories(${GE_CODE_DIR}/tests/ut/common) | |||||
include_directories(${CMAKE_BINARY_DIR}) | include_directories(${CMAKE_BINARY_DIR}) | ||||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | include_directories(${CMAKE_BINARY_DIR}/proto/ge) | ||||
include_directories(${CMAKE_BINARY_DIR}/proto/ge/proto) | include_directories(${CMAKE_BINARY_DIR}/proto/ge/proto) | ||||
@@ -732,6 +734,7 @@ set(KERNEL_TEST_FILES | |||||
set(MULTI_PARTS_TEST_FILES | set(MULTI_PARTS_TEST_FILES | ||||
"graph_ir/ge_operator_factory_unittest.cc" | "graph_ir/ge_operator_factory_unittest.cc" | ||||
"graph_ir/ge_ir_build_unittest.cc" | |||||
"graph/transop_util_unittest.cc" | "graph/transop_util_unittest.cc" | ||||
"common/datatype_transfer_unittest.cc" | "common/datatype_transfer_unittest.cc" | ||||
"common/dump_manager_unittest.cc" | "common/dump_manager_unittest.cc" | ||||
@@ -0,0 +1,100 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include "ir_build/atc_ir_common.h" | |||||
#include "graph/testcase/ge_graph/graph_builder_utils.h" | |||||
#define protected public | |||||
#define private public | |||||
#undef private | |||||
#undef protected | |||||
const string DATA = "Data"; | |||||
const string AddNYes = "AddNYes"; | |||||
const string NETOUTPUT = "NetOutput"; | |||||
using namespace ge; | |||||
class UtestIrCommon : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
static ge::OpDescPtr CreateOpDesc(const std::string &name, const std::string &type) { | |||||
OpDescPtr op_desc = std::make_shared<ge::OpDesc>(name, type); | |||||
ge::GeTensorDesc ge_tensor_desc; | |||||
op_desc->AddInputDesc("input", ge_tensor_desc); | |||||
op_desc->AddOutputDesc("output", ge_tensor_desc); | |||||
return op_desc; | |||||
} | |||||
static ComputeGraphPtr BuildComputeGraph() { | |||||
auto builder = ut::GraphBuilder("test"); | |||||
auto data1 = builder.AddNode("input1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 2, 3}); | |||||
auto data2 = builder.AddNode("input2", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {4, 10}); | |||||
auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1); | |||||
auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); | |||||
builder.AddDataEdge(data1, 0, addn1, 0); | |||||
builder.AddDataEdge(data2, 0, addn1, 1); | |||||
builder.AddDataEdge(addn1, 0,netoutput, 0); | |||||
return builder.GetGraph(); | |||||
} | |||||
TEST(UtestIrCommon, update_data_op_shape) { | |||||
ge::OpDescPtr op_desc = CreateOpDesc("Data", "Data"); | |||||
map<string, vector<int64_t>> shape_map; | |||||
shape_map["Data"] = {{1,2}}; | |||||
Status ret = UpdateDataOpShape(op_desc, shape_map); | |||||
EXPECT_EQ(ret, ge::SUCCESS); | |||||
} | |||||
TEST(UtestIrCommon, update_dynamic_shape_range_success) { | |||||
ComputeGraphPtr graph = BuildComputeGraph(); | |||||
std::string input_shape_range = "input1:[1, 2~3, -1];input2:[3~5, 10]"; | |||||
Status ret = UpdateDynamicInputShapeRange(graph, input_shape_range); | |||||
EXPECT_EQ(ret, ge::SUCCESS); | |||||
} | |||||
TEST(UtestIrCommon, update_dynamic_shape_range_failed) { | |||||
ComputeGraphPtr graph = BuildComputeGraph(); | |||||
// 1 | |||||
std::string input_shape_range = "input1;[1, 2~3, -1]"; | |||||
Status ret = UpdateDynamicInputShapeRange(graph, input_shape_range); | |||||
EXPECT_EQ(ret, ge::PARAM_INVALID); | |||||
// 2 | |||||
input_shape_range = "input1:[1, 2~3, -1)"; | |||||
ret = UpdateDynamicInputShapeRange(graph, input_shape_range); | |||||
EXPECT_EQ(ret, ge::PARAM_INVALID); | |||||
//3 | |||||
input_shape_range = "input1:[1, 3~2, -1];input2:[3~5, 10]"; | |||||
ret = UpdateDynamicInputShapeRange(graph, input_shape_range); | |||||
EXPECT_EQ(ret, ge::FAILED); | |||||
//4 | |||||
input_shape_range = "input1:[1, 2~-3, -1]"; | |||||
ret = UpdateDynamicInputShapeRange(graph, input_shape_range); | |||||
EXPECT_EQ(ret, ge::PARAM_INVALID); | |||||
} |