|
|
@@ -0,0 +1,194 @@ |
|
|
|
/** |
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
* You may obtain a copy of the License at |
|
|
|
* |
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
* |
|
|
|
* Unless required by applicable law or agreed to in writing, software |
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
* See the License for the specific language governing permissions and |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.h" |
|
|
|
|
|
|
|
#include <securec.h> |
|
|
|
#include <memory> |
|
|
|
|
|
|
|
#include "common/formats/utils/formats_definitions.h" |
|
|
|
#include "common/formats/utils/formats_trans_utils.h" |
|
|
|
#include "framework/common/debug/ge_log.h" |
|
|
|
#include "framework/common/debug/log.h" |
|
|
|
#include "graph/utils/type_utils.h" |
|
|
|
|
|
|
|
namespace ge { |
|
|
|
namespace formats { |
|
|
|
namespace { |
|
|
|
bool CheckDataTypeSupported(const DataType &data_type) { |
|
|
|
return (data_type == DT_FLOAT || data_type == DT_FLOAT16); |
|
|
|
} |
|
|
|
|
|
|
|
Status TransShapeHwcnToFrazlstm(const DataType &data_type, const std::vector<int64_t> &src_shape, |
|
|
|
std::vector<int64_t> &dst_shape) { |
|
|
|
auto cube_size = GetCubeSizeByDataType(data_type); |
|
|
|
dst_shape.clear(); |
|
|
|
dst_shape.push_back(Ceil(src_shape.at(kHwcnC), static_cast<int64_t>(cube_size))); |
|
|
|
dst_shape.push_back(Ceil(src_shape.at(kHwcnN), static_cast<int64_t>(cube_size))); |
|
|
|
dst_shape.push_back(cube_size); |
|
|
|
dst_shape.push_back(cube_size); |
|
|
|
if (!CheckShapeValid(dst_shape, kFracZDimsNum)) { |
|
|
|
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", |
|
|
|
ShapeToString(dst_shape).c_str()); |
|
|
|
return ACL_ERROR_GE_SHAPE_INVALID; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status CheckArgsForHwcnToFrazlstm(const TransArgs &args) { |
|
|
|
if (args.src_format != FORMAT_HWCN || args.dst_format != FORMAT_FRACTAL_ZN_LSTM) { |
|
|
|
std::string error = "Dose not support trans format from " + |
|
|
|
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + |
|
|
|
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); |
|
|
|
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); |
|
|
|
return UNSUPPORTED; |
|
|
|
} |
|
|
|
if (!CheckDataTypeSupported(args.src_data_type)) { |
|
|
|
GELOGE(UNSUPPORTED, "Failed to trans shape from HWCN to FRACTAL_ZN_LSTM, invalid data type %s", |
|
|
|
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); |
|
|
|
return UNSUPPORTED; |
|
|
|
} |
|
|
|
if (!CheckShapeValid(args.src_shape, kHwcnDimsNum)) { |
|
|
|
GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(args.src_shape).c_str()); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
if (!CheckShapeValid(args.dst_shape, kFracZDimsNum)) { |
|
|
|
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(args.dst_shape).c_str()); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
std::vector<int64_t> expect_dst_shape; |
|
|
|
auto ret = TransShapeHwcnToFrazlstm(args.src_data_type, args.src_shape, expect_dst_shape); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
return ret; |
|
|
|
} |
|
|
|
if (args.dst_shape != expect_dst_shape) { |
|
|
|
GELOGE(PARAM_INVALID, |
|
|
|
"Failed to trans format, src and dst shape are not compatible. src shape %s, dst shape %s, " |
|
|
|
"expect dst shape %s", |
|
|
|
ShapeToString(args.src_shape).c_str(), ShapeToString(args.dst_shape).c_str(), |
|
|
|
ShapeToString(expect_dst_shape).c_str()); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { |
|
|
|
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); |
|
|
|
if (dst == nullptr) { |
|
|
|
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", |
|
|
|
TypeUtils::FormatToSerialString(args.src_format).c_str(), |
|
|
|
TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); |
|
|
|
return OUT_OF_MEMORY; |
|
|
|
} |
|
|
|
|
|
|
|
auto ret = memcpy_s(dst.get(), static_cast<size_t>(total_size), args.data, static_cast<size_t>(total_size)); |
|
|
|
if (ret != EOK) { |
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to copy data ==="); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
result.data = dst; |
|
|
|
result.length = static_cast<size_t>(total_size); |
|
|
|
return SUCCESS; |
|
|
|
|
|
|
|
auto c = args.src_shape.at(kHwcnC); |
|
|
|
auto n = args.src_shape.at(kHwcnN); |
|
|
|
|
|
|
|
int64_t K = 16; |
|
|
|
auto x_total_size= c / K; |
|
|
|
auto y_total_size = n / K; |
|
|
|
|
|
|
|
for (int64_t x_f = 0; x_f < x_total_size; x_f++) { |
|
|
|
for (int64_t y_f = 0; y_f < y_total_size; y_f++) { |
|
|
|
int64_t x_src_head = x_f * K; |
|
|
|
int64_t y_src_head = y_f * K; |
|
|
|
|
|
|
|
for (int64_t x_src = x_src_head; x_src < x_src_head + K; x_src++) { |
|
|
|
int index = 0; |
|
|
|
for (int64_t y_src = y_src_head; y_src < y_src_head + K; y_src++) { |
|
|
|
int64_t x_dst = x_src_head + index; |
|
|
|
++index; |
|
|
|
int64_t src_idx = x_src * c + y_src; |
|
|
|
auto src_offset = src_idx * size; |
|
|
|
|
|
|
|
int64_t y_dst = x_src + y_src - x_dst; |
|
|
|
auto dst_idx = x_dst * c + y_dst; |
|
|
|
auto dst_offset = dst_idx * size; |
|
|
|
|
|
|
|
auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) |
|
|
|
? total_size - dst_offset |
|
|
|
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); |
|
|
|
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, |
|
|
|
static_cast<size_t>(size)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
result.data = dst; |
|
|
|
result.length = static_cast<size_t>(total_size); |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
Status FormatTransferHwcnFractalznlstm::TransFormat(const TransArgs &args, TransResult &result) { |
|
|
|
if (CheckArgsForHwcnToFrazlstm(args) != SUCCESS) { |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
int size = GetSizeByDataType(args.src_data_type); |
|
|
|
auto total_size = GetItemNumByShape(args.dst_shape) * size; |
|
|
|
if (total_size <= 0) { |
|
|
|
int64_t src_size = GetItemNumByShape(args.src_shape); |
|
|
|
if (total_size == 0 && src_size == 0) { |
|
|
|
result.length = static_cast<size_t>(total_size); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, |
|
|
|
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
GELOGD("Begin to trans format from HWCN to C1HWNCoC0, src shape %s, data type %s, dst shape %s, memory size %ld", |
|
|
|
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), |
|
|
|
ShapeToString(args.dst_shape).c_str(), total_size); |
|
|
|
if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { |
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", |
|
|
|
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), |
|
|
|
ShapeToString(args.dst_shape).c_str(), total_size); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status FormatTransferHwcnFractalznlstm::TransShape(Format src_format, const std::vector<int64_t> &src_shape, |
|
|
|
DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { |
|
|
|
if (src_format == FORMAT_HWCN && CheckDataTypeSupported(data_type)) { |
|
|
|
if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { |
|
|
|
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", |
|
|
|
ShapeToString(src_shape).c_str()); |
|
|
|
return ACL_ERROR_GE_SHAPE_INVALID; |
|
|
|
} |
|
|
|
return TransShapeHwcnToFrazlstm(data_type, src_shape, dst_shape); |
|
|
|
} else if (src_format != FORMAT_HWCN) { |
|
|
|
return ACL_ERROR_GE_SHAPE_INVALID; |
|
|
|
} else { |
|
|
|
return ACL_ERROR_GE_SHAPE_INVALID; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
REGISTER_FORMAT_TRANSFER(FormatTransferHwcnFractalznlstm, FORMAT_HWCN, FORMAT_FRACTAL_ZN_LSTM) |
|
|
|
} // namespace formats |
|
|
|
} // namespace ge |