Browse Source

Pre Merge pull request !1502 from 董铎/master

pull/1502/MERGE
董铎 Gitee 4 years ago
parent
commit
40228cd4a5
9 changed files with 266 additions and 0 deletions
  1. +2
    -0
      ge/CMakeLists.txt
  2. +1
    -0
      ge/common/CMakeLists.txt
  3. +223
    -0
      ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc
  4. +35
    -0
      ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.h
  5. +1
    -0
      ge/common/formats/utils/formats_definitions.h
  6. +1
    -0
      ge/common/ge_common.mk
  7. +1
    -0
      ge/ge_inference.mk
  8. +1
    -0
      ge/ge_runner.mk
  9. +1
    -0
      tests/ut/ge/CMakeLists.txt

+ 2
- 0
ge/CMakeLists.txt View File

@@ -100,6 +100,7 @@ set(TRAIN_SRC_LIST
"common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc"
"common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc"
"common/formats/format_transfers/format_transfer_transpose.cc"
"common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc"
"common/formats/formats.cc"
"common/formats/utils/formats_trans_utils.cc"
"common/fp16_t.cc"
@@ -435,6 +436,7 @@ set(INFER_SRC_LIST
"common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc"
"common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc"
"common/formats/format_transfers/format_transfer_nchw_fz_c04.cc"
"common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc"
"common/formats/formats.cc"
"common/profiling/profiling_manager.cc"
"common/dump/dump_properties.cc"


+ 1
- 0
ge/common/CMakeLists.txt View File

@@ -49,6 +49,7 @@ set(SRC_LIST
"formats/format_transfers/format_transfer_dhwcn_fracz3D.cc"
"formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc"
"formats/format_transfers/format_transfer_nchw_fz_c04.cc"
"formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc"
"formats/formats.cc"
"ge_format_util.cc"
"fmk_error_codes.cc"


+ 223
- 0
ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc View File

@@ -0,0 +1,223 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.h"

#include <securec.h>
#include <memory>

#include "common/formats/utils/formats_definitions.h"
#include "common/formats/utils/formats_trans_utils.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h"
#include "graph/utils/type_utils.h"

namespace ge {
namespace formats {
namespace {
bool CheckDataTypeSupported(const DataType &data_type) {
return (data_type == DT_FLOAT || data_type == DT_FLOAT16);
}

Status TransShapeHwcnToFrazlstm(const DataType &data_type, const std::vector<int64_t> &src_shape,
std::vector<int64_t> &dst_shape) {
auto cube_size = GetCubeSizeByDataType(data_type);
dst_shape.clear();
dst_shape.push_back(Ceil(src_shape.at(kHwcnC), static_cast<int64_t>(cube_size)));
dst_shape.push_back(Ceil(src_shape.at(kHwcnN), static_cast<int64_t>(cube_size)));
dst_shape.push_back(cube_size);
dst_shape.push_back(cube_size);
if (!CheckShapeValid(dst_shape, kFracZDimsNum)) {
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s",
ShapeToString(dst_shape).c_str());
return ACL_ERROR_GE_SHAPE_INVALID;
}
return SUCCESS;
}

Status CheckArgsForHwcnToFrazlstm(const TransArgs &args) {
if (args.src_format != FORMAT_HWCN || args.dst_format != FORMAT_FRACTAL_ZN_LSTM) {
std::string error = "Dose not support trans format from " +
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " +
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format));
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str());
return UNSUPPORTED;
}
if (!CheckDataTypeSupported(args.src_data_type)) {
GELOGE(UNSUPPORTED, "Failed to trans shape from HWCN to FRACTAL_ZN_LSTM, invalid data type %s",
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str());
return UNSUPPORTED;
}
if (!CheckShapeValid(args.src_shape, kHwcnDimsNum)) {
GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(args.src_shape).c_str());
return PARAM_INVALID;
}
if (!CheckShapeValid(args.dst_shape, kFracZDimsNum)) {
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(args.dst_shape).c_str());
return PARAM_INVALID;
}
std::vector<int64_t> expect_dst_shape;
auto ret = TransShapeHwcnToFrazlstm(args.src_data_type, args.src_shape, expect_dst_shape);
if (ret != SUCCESS) {
return ret;
}
if (args.dst_shape != expect_dst_shape) {
GELOGE(PARAM_INVALID,
"Failed to trans format, src and dst shape are not compatible. src shape %s, dst shape %s, "
"expect dst shape %s",
ShapeToString(args.src_shape).c_str(), ShapeToString(args.dst_shape).c_str(),
ShapeToString(expect_dst_shape).c_str());
return PARAM_INVALID;
}

return SUCCESS;
}

Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) {
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>());
if (dst == nullptr) {
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s",
TypeUtils::FormatToSerialString(args.src_format).c_str(),
TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str());
return OUT_OF_MEMORY;
}

