|
@@ -32,7 +32,7 @@ namespace ge { |
|
|
namespace { |
|
|
namespace { |
|
|
const size_t kConcatV2InputNum = 3; |
|
|
const size_t kConcatV2InputNum = 3; |
|
|
const int kSupportEmptyTensorRank = 1; |
|
|
const int kSupportEmptyTensorRank = 1; |
|
|
const std::set<DataType> concatv2_supported_type = {DT_INT32, DT_FLOAT}; |
|
|
|
|
|
|
|
|
const std::set<DataType> concatv2_supported_type = {DT_INT32, DT_FLOAT, DT_INT64}; |
|
|
|
|
|
|
|
|
template <typename T> |
|
|
template <typename T> |
|
|
void GetOutputData(std::vector<T> &y_data, int64_t loop, size_t &input_size, |
|
|
void GetOutputData(std::vector<T> &y_data, int64_t loop, size_t &input_size, |
|
@@ -88,6 +88,7 @@ Status ConcatV2Kernel::Compute(const ge::OpDescPtr op_desc_ptr, const vector<ge: |
|
|
|
|
|
|
|
|
std::vector<int32_t> y_data_int32_t; |
|
|
std::vector<int32_t> y_data_int32_t; |
|
|
std::vector<float> y_data_float; |
|
|
std::vector<float> y_data_float; |
|
|
|
|
|
std::vector<int64_t> y_data_int64_t; |
|
|
|
|
|
|
|
|
// Index 0 can always gets a GeTensorDesc object from any OpDescPtr. |
|
|
// Index 0 can always gets a GeTensorDesc object from any OpDescPtr. |
|
|
auto output_tensor_desc = op_desc_ptr->GetOutputDesc(0); |
|
|
auto output_tensor_desc = op_desc_ptr->GetOutputDesc(0); |
|
@@ -106,6 +107,7 @@ Status ConcatV2Kernel::Compute(const ge::OpDescPtr op_desc_ptr, const vector<ge: |
|
|
switch (data_type) { |
|
|
switch (data_type) { |
|
|
SET_OUTPUT(DT_INT32, int32_t) |
|
|
SET_OUTPUT(DT_INT32, int32_t) |
|
|
SET_OUTPUT(DT_FLOAT, float) |
|
|
SET_OUTPUT(DT_FLOAT, float) |
|
|
|
|
|
SET_OUTPUT(DT_INT64, int64_t) |
|
|
default: |
|
|
default: |
|
|
break; |
|
|
break; |
|
|
} |
|
|
} |
|
|