From dc53d49bb4d1bc5e588b64b6ec327a95d48ced2b Mon Sep 17 00:00:00 2001 From: "dongduo5@huawei.com" Date: Mon, 12 Apr 2021 16:31:40 +0800 Subject: [PATCH 01/10] Add transfer from HWCN to FRACTAL_ZN_LSTM --- ge/CMakeLists.txt | 2 + ge/common/CMakeLists.txt | 1 + .../format_transfer_hwcn_fractal_zn_lstm.cc | 215 +++++++++++++++++++++ ge/common/ge_common.mk | 1 + ge/ge_inference.mk | 1 + ge/ge_runner.mk | 1 + tests/ut/ge/CMakeLists.txt | 1 + 7 files changed, 222 insertions(+) create mode 100644 ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index 89745019..c79ba061 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -100,6 +100,7 @@ set(TRAIN_SRC_LIST "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" "common/formats/format_transfers/format_transfer_transpose.cc" + "common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc" "common/formats/formats.cc" "common/formats/utils/formats_trans_utils.cc" "common/fp16_t.cc" @@ -435,6 +436,7 @@ set(INFER_SRC_LIST "common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" "common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" "common/formats/format_transfers/format_transfer_nchw_fz_c04.cc" + "common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc" "common/formats/formats.cc" "common/profiling/profiling_manager.cc" "common/dump/dump_properties.cc" diff --git a/ge/common/CMakeLists.txt b/ge/common/CMakeLists.txt index 75cb8ad1..51ac0e8a 100755 --- a/ge/common/CMakeLists.txt +++ b/ge/common/CMakeLists.txt @@ -49,6 +49,7 @@ set(SRC_LIST "formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" "formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" "formats/format_transfers/format_transfer_nchw_fz_c04.cc" + "formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc" "formats/formats.cc" "ge_format_util.cc" "fmk_error_codes.cc" diff --git a/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc new file mode 100644 index 00000000..669b7895 --- /dev/null +++ b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc @@ -0,0 +1,215 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.h" + +#include +#include + +#include "common/formats/utils/formats_definitions.h" +#include "common/formats/utils/formats_trans_utils.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" +#include "graph/utils/type_utils.h" + +namespace ge { +namespace formats { +namespace { +bool CheckDataTypeSupported(const DataType &data_type) { + return (data_type == DT_FLOAT || data_type == DT_FLOAT16); +} + +Status TransShapeHwcnToFrazlstm(const DataType &data_type, const std::vector &src_shape, + std::vector &dst_shape) { + auto cube_size = GetCubeSizeByDataType(data_type); + dst_shape.clear(); + dst_shape.push_back(Ceil(src_shape.at(kHwcnC), static_cast(cube_size))); + dst_shape.push_back(src_shape.at(kHwcnH)); + dst_shape.push_back(src_shape.at(kHwcnW)); + dst_shape.push_back(src_shape.at(kHwcnN)); + dst_shape.push_back(cube_size); + dst_shape.push_back(cube_size); + if (!CheckShapeValid(dst_shape, kFracZDimsNum)) { + GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", + ShapeToString(dst_shape).c_str()); + return ACL_ERROR_GE_SHAPE_INVALID; + } + return SUCCESS; +} + +Status CheckArgsForHwcnToFrazlstm(const TransArgs &args) { + if (args.src_format != FORMAT_HWCN || args.dst_format != FORMAT_FRACTAL_ZN_LSTM) { + std::string error = "Dose not support trans format from " + + FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + + FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); + GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); + return UNSUPPORTED; + } + if (!CheckDataTypeSupported(args.src_data_type)) { + GELOGE(UNSUPPORTED, "Failed to trans shape from HWCN to C1HWNCoC0, invalid data type %s", + TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); + return UNSUPPORTED; + } + if (!CheckShapeValid(args.src_shape, kHwcnDimsNum)) { + GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(args.src_shape).c_str()); + return PARAM_INVALID; + } + if (!CheckShapeValid(args.dst_shape, kFracZDimsNum)) { + GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(args.dst_shape).c_str()); + return PARAM_INVALID; + } + std::vector expect_dst_shape; + auto ret = TransShapeHwcnToFrazlstm(args.src_data_type, args.src_shape, expect_dst_shape); + if (ret != SUCCESS) { + return ret; + } + if (args.dst_shape != expect_dst_shape) { + GELOGE(PARAM_INVALID, + "Failed to trans format, src and dst shape are not compatible. src shape %s, dst shape %s, " + "expect dst shape %s", + ShapeToString(args.src_shape).c_str(), ShapeToString(args.dst_shape).c_str(), + ShapeToString(expect_dst_shape).c_str()); + return PARAM_INVALID; + } + + return SUCCESS; +} + +Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { + std::shared_ptr dst(new (std::nothrow) uint8_t[total_size], std::default_delete()); + if (dst == nullptr) { + GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", + TypeUtils::FormatToSerialString(args.src_format).c_str(), + TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); + return OUT_OF_MEMORY; + } + + auto h = args.src_shape.at(kHwcnH); + auto w = args.src_shape.at(kHwcnW); + auto c = args.src_shape.at(kHwcnC); + auto n = args.src_shape.at(kHwcnN); + auto c1 = args.dst_shape.at(kC1hwncoc0C1); + auto c0 = args.dst_shape.at(kC1hwncoc0C0); + auto co = args.dst_shape.at(kC1hwncoc0Co); + int64_t coc0 = co * c0; + int64_t ncoc0 = n * coc0; + int64_t wncoc0 = w * ncoc0; + int64_t hwncoc0 = h * wncoc0; + int64_t cn = c * n; + int64_t wcn = w * cn; + + for (int64_t c1_idx = 0; c1_idx < c1; c1_idx++) { + int64_t c1_head_addr = c1_idx * hwncoc0; + for (int64_t h_idx = 0; h_idx < h; h_idx++) { + int64_t h_head_addr = c1_head_addr + h_idx * wncoc0; + for (int64_t w_idx = 0; w_idx < w; w_idx++) { + int64_t w_head_addr = h_head_addr + w_idx * ncoc0; + for (int64_t n_idx = 0; n_idx < n; n_idx++) { + int64_t n_head_addr = w_head_addr + n_idx * coc0; + for (int64_t co_idx = 0; co_idx < co; co_idx++) { + int64_t co_head_addr = n_head_addr + co_idx * c0; + for (int64_t c0_idx = 0; c0_idx < c0; c0_idx++) { + int64_t dst_idx = c0_idx + co_head_addr; + auto dst_offset = dst_idx * size; + auto protected_size = total_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) + ? total_size - dst_offset + : static_cast(SECUREC_MEM_MAX_LEN); + int64_t c_idx = c0_idx + c1_idx * c0; + int64_t src_idx = h_idx * wcn + w_idx * cn + c_idx * n + n_idx; + auto src_offset = src_idx * size; + + if (c_idx < c && c0_idx == co_idx) { + auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, + static_cast(size)); + if (ret != EOK) { + GELOGE(INTERNAL_ERROR, + "Failed to copy data from HWCN[%ld, %ld, %ld, %ld] offset %ld to " + "C1HWNCoC0[%ld, %ld, %ld, %ld, %ld, %ld] offset %ld, err-code %d", + h_idx, w_idx, c_idx, n_idx, src_offset, c1_idx, h_idx, w_idx, n_idx, co_idx, c0_idx, + dst_offset, ret); + return INTERNAL_ERROR; + } + } else { + auto ret = + memset_s(dst.get() + dst_offset, static_cast(protected_size), 0, static_cast(size)); + if (ret != EOK) { + GELOGE(INTERNAL_ERROR, + "Failed to set to 0 to C1HWNCoC0[%ld, %ld, %ld, %ld, %ld, %ld] offset %ld, " + "err-code %d", + c1_idx, h_idx, w_idx, n_idx, co_idx, c0_idx, dst_offset, ret); + return INTERNAL_ERROR; + } + } + } + } + } + } + } + } + result.data = dst; + result.length = static_cast(total_size); + return SUCCESS; +} +} // namespace + +Status FormatTransferHwcnFractalznlstm::TransFormat(const TransArgs &args, TransResult &result) { + if (CheckArgsForHwcnToFrazlstm(args) != SUCCESS) { + return PARAM_INVALID; + } + int size = GetSizeByDataType(args.src_data_type); + auto total_size = GetItemNumByShape(args.dst_shape) * size; + if (total_size <= 0) { + int64_t src_size = GetItemNumByShape(args.src_shape); + if (total_size == 0 && src_size == 0) { + result.length = static_cast(total_size); + return SUCCESS; + } + + GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, + ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); + return PARAM_INVALID; + } + GELOGD("Begin to trans format from HWCN to C1HWNCoC0, src shape %s, data type %s, dst shape %s, memory size %ld", + ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), + ShapeToString(args.dst_shape).c_str(), total_size); + if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", + ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), + ShapeToString(args.dst_shape).c_str(), total_size); + return INTERNAL_ERROR; + } + return SUCCESS; +} + +Status FormatTransferHwcnFractalznlstm::TransShape(Format src_format, const std::vector &src_shape, + DataType data_type, Format dst_format, std::vector &dst_shape) { + if (src_format == FORMAT_HWCN && CheckDataTypeSupported(data_type)) { + if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { + GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", + ShapeToString(src_shape).c_str()); + return ACL_ERROR_GE_SHAPE_INVALID; + } + return TransShapeHwcnToFrazlstm(data_type, src_shape, dst_shape); + } else if (src_format != FORMAT_HWCN) { + return ACL_ERROR_GE_SHAPE_INVALID; + } else { + return ACL_ERROR_GE_SHAPE_INVALID; + } +} + +REGISTER_FORMAT_TRANSFER(FormatTransferHwcnFractalznlstm, FORMAT_HWCN, FORMAT_FRACTAL_ZN_LSTM) +} // namespace formats +} // namespace ge diff --git a/ge/common/ge_common.mk b/ge/common/ge_common.mk index e28090ad..7a7a3801 100755 --- a/ge/common/ge_common.mk +++ b/ge/common/ge_common.mk @@ -31,6 +31,7 @@ GE_COMMON_LOCAL_SRC_FILES := \ formats/format_transfers/format_transfer_dhwcn_fracz3D.cc \ formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc \ formats/format_transfers/format_transfer_nchw_fz_c04.cc \ + formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc \ formats/formats.cc \ ge_format_util.cc \ fmk_error_codes.cc \ diff --git a/ge/ge_inference.mk b/ge/ge_inference.mk index 32fc206d..8bbb0a6d 100755 --- a/ge/ge_inference.mk +++ b/ge/ge_inference.mk @@ -24,6 +24,7 @@ COMMON_LOCAL_SRC_FILES := \ common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc \ common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc \ common/formats/format_transfers/format_transfer_nchw_fz_c04.cc \ + common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc \ common/formats/formats.cc \ common/profiling/profiling_manager.cc \ common/dump/dump_properties.cc \ diff --git a/ge/ge_runner.mk b/ge/ge_runner.mk index 49515fe4..302b9c13 100644 --- a/ge/ge_runner.mk +++ b/ge/ge_runner.mk @@ -19,6 +19,7 @@ LIBGE_LOCAL_SRC_FILES := \ common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc \ common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc \ common/formats/format_transfers/format_transfer_transpose.cc \ + common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc \ common/formats/formats.cc \ common/formats/utils/formats_trans_utils.cc \ common/fp16_t.cc \ diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index dabc1485..3a57bf1a 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -364,6 +364,7 @@ set(COMMON_FORMAT_SRC_FILES "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc" "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc" "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc \ "${GE_CODE_DIR}/ge/common/formats/utils/formats_trans_utils.cc" "${GE_CODE_DIR}/ge/graph/manager/util/hcom_util.cc" "${GE_CODE_DIR}/ge/common/dump/dump_manager.cc" From 64162f801ec291e4cf70ea4f804d432dc229d05f Mon Sep 17 00:00:00 2001 From: "dongduo5@huawei.com" Date: Mon, 12 Apr 2021 16:42:43 +0800 Subject: [PATCH 02/10] Add transfer from HWCN to FRACTAL_ZN_LSTM --- .../format_transfer_hwcn_fractal_zn_lstm.h | 35 ++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.h diff --git a/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.h b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.h new file mode 100644 index 00000000..d233ccd7 --- /dev/null +++ b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_HWCN_C1HWNCOC0_H_ +#define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_HWCN_C1HWNCOC0_H_ + +#include + +#include "register/register_format_transfer.h" + +namespace ge { +namespace formats { +class FormatTransferHwcnFractalznlstm : public FormatTransfer { + public: + Status TransFormat(const TransArgs &args, TransResult &result) override; + Status TransShape(Format src_format, const std::vector &src_shape, DataType data_type, Format dst_format, + std::vector &dst_shape) override; +}; +} // namespace formats +} // namespace ge + +#endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_HWCN_C1HWNCOC0_H_ From b6ec85e1bc840ba55e3239b2c68c7f420810feb0 Mon Sep 17 00:00:00 2001 From: "dongduo5@huawei.com" Date: Mon, 12 Apr 2021 17:00:55 +0800 Subject: [PATCH 03/10] Add transfer from HWCN to FRACTAL_ZN_LSTM --- .../formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.h b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.h index d233ccd7..0c48eac4 100644 --- a/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.h +++ b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_HWCN_C1HWNCOC0_H_ -#define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_HWCN_C1HWNCOC0_H_ +#ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_HWCN_FRACTAL_ZN_LSTM_H_ +#define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_HWCN_FRACTAL_ZN_LSTM_H_ #include @@ -32,4 +32,4 @@ class FormatTransferHwcnFractalznlstm : public FormatTransfer { } // namespace formats } // namespace ge -#endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_HWCN_C1HWNCOC0_H_ +#endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_HWCN_FRACTAL_ZN_LSTM_H_ From db8bf4e589c5b89cee496f5ec53b2cec2db00e15 Mon Sep 17 00:00:00 2001 From: "dongduo5@huawei.com" Date: Mon, 12 Apr 2021 19:24:27 +0800 Subject: [PATCH 04/10] add fractal zn index --- .../formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc | 4 ++-- ge/common/formats/utils/formats_definitions.h | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc index 669b7895..279c840b 100644 --- a/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc +++ b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc @@ -42,7 +42,7 @@ Status TransShapeHwcnToFrazlstm(const DataType &data_type, const std::vector Date: Mon, 12 Apr 2021 19:50:41 +0800 Subject: [PATCH 05/10] add fractal zn index --- .../formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc index 279c840b..b046df1f 100644 --- a/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc +++ b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc @@ -34,7 +34,7 @@ bool CheckDataTypeSupported(const DataType &data_type) { Status TransShapeHwcnToFrazlstm(const DataType &data_type, const std::vector &src_shape, std::vector &dst_shape) { - auto cube_size = GetCubeSizeByDataType(data_type); + /*auto cube_size = GetCubeSizeByDataType(data_type); dst_shape.clear(); dst_shape.push_back(Ceil(src_shape.at(kHwcnC), static_cast(cube_size))); dst_shape.push_back(src_shape.at(kHwcnH)); @@ -46,7 +46,7 @@ Status TransShapeHwcnToFrazlstm(const DataType &data_type, const std::vector dst(new (std::nothrow) uint8_t[total_size], std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", From 7841c4d196fb352aa1004fec7a5b0219e806e9a5 Mon Sep 17 00:00:00 2001 From: "dongduo5@huawei.com" Date: Tue, 13 Apr 2021 21:18:08 +0800 Subject: [PATCH 06/10] add fractal zn lstm --- .../format_transfer_hwcn_fractal_zn_lstm.cc | 24 +++++++++++++--------- ge/common/formats/utils/formats_definitions.h | 4 ---- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc index b046df1f..12dd57e0 100644 --- a/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc +++ b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc @@ -34,19 +34,17 @@ bool CheckDataTypeSupported(const DataType &data_type) { Status TransShapeHwcnToFrazlstm(const DataType &data_type, const std::vector &src_shape, std::vector &dst_shape) { - /*auto cube_size = GetCubeSizeByDataType(data_type); + auto cube_size = GetCubeSizeByDataType(data_type); dst_shape.clear(); dst_shape.push_back(Ceil(src_shape.at(kHwcnC), static_cast(cube_size))); - dst_shape.push_back(src_shape.at(kHwcnH)); - dst_shape.push_back(src_shape.at(kHwcnW)); - dst_shape.push_back(src_shape.at(kHwcnN)); + dst_shape.push_back(Ceil(src_shape.at(kHwcnN), static_cast(cube_size))); dst_shape.push_back(cube_size); dst_shape.push_back(cube_size); - if (!CheckShapeValid(dst_shape, kFracZnLstmDimsNum)) { + if (!CheckShapeValid(dst_shape, kFracZDimsNum)) { GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); return ACL_ERROR_GE_SHAPE_INVALID; - }*/ + } return SUCCESS; } @@ -59,7 +57,7 @@ Status CheckArgsForHwcnToFrazlstm(const TransArgs &args) { return UNSUPPORTED; } if (!CheckDataTypeSupported(args.src_data_type)) { - GELOGE(UNSUPPORTED, "Failed to trans shape from HWCN to C1HWNCoC0, invalid data type %s", + GELOGE(UNSUPPORTED, "Failed to trans shape from HWCN to FRACTAL_ZN_LSTM, invalid data type %s", TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); return UNSUPPORTED; } @@ -89,7 +87,12 @@ Status CheckArgsForHwcnToFrazlstm(const TransArgs &args) { } Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { + + auto ret = memcpy_s(dst.get(), static_cast(total_size), args.data, static_cast(total_size)); + result.data = dst; + result.length = static_cast(total_size); return SUCCESS; + std::shared_ptr dst(new (std::nothrow) uint8_t[total_size], std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", @@ -102,9 +105,10 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in auto w = args.src_shape.at(kHwcnW); auto c = args.src_shape.at(kHwcnC); auto n = args.src_shape.at(kHwcnN); - auto c1 = args.dst_shape.at(kC1hwncoc0C1); - auto c0 = args.dst_shape.at(kC1hwncoc0C0); - auto co = args.dst_shape.at(kC1hwncoc0Co); + auto hwc1 = args.dst_shape.at(kFracZHWC1); + auto n0 = args.dst_shape.at(kFracZN0); + auto ni = args.dst_shape.at(kFracZNi); + auto c0 = args.dst_shape.at(kFracZC0); int64_t coc0 = co * c0; int64_t ncoc0 = n * coc0; int64_t wncoc0 = w * ncoc0; diff --git a/ge/common/formats/utils/formats_definitions.h b/ge/common/formats/utils/formats_definitions.h index 01430a9d..62ead019 100755 --- a/ge/common/formats/utils/formats_definitions.h +++ b/ge/common/formats/utils/formats_definitions.h @@ -101,10 +101,6 @@ enum DhwncDimIndex { kDhwncDimsNum }; -enum FracZnLstmIndex { - kFracZnLstmDimsNum = 6, -}; - } // namespace formats } // namespace ge #endif // GE_COMMON_FORMATS_UTILS_FORMATS_DEFINITIONS_H_ From 67aa5682cd756c9853e17d8baf3030d4607089a3 Mon Sep 17 00:00:00 2001 From: "dongduo5@huawei.com" Date: Tue, 13 Apr 2021 21:38:40 +0800 Subject: [PATCH 07/10] add fractal zn lstm --- .../format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc index 12dd57e0..8cf0bdd8 100644 --- a/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc +++ b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc @@ -92,7 +92,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in result.data = dst; result.length = static_cast(total_size); return SUCCESS; - + std::shared_ptr dst(new (std::nothrow) uint8_t[total_size], std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", @@ -116,7 +116,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in int64_t cn = c * n; int64_t wcn = w * cn; - for (int64_t c1_idx = 0; c1_idx < c1; c1_idx++) { + /*for (int64_t c1_idx = 0; c1_idx < c1; c1_idx++) { int64_t c1_head_addr = c1_idx * hwncoc0; for (int64_t h_idx = 0; h_idx < h; h_idx++) { int64_t h_head_addr = c1_head_addr + h_idx * wncoc0; @@ -163,7 +163,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in } } } - } + }*/ result.data = dst; result.length = static_cast(total_size); return SUCCESS; From fcc5839a5b1decc2f32e8caa47679abc839f57dc Mon Sep 17 00:00:00 2001 From: "dongduo5@huawei.com" Date: Tue, 13 Apr 2021 21:41:48 +0800 Subject: [PATCH 08/10] add fractal zn lstm --- .../formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc index 8cf0bdd8..e7a9c55e 100644 --- a/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc +++ b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc @@ -101,7 +101,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in return OUT_OF_MEMORY; } - auto h = args.src_shape.at(kHwcnH); + /*auto h = args.src_shape.at(kHwcnH); auto w = args.src_shape.at(kHwcnW); auto c = args.src_shape.at(kHwcnC); auto n = args.src_shape.at(kHwcnN); @@ -114,7 +114,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in int64_t wncoc0 = w * ncoc0; int64_t hwncoc0 = h * wncoc0; int64_t cn = c * n; - int64_t wcn = w * cn; + int64_t wcn = w * cn;*/ /*for (int64_t c1_idx = 0; c1_idx < c1; c1_idx++) { int64_t c1_head_addr = c1_idx * hwncoc0; From 1aa5ed78718c0d2d1574334964ca74171266d085 Mon Sep 17 00:00:00 2001 From: "dongduo5@huawei.com" Date: Tue, 13 Apr 2021 21:47:06 +0800 Subject: [PATCH 09/10] add fractal zn lstm --- .../format_transfer_hwcn_fractal_zn_lstm.cc | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc index e7a9c55e..2ccb7a2d 100644 --- a/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc +++ b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc @@ -87,12 +87,6 @@ Status CheckArgsForHwcnToFrazlstm(const TransArgs &args) { } Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { - - auto ret = memcpy_s(dst.get(), static_cast(total_size), args.data, static_cast(total_size)); - result.data = dst; - result.length = static_cast(total_size); - return SUCCESS; - std::shared_ptr dst(new (std::nothrow) uint8_t[total_size], std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", @@ -101,6 +95,11 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in return OUT_OF_MEMORY; } + auto ret = memcpy_s(dst.get(), static_cast(total_size), args.data, static_cast(total_size)); + result.data = dst; + result.length = static_cast(total_size); + return SUCCESS; + /*auto h = args.src_shape.at(kHwcnH); auto w = args.src_shape.at(kHwcnW); auto c = args.src_shape.at(kHwcnC); @@ -163,10 +162,10 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in } } } - }*/ + } result.data = dst; result.length = static_cast(total_size); - return SUCCESS; + return SUCCESS;*/ } } // namespace From ab463188e6ebbdc61a35b4ff220a1e6ead3c7110 Mon Sep 17 00:00:00 2001 From: "dongduo5@huawei.com" Date: Tue, 13 Apr 2021 21:53:53 +0800 Subject: [PATCH 10/10] add fractal zn lstm --- .../format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc index 2ccb7a2d..d73dd2a7 100644 --- a/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc +++ b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc @@ -65,7 +65,7 @@ Status CheckArgsForHwcnToFrazlstm(const TransArgs &args) { GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(args.src_shape).c_str()); return PARAM_INVALID; } - if (!CheckShapeValid(args.dst_shape, kFracZnLstmDimsNum)) { + if (!CheckShapeValid(args.dst_shape, kFracZDimsNum)) { GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(args.dst_shape).c_str()); return PARAM_INVALID; } @@ -96,6 +96,10 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in } auto ret = memcpy_s(dst.get(), static_cast(total_size), args.data, static_cast(total_size)); + if (ret != EOK) { + GELOGE(INTERNAL_ERROR, "Failed to copy data ==="); + return INTERNAL_ERROR; + } result.data = dst; result.length = static_cast(total_size); return SUCCESS;