Browse Source

Add transfer HWCN to Fractal_Zn_Lstm

pull/1527/head
dongduo5@huawei.com 4 years ago
parent
commit
64193c9490
8 changed files with 236 additions and 0 deletions
  1. +2
    -0
      ge/CMakeLists.txt
  2. +1
    -0
      ge/common/CMakeLists.txt
  3. +194
    -0
      ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc
  4. +35
    -0
      ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.h
  5. +1
    -0
      ge/common/ge_common.mk
  6. +1
    -0
      ge/ge_inference.mk
  7. +1
    -0
      ge/ge_runner.mk
  8. +1
    -0
      tests/ut/ge/CMakeLists.txt

+ 2
- 0
ge/CMakeLists.txt View File

@@ -100,6 +100,7 @@ set(TRAIN_SRC_LIST
"common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc"
"common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc"
"common/formats/format_transfers/format_transfer_transpose.cc" "common/formats/format_transfers/format_transfer_transpose.cc"
"common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc"
"common/formats/formats.cc" "common/formats/formats.cc"
"common/formats/utils/formats_trans_utils.cc" "common/formats/utils/formats_trans_utils.cc"
"common/fp16_t.cc" "common/fp16_t.cc"
@@ -435,6 +436,7 @@ set(INFER_SRC_LIST
"common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" "common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc"
"common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" "common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc"
"common/formats/format_transfers/format_transfer_nchw_fz_c04.cc" "common/formats/format_transfers/format_transfer_nchw_fz_c04.cc"
"common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc"
"common/formats/formats.cc" "common/formats/formats.cc"
"common/profiling/profiling_manager.cc" "common/profiling/profiling_manager.cc"
"common/dump/dump_properties.cc" "common/dump/dump_properties.cc"


+ 1
- 0
ge/common/CMakeLists.txt View File

@@ -49,6 +49,7 @@ set(SRC_LIST
"formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" "formats/format_transfers/format_transfer_dhwcn_fracz3D.cc"
"formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" "formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc"
"formats/format_transfers/format_transfer_nchw_fz_c04.cc" "formats/format_transfers/format_transfer_nchw_fz_c04.cc"
"formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc"
"formats/formats.cc" "formats/formats.cc"
"ge_format_util.cc" "ge_format_util.cc"
"fmk_error_codes.cc" "fmk_error_codes.cc"


+ 194
- 0
ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc View File

@@ -0,0 +1,194 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.h"

#include <securec.h>
#include <memory>

#include "common/formats/utils/formats_definitions.h"
#include "common/formats/utils/formats_trans_utils.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h"
#include "graph/utils/type_utils.h"

namespace ge {
namespace formats {
namespace {
bool CheckDataTypeSupported(const DataType &data_type) {
return (data_type == DT_FLOAT || data_type == DT_FLOAT16);
}

Status TransShapeHwcnToFrazlstm(const DataType &data_type, const std::vector<int64_t> &src_shape,
std::vector<int64_t> &dst_shape) {
auto cube_size = GetCubeSizeByDataType(data_type);
dst_shape.clear();
dst_shape.push_back(Ceil(src_shape.at(kHwcnC), static_cast<int64_t>(cube_size)));
dst_shape.push_back(Ceil(src_shape.at(kHwcnN), static_cast<int64_t>(cube_size)));
dst_shape.push_back(cube_size);
dst_shape.push_back(cube_size);
if (!CheckShapeValid(dst_shape, kFracZDimsNum)) {
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s",
ShapeToString(dst_shape).c_str());
return ACL_ERROR_GE_SHAPE_INVALID;
}
return SUCCESS;
}

Status CheckArgsForHwcnToFrazlstm(const TransArgs &args) {
if (args.src_format != FORMAT_HWCN || args.dst_format != FORMAT_FRACTAL_ZN_LSTM) {
std::string error = "Dose not support trans format from " +
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " +
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format));
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str());
return UNSUPPORTED;
}
if (!CheckDataTypeSupported(args.src_data_type)) {
GELOGE(UNSUPPORTED, "Failed to trans shape from HWCN to FRACTAL_ZN_LSTM, invalid data type %s",
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str());
return UNSUPPORTED;
}
if (!CheckShapeValid(args.src_shape, kHwcnDimsNum)) {
GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(args.src_shape).c_str());
return PARAM_INVALID;
}
if (!CheckShapeValid(args.dst_shape, kFracZDimsNum)) {
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(args.dst_shape).c_str());
return PARAM_INVALID;
}
std::vector<int64_t> expect_dst_shape;
auto ret = TransShapeHwcnToFrazlstm(args.src_data_type, args.src_shape, expect_dst_shape);
if (ret != SUCCESS) {
return ret;
}
if (args.dst_shape != expect_dst_shape) {
GELOGE(PARAM_INVALID,
"Failed to trans format, src and dst shape are not compatible. src shape %s, dst shape %s, "
"expect dst shape %s",
ShapeToString(args.src_shape).c_str(), ShapeToString(args.dst_shape).c_str(),
ShapeToString(expect_dst_shape).c_str());
return PARAM_INVALID;
}

return SUCCESS;
}

Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) {
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>());
if (dst == nullptr) {
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s",
TypeUtils::FormatToSerialString(args.src_format).c_str(),
TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str());
return OUT_OF_MEMORY;
}

