You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_batch_options.cc 13 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "multi_batch_options.h"
  17. #include "framework/common/debug/ge_log.h"
  18. #include "framework/omg/omg_inner_types.h"
  19. #include "framework/common/util.h"
  20. #include "framework/common/string_util.h"
  21. #include "common/formats/utils/formats_trans_utils.h"
  22. #include "common/util/error_manager/error_manager.h"
  23. #include "graph/debug/ge_attr_define.h"
  24. #include "graph/utils/node_utils.h"
  25. #include "graph/ge_context.h"
  26. #include "graph/common/local_context.h"
  27. #include "framework/common/types.h"
  28. namespace ge {
  29. namespace multibatch {
  30. constexpr int kDecimal = 10;
  31. constexpr uint8_t kMaxShapesCount = 100;
  32. constexpr uint8_t kMinShapesCount = 2;
  33. const int kDynmaicDims = -1;
  34. const int kDynamicBatchDynamicDimsNum = 1;
  35. const int kDynamicImgSizeDynamciDimsNum = 2;
  36. const size_t kMaxNDDimNum = 4;
  37. const size_t kMinNDDimNum = 1;
  38. void ParseDynamicSize(string dynamic_size, vector<vector<int64_t>> &shapes) {
  39. std::vector<std::string> shape_strs = ge::StringUtils::Split(dynamic_size, ';');
  40. for (const auto &shape_str : shape_strs) {
  41. if (shape_str.empty()) {
  42. continue;
  43. }
  44. std::vector<int64_t> shape;
  45. std::vector<std::string> dims = ge::StringUtils::Split(shape_str, ',');
  46. for (const auto &dim : dims) {
  47. if (dim.empty()) {
  48. continue;
  49. }
  50. shape.emplace_back(std::strtol(dim.c_str(), nullptr, kDecimal));
  51. }
  52. if (!shape.empty()) {
  53. shapes.emplace_back(shape);
  54. }
  55. }
  56. }
  57. ///
  58. /// @ingroup ge
  59. /// @brief Init Dynamic Param from Options.
  60. /// @param [out] std::vector<std::vector<int64_t>> &shapes: Result for Params.
  61. /// @return true: Configed for Multi batch / false: Not configed for Multi batch.
  62. ///
  63. bool InitDynamicParams(vector<vector<int64_t>> &shapes) {
  64. if (!GetLocalOmgContext().dynamic_batch_size.empty()) {
  65. GELOGD("Found dynamic batch option, value %s", GetLocalOmgContext().dynamic_batch_size.c_str());
  66. std::vector<std::string> dims = ge::StringUtils::Split(GetLocalOmgContext().dynamic_batch_size, ',');
  67. for (const auto &dim : dims) {
  68. if (dim.empty()) {
  69. continue;
  70. }
  71. shapes.emplace_back(std::vector<int64_t>({std::strtol(dim.c_str(), nullptr, kDecimal)}));
  72. GELOGI("Found dynamic batch, shape %s", formats::JoinToString(*shapes.rbegin()).c_str());
  73. }
  74. }
  75. if (!GetLocalOmgContext().dynamic_image_size.empty()) {
  76. GELOGD("Found dynamic image size option, value %s", GetLocalOmgContext().dynamic_image_size.c_str());
  77. ParseDynamicSize(GetLocalOmgContext().dynamic_image_size, shapes);
  78. for (const auto &shape : shapes) {
  79. GELOGI("Found dynamic image size, shape %s", formats::JoinToString(shape).c_str());
  80. }
  81. }
  82. if (!GetLocalOmgContext().dynamic_dims.empty()) {
  83. GELOGD("Found dynamic dims option, value %s", GetLocalOmgContext().dynamic_dims.c_str());
  84. ParseDynamicSize(GetLocalOmgContext().dynamic_dims, shapes);
  85. for (const auto &shape : shapes) {
  86. GELOGI("Found dynamic dims, shape %s", formats::JoinToString(shape).c_str());
  87. }
  88. }
  89. return !shapes.empty();
  90. }
  91. ///
  92. /// @ingroup ge
  93. /// @brief parse each data's own dynamic dims.
  94. /// @param [out] map<string, vector<vector<int64_t>>> &data_to_dynamic_info: key:data_name. value:dynamic dims.
  95. /// @return true: Configed for Multi batch / false: Not configed for Multi batch.
  96. ///
  97. Status ParserDataToDynmaicInfo(const vector<vector<int64_t>> &shapes,
  98. vector<pair<string, vector<int64_t>>> &data_name_and_shape,
  99. map<string, vector<vector<int64_t>> > &data_to_dynamic_info) {
  100. size_t cur_data_index = 0;
  101. for (size_t index = 0; index < data_name_and_shape.size(); ++index) {
  102. auto &cur_item = data_name_and_shape[index];
  103. auto &data_name = cur_item.first;
  104. auto &data_shape = cur_item.second;
  105. auto dynamic_dims_num = std::count_if(data_shape.begin(), data_shape.end(),
  106. [&data_shape](int64_t dim){ return dim < 0; });
  107. vector<vector<int64_t> > dynamic_info;
  108. for (auto &dynamic_gear_info : shapes) {
  109. vector<int64_t> one_gear;
  110. if (dynamic_gear_info.size() == static_cast<size_t>(dynamic_dims_num)) {
  111. one_gear = dynamic_gear_info;
  112. } else if (dynamic_gear_info.size() > static_cast<size_t>(dynamic_dims_num)) {
  113. auto tmp_index = cur_data_index;
  114. for (size_t i = 0; i < static_cast<size_t>(dynamic_dims_num); ++i) {
  115. if (tmp_index >= dynamic_gear_info.size()) {
  116. GELOGE(PARAM_INVALID, "Data: %s shape: %s make dynamic dims overflow", data_name.c_str(),
  117. formats::JoinToString(data_shape).c_str());
  118. return FAILED;
  119. }
  120. one_gear.push_back(dynamic_gear_info[tmp_index++]);
  121. }
  122. } else {
  123. GELOGE(PARAM_INVALID, "Dynamic dims num of data: %s shape: %s can not be more than one gear dynamic info size",
  124. data_name.c_str(), formats::JoinToString(data_shape).c_str());
  125. return FAILED;
  126. }
  127. dynamic_info.push_back(one_gear);
  128. }
  129. cur_data_index += dynamic_dims_num;
  130. data_to_dynamic_info[data_name] = dynamic_info;
  131. }
  132. return SUCCESS;
  133. }
  134. ///
  135. /// @ingroup ge
  136. /// @brief Check Dynamic Param is invalid.
  137. /// @param [in] const vector<vector<int64_t>> &shapes: Params for check.
  138. /// @return SUCCESS: valid / PARAM_INVALID: invalid.
  139. ///
  140. Status CheckDynamicParams(const vector<vector<int64_t>> &shapes) {
  141. if (shapes.size() < kMinShapesCount) {
  142. ErrorManager::GetInstance().ATCReportErrMessage(
  143. "E10035", {"shapesize", "minshapesize"}, {std::to_string(shapes.size()), std::to_string(kMinShapesCount - 1)});
  144. GELOGE(PARAM_INVALID,
  145. "Input parameter[--dynamic_batch_size, --dynamic_image_size or --dynamic_dims]'s "
  146. "value size [%zu] must be greater than [%zu].",
  147. shapes.size(), kMinShapesCount - 1);
  148. return PARAM_INVALID;
  149. }
  150. if (shapes.size() > kMaxShapesCount) {
  151. ErrorManager::GetInstance().ATCReportErrMessage(
  152. "E10036", {"shapesize", "maxshapesize"}, {std::to_string(shapes.size()), std::to_string(kMaxShapesCount + 1)});
  153. GELOGE(PARAM_INVALID,
  154. "Input parameter[--dynamic_batch_size, --dynamic_image_size or --dynamic_dims]'s "
  155. "value size [%zu] must be less than [%zu].",
  156. shapes.size(), kMaxShapesCount + 1);
  157. return PARAM_INVALID;
  158. }
  159. std::set<std::vector<int64_t>> shapes_set;
  160. size_t shape_size = shapes.at(0).size();
  161. for (auto &shape : shapes) {
  162. if (shape_size != shape.size()) {
  163. ErrorManager::GetInstance().ATCReportErrMessage("E10037", {"shapesize1", "shapesize2"},
  164. {std::to_string(shape_size), std::to_string(shape.size())});
  165. GELOGE(PARAM_INVALID,
  166. "Input parameter[--dynamic_batch_size, --dynamic_image_size or --dynamic_dims]'s "
  167. "value size must be same, first group's size is %zu and another's is %zu.",
  168. shape_size, shape.size());
  169. return PARAM_INVALID;
  170. }
  171. for (auto dim : shape) {
  172. if (dim <= 0) {
  173. ErrorManager::GetInstance().ATCReportErrMessage("E10038", {"dim"}, {std::to_string(dim)});
  174. GELOGE(PARAM_INVALID, "Invalid dim %ld, all dims must be greater than 0", dim);
  175. return PARAM_INVALID;
  176. }
  177. }
  178. shapes_set.insert(shape);
  179. }
  180. if (shapes_set.size() != shapes.size()) {
  181. ErrorManager::GetInstance().ATCReportErrMessage("E10039");
  182. GELOGE(PARAM_INVALID,
  183. "Input parameter[--dynamic_batch_size, --dynamic_image_size or --dynamic_dims] exist duplicate shapes.");
  184. return PARAM_INVALID;
  185. }
  186. return SUCCESS;
  187. }
  188. ///
  189. /// @ingroup ge
  190. /// @brief Get GeShape from configed shape.
  191. /// @param [in] const std::vector<int64_t> &batch_shape: Configed shape.
  192. /// @param [out] GeShape &data_shape: GeShape for configed shape.
  193. /// @return SUCCESS / PARAM_INVALID
  194. ///
  195. Status CalcShape(const std::vector<int64_t> &batch_shape, GeShape &data_shape) {
  196. size_t batch_shape_index = 0;
  197. for (size_t i = 0; i < data_shape.GetDimNum(); ++i) {
  198. if (data_shape.GetDim(i) < 0) {
  199. if (batch_shape_index >= batch_shape.size()) {
  200. ErrorManager::GetInstance().ATCReportErrMessage(
  201. "E19012", {"function", "reason"},
  202. {"CalcShape", "the batch shape count " + std::to_string(batch_shape.size()) +
  203. " does not match the data shape " + data_shape.ToString()});
  204. GELOGE(PARAM_INVALID,
  205. "Failed to calc tensor shape, the batch shape count %zu, does not match the data shape %s",
  206. batch_shape.size(), data_shape.ToString().c_str());
  207. return PARAM_INVALID;
  208. }
  209. data_shape.SetDim(i, batch_shape[batch_shape_index++]);
  210. }
  211. }
  212. if (batch_shape_index != batch_shape.size()) {
  213. ErrorManager::GetInstance().ATCReportErrMessage(
  214. "E19012", {"function", "reason"}, {"CalcShape", "the batch shape count " + std::to_string(batch_shape.size()) +
  215. " does not match the data shape " + data_shape.ToString()});
  216. GELOGE(PARAM_INVALID, "Failed to calc tensor shape, the batch shape count %zu, does not match the data shape %s",
  217. batch_shape.size(), data_shape.ToString().c_str());
  218. return PARAM_INVALID;
  219. }
  220. return SUCCESS;
  221. }
  222. ///
  223. /// @ingroup ge
  224. /// @brief Set mbatch_dynamic_type on node.
  225. /// @param [in] const OpDescPtr &op_desc: Node for set attribute.
  226. /// @return 0: SUCCESS / others: INTERNAL_ERROR
  227. ///
  228. Status StampDynamicType(const OpDescPtr &op_desc) {
  229. GE_CHECK_NOTNULL(op_desc);
  230. int32_t dynamic_type = static_cast<int32_t>(FIXED);
  231. if (!GetLocalOmgContext().dynamic_batch_size.empty()) {
  232. dynamic_type = static_cast<int32_t>(DYNAMIC_BATCH);
  233. }
  234. if (!GetLocalOmgContext().dynamic_image_size.empty()) {
  235. dynamic_type = static_cast<int32_t>(DYNAMIC_IMAGE);
  236. }
  237. if (!GetLocalOmgContext().dynamic_dims.empty()) {
  238. dynamic_type = static_cast<int32_t>(DYNAMIC_DIMS);
  239. }
  240. if (!AttrUtils::SetInt(op_desc, ATTR_DYNAMIC_TYPE, dynamic_type)) {
  241. GELOGE(INTERNAL_ERROR, "Failed to add dynamic type attr for node %s", op_desc->GetName().c_str());
  242. return INTERNAL_ERROR;
  243. }
  244. return SUCCESS;
  245. }
  246. ///
  247. /// @ingroup ge
  248. /// @brief Check dynamic batch Shape.
  249. /// @param [in] const vector<int64_t> &shape: data_shape to be checked.
  250. /// @param [in] const string &data_name: cur data name.
  251. /// @return 0: true/false
  252. ///
  253. bool CheckDynamicBatchShape(const vector<int64_t> &shape, const string &data_name) {
  254. if (shape[0] == kDynmaicDims) {
  255. for (size_t i = 1; i < shape.size(); ++i) {
  256. if (shape[i] < 1) {
  257. ErrorManager::GetInstance().ATCReportErrMessage("E10018", {"index", "shape"},
  258. {std::to_string(i), std::to_string(shape[i])});
  259. GELOGE(ge::PARAM_INVALID,
  260. "Only batch N can be -1 when set --dynamic_batch_size, current data: %s shape[%zu] is %ld",
  261. data_name.c_str(), i, shape[i]);
  262. return false;
  263. }
  264. }
  265. return true;
  266. } else {
  267. return false;
  268. }
  269. }
  270. ///
  271. /// @ingroup ge
  272. /// @brief Check Dynamic image size shape.
  273. /// @param [in] unordered_map<string, vector<int64_t>> &shape_map: map of data_name and data_shape.
  274. /// @param [in] const std::string &input_format: format of input.
  275. /// @return 0: true/false
  276. ///
  277. bool CheckDynamicImageSizeShape(const vector<int64_t> &shape, const string &data_name,
  278. const std::string &input_format) {
  279. int64_t height = 0;
  280. int64_t width = 0;
  281. if (input_format == "NCHW") {
  282. height = shape[NCHW_DIM_H];
  283. width = shape[NCHW_DIM_W];
  284. }
  285. if (input_format == "NHWC") {
  286. height = shape[NHWC_DIM_H];
  287. width = shape[NHWC_DIM_W];
  288. }
  289. if (height == kDynmaicDims && width == kDynmaicDims &&
  290. std::count(shape.begin(), shape.end(), kDynmaicDims) == kDynamicImgSizeDynamciDimsNum) {
  291. return true;
  292. } else {
  293. ErrorManager::GetInstance().ATCReportErrMessage("E10019");
  294. GELOGE(ge::PARAM_INVALID,
  295. "--input_shape's shape is invalid, only height and width can be -1 when set --dynamic_image_size.");
  296. return false;
  297. }
  298. }
  299. } // namespace multibatch
  300. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示