auto ret = memcpy_s(dst.get(), static_cast<size_t>(total_size), args.data, static_cast<size_t>(total_size));
if (ret != EOK) {
GELOGE(INTERNAL_ERROR, "Failed to copy data ===");
return INTERNAL_ERROR;
}
result.data = dst;
result.length = static_cast<size_t>(total_size);
return SUCCESS;

/*auto h = args.src_shape.at(kHwcnH);
auto w = args.src_shape.at(kHwcnW);
auto c = args.src_shape.at(kHwcnC);
auto n = args.src_shape.at(kHwcnN);
auto hwc1 = args.dst_shape.at(kFracZHWC1);
auto n0 = args.dst_shape.at(kFracZN0);
auto ni = args.dst_shape.at(kFracZNi);
auto c0 = args.dst_shape.at(kFracZC0);
int64_t coc0 = co * c0;
int64_t ncoc0 = n * coc0;
int64_t wncoc0 = w * ncoc0;
int64_t hwncoc0 = h * wncoc0;
int64_t cn = c * n;
int64_t wcn = w * cn;*/

/*for (int64_t c1_idx = 0; c1_idx < c1; c1_idx++) {
int64_t c1_head_addr = c1_idx * hwncoc0;
for (int64_t h_idx = 0; h_idx < h; h_idx++) {
int64_t h_head_addr = c1_head_addr + h_idx * wncoc0;
for (int64_t w_idx = 0; w_idx < w; w_idx++) {
int64_t w_head_addr = h_head_addr + w_idx * ncoc0;
for (int64_t n_idx = 0; n_idx < n; n_idx++) {
int64_t n_head_addr = w_head_addr + n_idx * coc0;
for (int64_t co_idx = 0; co_idx < co; co_idx++) {
int64_t co_head_addr = n_head_addr + co_idx * c0;
for (int64_t c0_idx = 0; c0_idx < c0; c0_idx++) {
int64_t dst_idx = c0_idx + co_head_addr;
auto dst_offset = dst_idx * size;
auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN)
? total_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
int64_t c_idx = c0_idx + c1_idx * c0;
int64_t src_idx = h_idx * wcn + w_idx * cn + c_idx * n + n_idx;
auto src_offset = src_idx * size;

if (c_idx < c && c0_idx == co_idx) {
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
static_cast<size_t>(size));
if (ret != EOK) {
GELOGE(INTERNAL_ERROR,
"Failed to copy data from HWCN[%ld, %ld, %ld, %ld] offset %ld to "
"C1HWNCoC0[%ld, %ld, %ld, %ld, %ld, %ld] offset %ld, err-code %d",
h_idx, w_idx, c_idx, n_idx, src_offset, c1_idx, h_idx, w_idx, n_idx, co_idx, c0_idx,
dst_offset, ret);
return INTERNAL_ERROR;
}
} else {
auto ret =
memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size));
if (ret != EOK) {
GELOGE(INTERNAL_ERROR,
"Failed to set to 0 to C1HWNCoC0[%ld, %ld, %ld, %ld, %ld, %ld] offset %ld, "
"err-code %d",
c1_idx, h_idx, w_idx, n_idx, co_idx, c0_idx, dst_offset, ret);
return INTERNAL_ERROR;
}
}
}
}
}
}
}
}
result.data = dst;
result.length = static_cast<size_t>(total_size);
return SUCCESS;*/
}
} // namespace