auto ret = memcpy_s(dst.get(), static_cast<size_t>(total_size), args.data, static_cast<size_t>(total_size));
if (ret != EOK) {
GELOGE(INTERNAL_ERROR, "Failed to copy data ===");
return INTERNAL_ERROR;
}
result.data = dst;
result.length = static_cast<size_t>(total_size);
return SUCCESS;

auto c = args.src_shape.at(kHwcnC);
auto n = args.src_shape.at(kHwcnN);
int64_t K = 16;
auto x_total_size= c / K;
auto y_total_size = n / K;

for (int64_t x_f = 0; x_f < x_total_size; x_f++) {
for (int64_t y_f = 0; y_f < y_total_size; y_f++) {
int64_t x_src_head = x_f * K;
int64_t y_src_head = y_f * K;

for (int64_t x_src = x_src_head; x_src < x_src_head + K; x_src++) {
int index = 0;
for (int64_t y_src = y_src_head; y_src < y_src_head + K; y_src++) {
int64_t x_dst = x_src_head + index;
++index;
int64_t src_idx = x_src * c + y_src;
auto src_offset = src_idx * size;

int64_t y_dst = x_src + y_src - x_dst;
auto dst_idx = x_dst * c + y_dst;
auto dst_offset = dst_idx * size;

auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN)
? total_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
static_cast<size_t>(size));
}
}
}

}

result.data = dst;
result.length = static_cast<size_t>(total_size);
}
} // namespace

Status FormatTransferHwcnFractalznlstm::TransFormat(const TransArgs &args, TransResult &result) {
if (CheckArgsForHwcnToFrazlstm(args) != SUCCESS) {
return PARAM_INVALID;
}
int size = GetSizeByDataType(args.src_data_type);
auto total_size = GetItemNumByShape(args.dst_shape) * size;
if (total_size <= 0) {
int64_t src_size = GetItemNumByShape(args.src_shape);
if (total_size == 0 && src_size == 0) {
result.length = static_cast<size_t>(total_size);
return SUCCESS;
}

GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size,
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str());
return PARAM_INVALID;
}
GELOGD("Begin to trans format from HWCN to C1HWNCoC0, src shape %s, data type %s, dst shape %s, memory size %ld",
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(),
ShapeToString(args.dst_shape).c_str(), total_size);
if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld",
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(),
ShapeToString(args.dst_shape).c_str(), total_size);
return INTERNAL_ERROR;
}
return SUCCESS;
}

