From 64193c9490e312450110a313eb11e2ddde4d1f55 Mon Sep 17 00:00:00 2001 From: "dongduo5@huawei.com" Date: Thu, 15 Apr 2021 14:07:24 +0800 Subject: [PATCH] Add transfer HWCN to Fractal_Zn_Lstm --- ge/CMakeLists.txt | 2 + ge/common/CMakeLists.txt | 1 + .../format_transfer_hwcn_fractal_zn_lstm.cc | 194 +++++++++++++++++++++ .../format_transfer_hwcn_fractal_zn_lstm.h | 35 ++++ ge/common/ge_common.mk | 1 + ge/ge_inference.mk | 1 + ge/ge_runner.mk | 1 + tests/ut/ge/CMakeLists.txt | 1 + 8 files changed, 236 insertions(+) create mode 100644 ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc create mode 100644 ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.h diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index 89745019..c79ba061 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -100,6 +100,7 @@ set(TRAIN_SRC_LIST "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" "common/formats/format_transfers/format_transfer_transpose.cc" + "common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc" "common/formats/formats.cc" "common/formats/utils/formats_trans_utils.cc" "common/fp16_t.cc" @@ -435,6 +436,7 @@ set(INFER_SRC_LIST "common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" "common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" "common/formats/format_transfers/format_transfer_nchw_fz_c04.cc" + "common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc" "common/formats/formats.cc" "common/profiling/profiling_manager.cc" "common/dump/dump_properties.cc" diff --git a/ge/common/CMakeLists.txt b/ge/common/CMakeLists.txt index 75cb8ad1..51ac0e8a 100755 --- a/ge/common/CMakeLists.txt +++ b/ge/common/CMakeLists.txt @@ -49,6 +49,7 @@ set(SRC_LIST "formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" "formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" "formats/format_transfers/format_transfer_nchw_fz_c04.cc" + "formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc" "formats/formats.cc" "ge_format_util.cc" "fmk_error_codes.cc" diff --git a/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc new file mode 100644 index 00000000..7ed3e7d3 --- /dev/null +++ b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc @@ -0,0 +1,194 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.h" + +#include +#include + +#include "common/formats/utils/formats_definitions.h" +#include "common/formats/utils/formats_trans_utils.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" +#include "graph/utils/type_utils.h" + +namespace ge { +namespace formats { +namespace { +bool CheckDataTypeSupported(const DataType &data_type) { + return (data_type == DT_FLOAT || data_type == DT_FLOAT16); +} + +Status TransShapeHwcnToFrazlstm(const DataType &data_type, const std::vector &src_shape, + std::vector &dst_shape) { + auto cube_size = GetCubeSizeByDataType(data_type); + dst_shape.clear(); + dst_shape.push_back(Ceil(src_shape.at(kHwcnC), static_cast(cube_size))); + dst_shape.push_back(Ceil(src_shape.at(kHwcnN), static_cast(cube_size))); + dst_shape.push_back(cube_size); + dst_shape.push_back(cube_size); + if (!CheckShapeValid(dst_shape, kFracZDimsNum)) { + GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", + ShapeToString(dst_shape).c_str()); + return ACL_ERROR_GE_SHAPE_INVALID; + } + return SUCCESS; +} + +Status CheckArgsForHwcnToFrazlstm(const TransArgs &args) { + if (args.src_format != FORMAT_HWCN || args.dst_format != FORMAT_FRACTAL_ZN_LSTM) { + std::string error = "Dose not support trans format from " + + FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + + FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); + GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); + return UNSUPPORTED; + } + if (!CheckDataTypeSupported(args.src_data_type)) { + GELOGE(UNSUPPORTED, "Failed to trans shape from HWCN to FRACTAL_ZN_LSTM, invalid data type %s", + TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); + return UNSUPPORTED; + } + if (!CheckShapeValid(args.src_shape, kHwcnDimsNum)) { + GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(args.src_shape).c_str()); + return PARAM_INVALID; + } + if (!CheckShapeValid(args.dst_shape, kFracZDimsNum)) { + GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(args.dst_shape).c_str()); + return PARAM_INVALID; + } + std::vector expect_dst_shape; + auto ret = TransShapeHwcnToFrazlstm(args.src_data_type, args.src_shape, expect_dst_shape); + if (ret != SUCCESS) { + return ret; + } + if (args.dst_shape != expect_dst_shape) { + GELOGE(PARAM_INVALID, + "Failed to trans format, src and dst shape are not compatible. src shape %s, dst shape %s, " + "expect dst shape %s", + ShapeToString(args.src_shape).c_str(), ShapeToString(args.dst_shape).c_str(), + ShapeToString(expect_dst_shape).c_str()); + return PARAM_INVALID; + } + + return SUCCESS; +} + +Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { + std::shared_ptr dst(new (std::nothrow) uint8_t[total_size], std::default_delete()); + if (dst == nullptr) { + GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", + TypeUtils::FormatToSerialString(args.src_format).c_str(), + TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); + return OUT_OF_MEMORY; + } + + auto ret = memcpy_s(dst.get(), static_cast(total_size), args.data, static_cast(total_size)); + if (ret != EOK) { + GELOGE(INTERNAL_ERROR, "Failed to copy data ==="); + return INTERNAL_ERROR; + } + result.data = dst; + result.length = static_cast(total_size); + return SUCCESS; + + auto c = args.src_shape.at(kHwcnC); + auto n = args.src_shape.at(kHwcnN); + + int64_t K = 16; + auto x_total_size= c / K; + auto y_total_size = n / K; + + for (int64_t x_f = 0; x_f < x_total_size; x_f++) { + for (int64_t y_f = 0; y_f < y_total_size; y_f++) { + int64_t x_src_head = x_f * K; + int64_t y_src_head = y_f * K; + + for (int64_t x_src = x_src_head; x_src < x_src_head + K; x_src++) { + int index = 0; + for (int64_t y_src = y_src_head; y_src < y_src_head + K; y_src++) { + int64_t x_dst = x_src_head + index; + ++index; + int64_t src_idx = x_src * c + y_src; + auto src_offset = src_idx * size; + + int64_t y_dst = x_src + y_src - x_dst; + auto dst_idx = x_dst * c + y_dst; + auto dst_offset = dst_idx * size; + + auto protected_size = total_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) + ? total_size - dst_offset + : static_cast(SECUREC_MEM_MAX_LEN); + auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, + static_cast(size)); + } + } + } + + } + + result.data = dst; + result.length = static_cast(total_size); +} +} // namespace + +Status FormatTransferHwcnFractalznlstm::TransFormat(const TransArgs &args, TransResult &result) { + if (CheckArgsForHwcnToFrazlstm(args) != SUCCESS) { + return PARAM_INVALID; + } + int size = GetSizeByDataType(args.src_data_type); + auto total_size = GetItemNumByShape(args.dst_shape) * size; + if (total_size <= 0) { + int64_t src_size = GetItemNumByShape(args.src_shape); + if (total_size == 0 && src_size == 0) { + result.length = static_cast(total_size); + return SUCCESS; + } + + GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, + ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); + return PARAM_INVALID; + } + GELOGD("Begin to trans format from HWCN to C1HWNCoC0, src shape %s, data type %s, dst shape %s, memory size %ld", + ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), + ShapeToString(args.dst_shape).c_str(), total_size); + if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", + ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), + ShapeToString(args.dst_shape).c_str(), total_size); + return INTERNAL_ERROR; + } + return SUCCESS; +} + +Status FormatTransferHwcnFractalznlstm::TransShape(Format src_format, const std::vector &src_shape, + DataType data_type, Format dst_format, std::vector &dst_shape) { + if (src_format == FORMAT_HWCN && CheckDataTypeSupported(data_type)) { + if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { + GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", + ShapeToString(src_shape).c_str()); + return ACL_ERROR_GE_SHAPE_INVALID; + } + return TransShapeHwcnToFrazlstm(data_type, src_shape, dst_shape); + } else if (src_format != FORMAT_HWCN) { + return ACL_ERROR_GE_SHAPE_INVALID; + } else { + return ACL_ERROR_GE_SHAPE_INVALID; + } +} + +REGISTER_FORMAT_TRANSFER(FormatTransferHwcnFractalznlstm, FORMAT_HWCN, FORMAT_FRACTAL_ZN_LSTM) +} // namespace formats +} // namespace ge diff --git a/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.h b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.h new file mode 100644 index 00000000..0c48eac4 --- /dev/null +++ b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_HWCN_FRACTAL_ZN_LSTM_H_ +#define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_HWCN_FRACTAL_ZN_LSTM_H_ + +#include + +#include "register/register_format_transfer.h" + +namespace ge { +namespace formats { +class FormatTransferHwcnFractalznlstm : public FormatTransfer { + public: + Status TransFormat(const TransArgs &args, TransResult &result) override; + Status TransShape(Format src_format, const std::vector &src_shape, DataType data_type, Format dst_format, + std::vector &dst_shape) override; +}; +} // namespace formats +} // namespace ge + +#endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_HWCN_FRACTAL_ZN_LSTM_H_ diff --git a/ge/common/ge_common.mk b/ge/common/ge_common.mk index e28090ad..7a7a3801 100755 --- a/ge/common/ge_common.mk +++ b/ge/common/ge_common.mk @@ -31,6 +31,7 @@ GE_COMMON_LOCAL_SRC_FILES := \ formats/format_transfers/format_transfer_dhwcn_fracz3D.cc \ formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc \ formats/format_transfers/format_transfer_nchw_fz_c04.cc \ + formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc \ formats/formats.cc \ ge_format_util.cc \ fmk_error_codes.cc \ diff --git a/ge/ge_inference.mk b/ge/ge_inference.mk index 32fc206d..8bbb0a6d 100755 --- a/ge/ge_inference.mk +++ b/ge/ge_inference.mk @@ -24,6 +24,7 @@ COMMON_LOCAL_SRC_FILES := \ common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc \ common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc \ common/formats/format_transfers/format_transfer_nchw_fz_c04.cc \ + common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc \ common/formats/formats.cc \ common/profiling/profiling_manager.cc \ common/dump/dump_properties.cc \ diff --git a/ge/ge_runner.mk b/ge/ge_runner.mk index 49515fe4..add99caa 100644 --- a/ge/ge_runner.mk +++ b/ge/ge_runner.mk @@ -19,6 +19,7 @@ LIBGE_LOCAL_SRC_FILES := \ common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc \ common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc \ common/formats/format_transfers/format_transfer_transpose.cc \ + common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc \ common/formats/formats.cc \ common/formats/utils/formats_trans_utils.cc \ common/fp16_t.cc \ diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index dabc1485..bd2cae87 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -364,6 +364,7 @@ set(COMMON_FORMAT_SRC_FILES "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc" "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc" "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc" "${GE_CODE_DIR}/ge/common/formats/utils/formats_trans_utils.cc" "${GE_CODE_DIR}/ge/graph/manager/util/hcom_util.cc" "${GE_CODE_DIR}/ge/common/dump/dump_manager.cc"