Status FormatTransferHwcnFractalznlstm::TransFormat(const TransArgs &args, TransResult &result) {
if (CheckArgsForHwcnToFrazlstm(args) != SUCCESS) {
return PARAM_INVALID;
}
int size = GetSizeByDataType(args.src_data_type);
auto total_size = GetItemNumByShape(args.dst_shape) * size;
if (total_size <= 0) {
int64_t src_size = GetItemNumByShape(args.src_shape);
if (total_size == 0 && src_size == 0) {
result.length = static_cast<size_t>(total_size);
return SUCCESS;
}

GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size,
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str());
return PARAM_INVALID;
}
GELOGD("Begin to trans format from HWCN to C1HWNCoC0, src shape %s, data type %s, dst shape %s, memory size %ld",
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(),
ShapeToString(args.dst_shape).c_str(), total_size);
if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld",
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(),
ShapeToString(args.dst_shape).c_str(), total_size);
return INTERNAL_ERROR;
}
return SUCCESS;
}

Status FormatTransferHwcnFractalznlstm::TransShape(Format src_format, const std::vector<int64_t> &src_shape,
DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) {
if (src_format == FORMAT_HWCN && CheckDataTypeSupported(data_type)) {
if (!CheckShapeValid(src_shape, kHwcnDimsNum)) {
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s",
ShapeToString(src_shape).c_str());
return ACL_ERROR_GE_SHAPE_INVALID;
}
return TransShapeHwcnToFrazlstm(data_type, src_shape, dst_shape);
} else if (src_format != FORMAT_HWCN) {
return ACL_ERROR_GE_SHAPE_INVALID;
} else {
return ACL_ERROR_GE_SHAPE_INVALID;
}
}

REGISTER_FORMAT_TRANSFER(FormatTransferHwcnFractalznlstm, FORMAT_HWCN, FORMAT_FRACTAL_ZN_LSTM)
} // namespace formats
} // namespace ge

+ 35
- 0
ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.h View File

@@ -0,0 +1,35 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_HWCN_FRACTAL_ZN_LSTM_H_
#define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_HWCN_FRACTAL_ZN_LSTM_H_

#include <vector>

#include "register/register_format_transfer.h"

namespace ge {
namespace formats {
class FormatTransferHwcnFractalznlstm : public FormatTransfer {
public:
Status TransFormat(const TransArgs &args, TransResult &result) override;
Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format,
std::vector<int64_t> &dst_shape) override;
};
} // namespace formats
} // namespace ge

#endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_HWCN_FRACTAL_ZN_LSTM_H_

+ 1
- 0
ge/common/formats/utils/formats_definitions.h View File

@@ -100,6 +100,7 @@ enum DhwncDimIndex {
kDhwncC,
kDhwncDimsNum
};

} // namespace formats
} // namespace ge
#endif // GE_COMMON_FORMATS_UTILS_FORMATS_DEFINITIONS_H_

+ 1
- 0
ge/common/ge_common.mk View File

@@ -31,6 +31,7 @@ GE_COMMON_LOCAL_SRC_FILES := \
formats/format_transfers/format_transfer_dhwcn_fracz3D.cc \
formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc \
formats/format_transfers/format_transfer_nchw_fz_c04.cc \
formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc \
formats/formats.cc \
ge_format_util.cc \
fmk_error_codes.cc \


+ 1
- 0
ge/ge_inference.mk View File

@@ -24,6 +24,7 @@ COMMON_LOCAL_SRC_FILES := \
common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc \
common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc \
common/formats/format_transfers/format_transfer_nchw_fz_c04.cc \
common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc \
common/formats/formats.cc \
common/profiling/profiling_manager.cc \
common/dump/dump_properties.cc \


+ 1
- 0
ge/ge_runner.mk View File

@@ -19,6 +19,7 @@ LIBGE_LOCAL_SRC_FILES := \
common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc \
common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc \
common/formats/format_transfers/format_transfer_transpose.cc \
common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc \
common/formats/formats.cc \
common/formats/utils/formats_trans_utils.cc \
common/fp16_t.cc \


+ 1
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -364,6 +364,7 @@ set(COMMON_FORMAT_SRC_FILES
"${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc"
"${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc"
"${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc"
"${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc \
"${GE_CODE_DIR}/ge/common/formats/utils/formats_trans_utils.cc"
"${GE_CODE_DIR}/ge/graph/manager/util/hcom_util.cc"
"${GE_CODE_DIR}/ge/common/dump/dump_manager.cc"


Loading…
Cancel
Save