Status FormatTransferHwcnFractalznlstm::TransShape(Format src_format, const std::vector<int64_t> &src_shape,
DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) {
if (src_format == FORMAT_HWCN && CheckDataTypeSupported(data_type)) {
if (!CheckShapeValid(src_shape, kHwcnDimsNum)) {
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s",
ShapeToString(src_shape).c_str());
return ACL_ERROR_GE_SHAPE_INVALID;
}
return TransShapeHwcnToFrazlstm(data_type, src_shape, dst_shape);
} else if (src_format != FORMAT_HWCN) {
return ACL_ERROR_GE_SHAPE_INVALID;
} else {
return ACL_ERROR_GE_SHAPE_INVALID;
}
}

REGISTER_FORMAT_TRANSFER(FormatTransferHwcnFractalznlstm, FORMAT_HWCN, FORMAT_FRACTAL_ZN_LSTM)
} // namespace formats
} // namespace ge

+ 35
- 0
ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.h View File

@@ -0,0 +1,35 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_HWCN_FRACTAL_ZN_LSTM_H_
#define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_HWCN_FRACTAL_ZN_LSTM_H_

#include <vector>

#include "register/register_format_transfer.h"

namespace ge {
namespace formats {
class FormatTransferHwcnFractalznlstm : public FormatTransfer {
public:
Status TransFormat(const TransArgs &args, TransResult &result) override;
Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format,
std::vector<int64_t> &dst_shape) override;
};
} // namespace formats
} // namespace ge

#endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_HWCN_FRACTAL_ZN_LSTM_H_

+ 1
- 0
ge/common/ge_common.mk View File

@@ -31,6 +31,7 @@ GE_COMMON_LOCAL_SRC_FILES := \
formats/format_transfers/format_transfer_dhwcn_fracz3D.cc \ formats/format_transfers/format_transfer_dhwcn_fracz3D.cc \
formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc \ formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc \
formats/format_transfers/format_transfer_nchw_fz_c04.cc \ formats/format_transfers/format_transfer_nchw_fz_c04.cc \
formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc \
formats/formats.cc \ formats/formats.cc \
ge_format_util.cc \ ge_format_util.cc \
fmk_error_codes.cc \ fmk_error_codes.cc \


+ 1
- 0
ge/ge_inference.mk View File

@@ -24,6 +24,7 @@ COMMON_LOCAL_SRC_FILES := \
common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc \ common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc \
common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc \ common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc \
common/formats/format_transfers/format_transfer_nchw_fz_c04.cc \ common/formats/format_transfers/format_transfer_nchw_fz_c04.cc \
common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc \
common/formats/formats.cc \ common/formats/formats.cc \
common/profiling/profiling_manager.cc \ common/profiling/profiling_manager.cc \
common/dump/dump_properties.cc \ common/dump/dump_properties.cc \


+ 1
- 0
ge/ge_runner.mk View File

@@ -19,6 +19,7 @@ LIBGE_LOCAL_SRC_FILES := \
common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc \ common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc \
common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc \ common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc \
common/formats/format_transfers/format_transfer_transpose.cc \ common/formats/format_transfers/format_transfer_transpose.cc \
common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc \
common/formats/formats.cc \ common/formats/formats.cc \
common/formats/utils/formats_trans_utils.cc \ common/formats/utils/formats_trans_utils.cc \
common/fp16_t.cc \ common/fp16_t.cc \


+ 1
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -364,6 +364,7 @@ set(COMMON_FORMAT_SRC_FILES
"${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc" "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc"
"${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc" "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc"
"${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc" "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc"
"${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc"
"${GE_CODE_DIR}/ge/common/formats/utils/formats_trans_utils.cc" "${GE_CODE_DIR}/ge/common/formats/utils/formats_trans_utils.cc"
"${GE_CODE_DIR}/ge/graph/manager/util/hcom_util.cc" "${GE_CODE_DIR}/ge/graph/manager/util/hcom_util.cc"
"${GE_CODE_DIR}/ge/common/dump/dump_manager.cc" "${GE_CODE_DIR}/ge/common/dump/dump_manager.cc"


Loading…
Cancel
Save