|
@@ -27,19 +27,19 @@ using namespace std; |
|
|
|
|
|
|
|
|
ShapeTransferAccordingToFormat::ShapeTransferAccordingToFormat(void) { |
|
|
ShapeTransferAccordingToFormat::ShapeTransferAccordingToFormat(void) { |
|
|
getNewShapeFuncMap = { |
|
|
getNewShapeFuncMap = { |
|
|
{ge::FORMAT_NCHW, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNCHWShapeByAxisValue)}, |
|
|
|
|
|
{ge::FORMAT_NHWC, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNHWCShapeByAxisValue)}, |
|
|
|
|
|
{ge::FORMAT_NC1HWC0, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNC1HWC0ShapeByAxisValue)}, |
|
|
|
|
|
{ge::FORMAT_FRACTAL_Z, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetFzShapeByAxisValue)}, |
|
|
|
|
|
{ge::FORMAT_HWCN, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetHWCNShapeByAxisValue)}, |
|
|
|
|
|
{ge::FORMAT_C1HWNCoC0, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetC1HWNCoC0ShapeByAxisValue)}, |
|
|
|
|
|
{ge::FORMAT_FRACTAL_NZ, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNzShapeByAxisValue)}}; |
|
|
|
|
|
|
|
|
{ge::FORMAT_NCHW, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNCHWShapeByAxisValue)}, |
|
|
|
|
|
{ge::FORMAT_NHWC, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNHWCShapeByAxisValue)}, |
|
|
|
|
|
{ge::FORMAT_NC1HWC0, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNC1HWC0ShapeByAxisValue)}, |
|
|
|
|
|
{ge::FORMAT_FRACTAL_Z, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetFzShapeByAxisValue)}, |
|
|
|
|
|
{ge::FORMAT_HWCN, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetHWCNShapeByAxisValue)}, |
|
|
|
|
|
{ge::FORMAT_C1HWNCoC0, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetC1HWNCoC0ShapeByAxisValue)}, |
|
|
|
|
|
{ge::FORMAT_FRACTAL_NZ, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNzShapeByAxisValue)}}; |
|
|
|
|
|
|
|
|
mapOfDtypeAndC0 = { |
|
|
mapOfDtypeAndC0 = { |
|
|
{ge::DT_FLOAT16, SHAPE_NUMBER_16}, {ge::DT_FLOAT, SHAPE_NUMBER_16}, {ge::DT_INT8, SHAPE_NUMBER_32}, |
|
|
|
|
|
{ge::DT_INT16, SHAPE_NUMBER_16}, {ge::DT_INT32, SHAPE_NUMBER_16}, {ge::DT_INT64, SHAPE_NUMBER_16}, |
|
|
|
|
|
{ge::DT_UINT8, SHAPE_NUMBER_16}, {ge::DT_UINT16, SHAPE_NUMBER_32}, {ge::DT_UINT32, SHAPE_NUMBER_16}, |
|
|
|
|
|
{ge::DT_UINT64, SHAPE_NUMBER_16}, {ge::DT_BOOL, SHAPE_NUMBER_16}}; |
|
|
|
|
|
|
|
|
{ge::DT_FLOAT16, SHAPE_NUMBER_16}, {ge::DT_FLOAT, SHAPE_NUMBER_16}, {ge::DT_INT8, SHAPE_NUMBER_32}, |
|
|
|
|
|
{ge::DT_INT16, SHAPE_NUMBER_16}, {ge::DT_INT32, SHAPE_NUMBER_16}, {ge::DT_INT64, SHAPE_NUMBER_16}, |
|
|
|
|
|
{ge::DT_UINT8, SHAPE_NUMBER_16}, {ge::DT_UINT16, SHAPE_NUMBER_32}, {ge::DT_UINT32, SHAPE_NUMBER_16}, |
|
|
|
|
|
{ge::DT_UINT64, SHAPE_NUMBER_16}, {ge::DT_BOOL, SHAPE_NUMBER_16}}; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool ShapeTransferAccordingToFormat::GetNCHWShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType, |
|
|
bool ShapeTransferAccordingToFormat::GetNCHWShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType, |
|
@@ -97,9 +97,9 @@ bool ShapeTransferAccordingToFormat::GetFzShapeByAxisValue(vector<int64_t>& newS |
|
|
/* sizeOfOriginalVec - 1 mean the last value of original vec |
|
|
/* sizeOfOriginalVec - 1 mean the last value of original vec |
|
|
* sizeOfOriginalVec - 2 mean the second last value of original vec */ |
|
|
* sizeOfOriginalVec - 2 mean the second last value of original vec */ |
|
|
newShape[sizeOfOriginalVec - MINUS_VALUE_ONE] = |
|
|
newShape[sizeOfOriginalVec - MINUS_VALUE_ONE] = |
|
|
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], SHAPE_NUMBER_16); |
|
|
|
|
|
|
|
|
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], SHAPE_NUMBER_16); |
|
|
newShape[sizeOfOriginalVec - MINUS_VALUE_TWO] = |
|
|
newShape[sizeOfOriginalVec - MINUS_VALUE_TWO] = |
|
|
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], axisValue[AXIS_C0]); |
|
|
|
|
|
|
|
|
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], axisValue[AXIS_C0]); |
|
|
newShape.push_back(SHAPE_NUMBER_16); |
|
|
newShape.push_back(SHAPE_NUMBER_16); |
|
|
newShape.push_back(axisValue[AXIS_C0]); |
|
|
newShape.push_back(axisValue[AXIS_C0]); |
|
|
} else { |
|
|
} else { |
|
@@ -163,10 +163,10 @@ bool ShapeTransferAccordingToFormat::GetNzShapeByAxisValue(vector<int64_t>& newS |
|
|
/* sizeOfOriginalVec - 1 mean the last value of original vec |
|
|
/* sizeOfOriginalVec - 1 mean the last value of original vec |
|
|
* sizeOfOriginalVec - 2 mean the second last value of original vec */ |
|
|
* sizeOfOriginalVec - 2 mean the second last value of original vec */ |
|
|
newShape[sizeOfOriginalVec - MINUS_VALUE_ONE] = |
|
|
newShape[sizeOfOriginalVec - MINUS_VALUE_ONE] = |
|
|
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], (int64_t)SHAPE_NUMBER_16); |
|
|
|
|
|
|
|
|
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], (int64_t)SHAPE_NUMBER_16); |
|
|
|
|
|
|
|
|
newShape[sizeOfOriginalVec - MINUS_VALUE_TWO] = |
|
|
newShape[sizeOfOriginalVec - MINUS_VALUE_TWO] = |
|
|
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], axisValue[AXIS_C0]); |
|
|
|
|
|
|
|
|
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], axisValue[AXIS_C0]); |
|
|
newShape.push_back(SHAPE_NUMBER_16); |
|
|
newShape.push_back(SHAPE_NUMBER_16); |
|
|
newShape.push_back(axisValue[AXIS_C0]); |
|
|
newShape.push_back(axisValue[AXIS_C0]); |
|
|
return true; |
|
|
return true; |
|
@@ -177,7 +177,7 @@ bool ShapeTransferAccordingToFormat::GetShapeAccordingToFormat(ShapeAndFormat& s |
|
|
shapeAndFormatInfo.newShape = shapeAndFormatInfo.oldShape; |
|
|
shapeAndFormatInfo.newShape = shapeAndFormatInfo.oldShape; |
|
|
if (shapeAndFormatInfo.oldFormat >= ge::FORMAT_RESERVED || shapeAndFormatInfo.newFormat >= ge::FORMAT_RESERVED) { |
|
|
if (shapeAndFormatInfo.oldFormat >= ge::FORMAT_RESERVED || shapeAndFormatInfo.newFormat >= ge::FORMAT_RESERVED) { |
|
|
GELOGE(GRAPH_FAILED, "Old format %u or new format %u is invalid!", shapeAndFormatInfo.oldFormat, |
|
|
GELOGE(GRAPH_FAILED, "Old format %u or new format %u is invalid!", shapeAndFormatInfo.oldFormat, |
|
|
shapeAndFormatInfo.newFormat); |
|
|
|
|
|
|
|
|
shapeAndFormatInfo.newFormat); |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
@@ -223,8 +223,8 @@ bool ShapeTransferAccordingToFormat::GetShapeAccordingToFormat(ShapeAndFormat& s |
|
|
c0 = SHAPE_DIM_VALUE_C04; |
|
|
c0 = SHAPE_DIM_VALUE_C04; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool status = axisutil_object->GetAxisValueByOriginFormat( |
|
|
|
|
|
shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.oldShape, c0, axisValue, ndValue); |
|
|
|
|
|
|
|
|
bool status = axisutil_object->GetAxisValueByOriginFormat(shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.oldShape, |
|
|
|
|
|
c0, axisValue, ndValue); |
|
|
if (status != true && shapeAndFormatInfo.newFormat != ge::FORMAT_FRACTAL_NZ) { |
|
|
if (status != true && shapeAndFormatInfo.newFormat != ge::FORMAT_FRACTAL_NZ) { |
|
|
delete axisutil_object; |
|
|
delete axisutil_object; |
|
|
return true; |
|
|
return true; |
|
@@ -238,5 +238,5 @@ bool ShapeTransferAccordingToFormat::GetShapeAccordingToFormat(ShapeAndFormat& s |
|
|
} |
|
|
} |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
} // namespace transformer |
|
|
|
|
|
} // namespace common |
|
|
|
|
|
|
|
|
} // namespace transformer |
|
|
|
|
|
} // namespace common |