GitOrigin-RevId: a802a14e8d
tags/v0.4.0
@@ -15,18 +15,16 @@ | |||||
using namespace megdnn; | using namespace megdnn; | ||||
void ParamPackConcatSplitBase::check_exec(const TensorLayout& concated, | void ParamPackConcatSplitBase::check_exec(const TensorLayout& concated, | ||||
const TensorLayout& table, | |||||
const TensorLayout& offsets, | |||||
const TensorLayout& parts) { | const TensorLayout& parts) { | ||||
megdnn_assert(table.dtype == dtype::Int32{}, "bad dtype: %s", | |||||
table.dtype.name()); | |||||
megdnn_assert(concated.ndim == 1 && table.ndim == 1 && parts.ndim == 1 && | |||||
concated.stride[0] == 1 && table.stride[0] == 1 && | |||||
megdnn_assert(offsets.dtype == dtype::Int32{}, "bad dtype: %s", | |||||
offsets.dtype.name()); | |||||
megdnn_assert(concated.ndim == 1 && offsets.ndim == 1 && parts.ndim == 1 && | |||||
concated.stride[0] == 1 && offsets.stride[0] == 1 && | |||||
parts.stride[0] == 1, | parts.stride[0] == 1, | ||||
"bad layout: concated=%s table=%s parts=%s", | |||||
concated.to_string().c_str(), table.to_string().c_str(), | |||||
"bad layout: concated=%s offsets=%s parts=%s", | |||||
concated.to_string().c_str(), offsets.to_string().c_str(), | |||||
parts.to_string().c_str()); | parts.to_string().c_str()); | ||||
megdnn_assert(table.shape[0] == concated.shape[0] * 2, | |||||
"concated=%zu table=%zu", concated.shape[0], table.shape[0]); | |||||
} | } | ||||
std::vector<dt_int32> ParamPackConcatSplitBase::gen_offsets( | std::vector<dt_int32> ParamPackConcatSplitBase::gen_offsets( | ||||
@@ -46,11 +44,13 @@ std::vector<dt_int32> ParamPackConcatSplitBase::gen_offsets( | |||||
return v + ((alignment - mod) & (alignment - 1)); | return v + ((alignment - mod) & (alignment - 1)); | ||||
}; | }; | ||||
std::vector<dt_int32> offsets(shapes.size()); | |||||
std::vector<dt_int32> offsets(shapes.size() << 1); | |||||
size_t offset = 0; | size_t offset = 0; | ||||
for (size_t i = 0; i < shapes.size(); i++) { | for (size_t i = 0; i < shapes.size(); i++) { | ||||
offsets[i] = offset; | |||||
offset = get_aligned(offset) + shapes[i].total_nr_elems(); | |||||
offset = get_aligned(offset); | |||||
offsets[i * 2] = offset; | |||||
offset += shapes[i].total_nr_elems(); | |||||
offsets[i * 2 + 1] = offset; | |||||
} | } | ||||
return offsets; | return offsets; | ||||
} | } | ||||
@@ -24,7 +24,7 @@ size_t ParamPackConcatImpl::get_workspace_in_bytes(const TensorShapeArray& srcs, | |||||
template <typename T> | template <typename T> | ||||
void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, | void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, | ||||
_megdnn_tensor_in table, | |||||
_megdnn_tensor_in offsets, | |||||
_megdnn_tensor_out dst, | _megdnn_tensor_out dst, | ||||
_megdnn_workspace workspace) { | _megdnn_workspace workspace) { | ||||
size_t inp_size = srcs.layout.shape[0], | size_t inp_size = srcs.layout.shape[0], | ||||
@@ -35,25 +35,25 @@ void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, | |||||
megdnn_assert_internal(src_cpu); | megdnn_assert_internal(src_cpu); | ||||
auto src_gpu = reinterpret_cast<const T**>(workspace.raw_ptr); | auto src_gpu = reinterpret_cast<const T**>(workspace.raw_ptr); | ||||
auto table_outer_gpu = table.ptr<int32_t>(), | |||||
table_inner_gpu = table_outer_gpu + out_size; | |||||
auto offsets_gpu = offsets.ptr<int32_t>(); | |||||
cuda_check(cudaMemcpyAsync(src_gpu, src_cpu, sizeof(const T*) * inp_size, | cuda_check(cudaMemcpyAsync(src_gpu, src_cpu, sizeof(const T*) * inp_size, | ||||
cudaMemcpyHostToDevice, stream)); | cudaMemcpyHostToDevice, stream)); | ||||
param_pack::concat_proxy<T>(src_gpu, dst.ptr<T>(), out_size, | |||||
table_outer_gpu, table_inner_gpu, stream); | |||||
param_pack::concat_proxy<T>(src_gpu, dst.ptr<T>(), inp_size, out_size, | |||||
offsets_gpu, stream); | |||||
} | } | ||||
void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, _megdnn_tensor_in table, | |||||
void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, | |||||
_megdnn_tensor_in offsets, | |||||
_megdnn_tensor_out dst, | _megdnn_tensor_out dst, | ||||
_megdnn_workspace workspace) { | _megdnn_workspace workspace) { | ||||
check_exec(dst.layout, table.layout, srcs.layout); | |||||
#define cb(DType) \ | |||||
if (dst.layout.dtype == DType()) { \ | |||||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||||
exec_internal<ctype>(srcs, table, dst, workspace); \ | |||||
return; \ | |||||
check_exec(dst.layout, offsets.layout, srcs.layout); | |||||
#define cb(DType) \ | |||||
if (dst.layout.dtype == DType()) { \ | |||||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||||
exec_internal<ctype>(srcs, offsets, dst, workspace); \ | |||||
return; \ | |||||
} | } | ||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
megdnn_throw("bad type"); | megdnn_throw("bad type"); | ||||
@@ -19,17 +19,24 @@ namespace param_pack { | |||||
template <typename T> | template <typename T> | ||||
__global__ void concat_kernel(const T** srcs, T* dst, | __global__ void concat_kernel(const T** srcs, T* dst, | ||||
const int32_t* table_outer, | |||||
const int32_t* table_inner, | |||||
const int32_t* offsets, | |||||
size_t srcs_size, | |||||
size_t total_size) { | size_t total_size) { | ||||
size_t addr = threadIdx.x + blockIdx.x * blockDim.x; | size_t addr = threadIdx.x + blockIdx.x * blockDim.x; | ||||
if (addr < total_size) { | if (addr < total_size) { | ||||
int32_t i = table_outer[addr]; | |||||
int32_t idx = table_inner[addr]; | |||||
if (idx != -1) | |||||
dst[addr] = srcs[i][idx]; | |||||
else | |||||
size_t l = 0, r = srcs_size - 1, mid; | |||||
while (l < r) { | |||||
mid = (l + r) >> 1; | |||||
if (offsets[(mid << 1) + 1] > addr) { | |||||
r = mid; | |||||
} else { | |||||
l = mid + 1; | |||||
} | |||||
} | |||||
if (addr < offsets[l << 1]) | |||||
dst[addr] = 0; | dst[addr] = 0; | ||||
else | |||||
dst[addr] = srcs[l][addr - offsets[l << 1]]; | |||||
} | } | ||||
} | } | ||||
@@ -59,20 +66,20 @@ void split_proxy(const T* src, T** dsts, size_t total_size, | |||||
} | } | ||||
template <typename T> | template <typename T> | ||||
void concat_proxy(const T** srcs, T* dst, size_t total_size, | |||||
const int32_t* table_outer, | |||||
const int32_t* table_inner, cudaStream_t stream) { | |||||
void concat_proxy(const T** srcs, T* dst, size_t srcs_size, size_t total_size, | |||||
const int32_t* offsets, | |||||
cudaStream_t stream) { | |||||
size_t NR_BLOCKS = DIVUP(total_size, NR_THREADS); | size_t NR_BLOCKS = DIVUP(total_size, NR_THREADS); | ||||
concat_kernel<<<NR_BLOCKS, NR_THREADS, 0, stream>>>( | concat_kernel<<<NR_BLOCKS, NR_THREADS, 0, stream>>>( | ||||
srcs, dst, table_outer, table_inner, total_size); | |||||
srcs, dst, offsets, srcs_size, total_size); | |||||
after_kernel_launch(); | after_kernel_launch(); | ||||
} | } | ||||
#define INST(T) \ | #define INST(T) \ | ||||
template void concat_proxy<T>(const T**, T*, size_t, \ | |||||
const int32_t*, const int32_t*, \ | |||||
template void concat_proxy<T>(const T**, T*, size_t, size_t, \ | |||||
const int32_t*, \ | |||||
cudaStream_t); \ | cudaStream_t); \ | ||||
template void split_proxy<T>(const T*, T**, size_t, \ | |||||
template void split_proxy<T>(const T*, T**, size_t, \ | |||||
const int32_t*, const int32_t*, \ | const int32_t*, const int32_t*, \ | ||||
cudaStream_t); | cudaStream_t); | ||||
#define cb(DType) INST(typename DTypeTrait<DType>::ctype) | #define cb(DType) INST(typename DTypeTrait<DType>::ctype) | ||||
@@ -25,9 +25,8 @@ void split_proxy(const T* src, T** dsts, size_t total_size, | |||||
cudaStream_t stream); | cudaStream_t stream); | ||||
template <typename T> | template <typename T> | ||||
void concat_proxy(const T** srcs, T* dst, size_t total_size, | |||||
const int32_t* table_outer, | |||||
const int32_t* table_inner, cudaStream_t stream); | |||||
void concat_proxy(const T** srcs, T* dst, size_t srcs_size, size_t total_size, | |||||
const int32_t* offsets, cudaStream_t stream); | |||||
} // namespace param_pack | } // namespace param_pack | ||||
} // namespace cuda | } // namespace cuda | ||||
@@ -54,38 +54,40 @@ void ParamPackSplitImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in table, | |||||
} | } | ||||
template <typename T> | template <typename T> | ||||
void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, int32_t* table, | |||||
void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, | |||||
int32_t* offsets, | |||||
_megdnn_tensor_out dst, | _megdnn_tensor_out dst, | ||||
_megdnn_workspace) { | _megdnn_workspace) { | ||||
size_t out_size = dst.layout.total_nr_elems(); | |||||
auto srcs_ptr = static_cast<const T**>(srcs.raw_ptr); | auto srcs_ptr = static_cast<const T**>(srcs.raw_ptr); | ||||
auto dst_ptr = dst.ptr<T>(); | auto dst_ptr = dst.ptr<T>(); | ||||
auto table_outer = table, table_inner = table_outer + out_size; | |||||
for (size_t j = 0; j < out_size; j++) { | |||||
int32_t i = table_outer[j]; | |||||
int32_t idx = table_inner[j]; | |||||
if (idx != -1) | |||||
dst_ptr[j] = srcs_ptr[i][idx]; | |||||
else | |||||
dst_ptr[j] = 0; | |||||
int32_t last_pos = 0; | |||||
for (size_t i = 0; i < srcs.layout[0]; i++) { | |||||
int32_t begin = offsets[i * 2], end = offsets[i * 2 + 1]; | |||||
while (last_pos < begin) { | |||||
dst_ptr[last_pos] = 0; | |||||
last_pos++; | |||||
} | |||||
for (int32_t j = 0; j < end - begin; j++) { | |||||
dst_ptr[begin + j] = srcs_ptr[i][j]; | |||||
} | |||||
last_pos = end; | |||||
} | } | ||||
} | } | ||||
void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, _megdnn_tensor_in table, | |||||
void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, | |||||
_megdnn_tensor_in offsets, | |||||
_megdnn_tensor_out dst, | _megdnn_tensor_out dst, | ||||
_megdnn_workspace workspace) { | _megdnn_workspace workspace) { | ||||
check_exec(dst.layout, table.layout, srcs.layout); | |||||
auto table_ptr = table.ptr<int32_t>(); | |||||
check_exec(dst.layout, offsets.layout, srcs.layout); | |||||
auto offsets_ptr = offsets.ptr<int32_t>(); | |||||
#define cb(DType) \ | |||||
if (dst.layout.dtype == DType()) { \ | |||||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
exec_internal<ctype>(srcs, table_ptr, dst, workspace)); \ | |||||
return; \ | |||||
#define cb(DType) \ | |||||
if (dst.layout.dtype == DType()) { \ | |||||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
exec_internal<ctype>(srcs, offsets_ptr, dst, workspace)); \ | |||||
return; \ | |||||
} | } | ||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
megdnn_throw("bad type"); | megdnn_throw("bad type"); | ||||
@@ -1339,8 +1339,10 @@ void Concat::init_output_comp_node() { | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackConcat); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackConcat); | ||||
ParamPackConcat::ParamPackConcat(VarNodeArray& inp, VarNode* table, | ParamPackConcat::ParamPackConcat(VarNodeArray& inp, VarNode* table, | ||||
const std::vector<dt_int32> offsets_val, | |||||
const OperatorNodeConfig& config) | const OperatorNodeConfig& config) | ||||
: Super(inp[0]->owner_graph(), config, "ParamPackConcat", inp) { | |||||
: Super(inp[0]->owner_graph(), config, "ParamPackConcat", inp), | |||||
m_offsets(offsets_val) { | |||||
CompNode cn = inp[0]->comp_node(); | CompNode cn = inp[0]->comp_node(); | ||||
add_input({inp[0]}); | add_input({inp[0]}); | ||||
for (size_t i = 1; i < inp.size(); i++) { | for (size_t i = 1; i < inp.size(); i++) { | ||||
@@ -1361,14 +1363,16 @@ void ParamPackConcat::add_input_layout_constraint(){ | |||||
} | } | ||||
} | } | ||||
SymbolVar ParamPackConcat::make(const SmallVector<SymbolVar> &inp, | |||||
const SymbolVar &table, const OperatorNodeConfig& config) { | |||||
SymbolVar ParamPackConcat::make(const SmallVector<SymbolVar>& inp, | |||||
const SymbolVar& offsets, | |||||
const std::vector<dt_int32> offsets_val, | |||||
const OperatorNodeConfig& config) { | |||||
VarNodeArray array(inp.size()); | VarNodeArray array(inp.size()); | ||||
for (size_t i = 0; i < inp.size(); i++) { | for (size_t i = 0; i < inp.size(); i++) { | ||||
array[i] = inp[i].node(); | array[i] = inp[i].node(); | ||||
} | } | ||||
return inp.front(). | |||||
insert_single_output_opr<ParamPackConcat>(array, table.node(), config); | |||||
return inp.front().insert_single_output_opr<ParamPackConcat>( | |||||
array, offsets.node(), offsets_val, config); | |||||
} | } | ||||
void ParamPackConcat::scn_do_execute() { | void ParamPackConcat::scn_do_execute() { | ||||
@@ -1379,13 +1383,13 @@ void ParamPackConcat::scn_do_execute() { | |||||
for (size_t i = 0; i < inputs.size() - 1; i++) { | for (size_t i = 0; i < inputs.size() - 1; i++) { | ||||
ptr[i] = inputs[i]->dev_tensor().as_megdnn().raw_ptr; | ptr[i] = inputs[i]->dev_tensor().as_megdnn().raw_ptr; | ||||
} | } | ||||
auto table = inputs.back()->dev_tensor().as_megdnn(); | |||||
auto offsets = inputs.back()->dev_tensor().as_megdnn(); | |||||
megdnn::TensorND srcs( | megdnn::TensorND srcs( | ||||
ptr, megdnn::TensorLayout({inputs.size() - 1}, dtype::Int32())); | ptr, megdnn::TensorLayout({inputs.size() - 1}, dtype::Int32())); | ||||
auto&& dst = output(0)->dev_tensor().as_megdnn(); | auto&& dst = output(0)->dev_tensor().as_megdnn(); | ||||
m_opr->exec(srcs, table, dst, get_megdnn_workspace_from_var(output(1))); | |||||
m_opr->exec(srcs, offsets, dst, get_megdnn_workspace_from_var(output(1))); | |||||
} | } | ||||
void ParamPackConcat::init_output_dtype() { | void ParamPackConcat::init_output_dtype() { | ||||
@@ -1396,8 +1400,8 @@ void ParamPackConcat::init_output_static_infer_desc(){ | |||||
using namespace cg::static_infer; | using namespace cg::static_infer; | ||||
auto &&mgr = owner_graph()->static_infer_manager(); | auto &&mgr = owner_graph()->static_infer_manager(); | ||||
auto infer_out = [](TensorShape &dest, const InpVal &inp) { | |||||
dest = {inp.val.back().shape().total_nr_elems()/2}; | |||||
auto infer_out = [this](TensorShape &dest, const InpVal &inp) { | |||||
dest = {m_offsets.back()}; | |||||
return true; | return true; | ||||
}; | }; | ||||
DepVal shp_deps; | DepVal shp_deps; | ||||
@@ -1480,10 +1484,10 @@ void ParamPackSplit::init_output_dtype() { | |||||
} | } | ||||
void ParamPackSplit::mem_plan_fwd_in2out_readonly() { | void ParamPackSplit::mem_plan_fwd_in2out_readonly() { | ||||
mgb_assert(m_offsets.size() == output().size()); | |||||
mgb_assert(m_offsets.size() == output().size() * 2); | |||||
for (size_t i = 0; i < output().size(); i++) { | for (size_t i = 0; i < output().size(); i++) { | ||||
auto layout = output(i)->layout(); | auto layout = output(i)->layout(); | ||||
auto spec = SubTensorSpec::make_from_offset_elem(layout, m_offsets[i]); | |||||
auto spec = SubTensorSpec::make_from_offset_elem(layout, m_offsets[i * 2]); | |||||
m_mem_fwd_success[i] = output(i)->set_fwd_in2out_readonly( | m_mem_fwd_success[i] = output(i)->set_fwd_in2out_readonly( | ||||
input(0), spec); | input(0), spec); | ||||
mgb_assert(m_mem_fwd_success[i]); | mgb_assert(m_mem_fwd_success[i]); | ||||
@@ -1524,7 +1528,7 @@ MGB_IMPL_OPR_GRAD(ParamPackSplit) { | |||||
} | } | ||||
return ParamPackConcat::make( | return ParamPackConcat::make( | ||||
grad, opr.input(1), | |||||
grad, opr.input(1), opr.get_offsets(), | |||||
OperatorNodeConfig{}.follow_comp_node(opr.input(0))) | OperatorNodeConfig{}.follow_comp_node(opr.input(0))) | ||||
.node(); | .node(); | ||||
} | } | ||||
@@ -32,31 +32,6 @@ namespace serialization { | |||||
public OprMakerVariadic<opr::GetVarShape>{}; | public OprMakerVariadic<opr::GetVarShape>{}; | ||||
template<> | template<> | ||||
struct OprLoadDumpImpl<opr::ParamPackConcat, 0> | |||||
{ | |||||
using ParamPackConcat = opr::ParamPackConcat; | |||||
using Param = opr::ParamPackConcat::Param; | |||||
static void dump(OprDumpContext &ctx, | |||||
const cg::OperatorNodeBase &opr_) { | |||||
auto &&opr = opr_.cast_final_safe<ParamPackConcat>(); | |||||
ctx.write_param<Param>(opr.param()); | |||||
} | |||||
static cg::OperatorNodeBase* load( | |||||
OprLoadContext &ctx, const cg::VarNodeArray &inputs, | |||||
const OperatorNodeConfig &config) { | |||||
auto param = ctx.read_param<Param>(); | |||||
mgb_assert(!inputs.empty()); | |||||
SymbolVarArray ivar{inputs.size() - 1}; | |||||
for (size_t i = 0; i < inputs.size() - 1; ++ i) | |||||
ivar[i] = inputs[i]; | |||||
return ParamPackConcat::make(ivar, inputs.back(), | |||||
param, config).node()->owner_opr(); | |||||
} | |||||
}; | |||||
template<> | |||||
struct OprLoadDumpImpl<opr::Split, 0> { | struct OprLoadDumpImpl<opr::Split, 0> { | ||||
using Split = opr::Split; | using Split = opr::Split; | ||||
using Options = Split::Options; | using Options = Split::Options; | ||||
@@ -151,7 +126,6 @@ namespace opr { | |||||
MGB_SEREG_OPR(Dimshuffle, 1); | MGB_SEREG_OPR(Dimshuffle, 1); | ||||
MGB_SEREG_OPR(AxisAddRemove, 1); | MGB_SEREG_OPR(AxisAddRemove, 1); | ||||
MGB_SEREG_OPR(Concat, 0); | MGB_SEREG_OPR(Concat, 0); | ||||
MGB_SEREG_OPR(ParamPackConcat, 0); | |||||
using GetVarShapeV1 = opr::GetVarShape; | using GetVarShapeV1 = opr::GetVarShape; | ||||
MGB_SEREG_OPR(GetVarShapeV1, 0); | MGB_SEREG_OPR(GetVarShapeV1, 0); | ||||
using ReshapeV1 = opr::Reshape; | using ReshapeV1 = opr::Reshape; | ||||
@@ -193,6 +167,22 @@ namespace opr { | |||||
} | } | ||||
MGB_REG_OPR_SHALLOW_COPY(ParamPackSplit, opr_shallow_copy_param_pack_split); | MGB_REG_OPR_SHALLOW_COPY(ParamPackSplit, opr_shallow_copy_param_pack_split); | ||||
cg::OperatorNodeBase* opr_shallow_copy_param_pack_concat( | |||||
const serialization::OprShallowCopyContext &ctx, | |||||
const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, | |||||
const OperatorNodeConfig &config){ | |||||
auto &&opr = opr_.cast_final_safe<ParamPackConcat>(); | |||||
auto &&offsets = opr.get_offsets(); | |||||
SymbolVarArray ivar{inputs.size() - 1}; | |||||
for (size_t i = 0; i < inputs.size() - 1; ++i) | |||||
ivar[i] = inputs[i]; | |||||
return ParamPackConcat::make(ivar, inputs.back(), offsets, config). | |||||
node()->owner_opr(); | |||||
} | |||||
MGB_REG_OPR_SHALLOW_COPY(ParamPackConcat, opr_shallow_copy_param_pack_concat); | |||||
MGB_SEREG_OPR(RelayoutFormat, 1); | MGB_SEREG_OPR(RelayoutFormat, 1); | ||||
MGB_SEREG_OPR(WinogradFilterPreprocess, 1); | MGB_SEREG_OPR(WinogradFilterPreprocess, 1); | ||||
} // namespace opr | } // namespace opr | ||||
@@ -539,6 +539,7 @@ MGB_DEFINE_OPR_CLASS(Concat, cg::SingleCNOutshapePureByInshapeOprBase) // { | |||||
MGB_DEFINE_OPR_CLASS(ParamPackConcat, cg::SingleCNOperatorNodeBase) // { | MGB_DEFINE_OPR_CLASS(ParamPackConcat, cg::SingleCNOperatorNodeBase) // { | ||||
//! input pointer buffer | //! input pointer buffer | ||||
SmallVector<void*> m_inp_ptr; | SmallVector<void*> m_inp_ptr; | ||||
std::vector<dt_int32> m_offsets; | |||||
intl::UniqPtrWithCN<megdnn::ParamPackConcat> m_opr; | intl::UniqPtrWithCN<megdnn::ParamPackConcat> m_opr; | ||||
void add_input_layout_constraint() override; | void add_input_layout_constraint() override; | ||||
@@ -554,15 +555,23 @@ public: | |||||
return {}; | return {}; | ||||
} | } | ||||
ParamPackConcat(VarNodeArray &inp, VarNode *table, | |||||
const OperatorNodeConfig &config); | |||||
static SymbolVar make(const SmallVector<SymbolVar> &inp, | |||||
const SymbolVar &table, const OperatorNodeConfig &config = {}); | |||||
ParamPackConcat(VarNodeArray& inp, VarNode* offsets, | |||||
const std::vector<dt_int32> offsets_val, | |||||
const OperatorNodeConfig& config); | |||||
static SymbolVar make(const SmallVector<SymbolVar>& inp, | |||||
const SymbolVar& offsets, | |||||
const std::vector<dt_int32> offsets_val, | |||||
const OperatorNodeConfig& config = {}); | |||||
static SymbolVar make(const SmallVector<SymbolVar>& inp, | |||||
const SymbolVar& offsets, | |||||
const std::vector<dt_int32> offsets_val, const Param&, | |||||
const OperatorNodeConfig& config) { | |||||
return make(inp, offsets, offsets_val, config); | |||||
} | |||||
static SymbolVar make(const SmallVector<SymbolVar> &inp, | |||||
const SymbolVar &table, const Param &, | |||||
const OperatorNodeConfig &config) { | |||||
return make(inp, table, config); | |||||
const std::vector<dt_int32>& get_offsets() const { | |||||
return m_offsets; | |||||
} | } | ||||
}; | }; | ||||
@@ -1906,7 +1906,7 @@ void test_param_pack_concat(const TensorShapeArray &shapes, DType type){ | |||||
memcpy(host_table->raw_ptr(), host_table_gen.data(), size * 8); | memcpy(host_table->raw_ptr(), host_table_gen.data(), size * 8); | ||||
auto table = opr::Host2DeviceCopy::make(*graph, host_table); | auto table = opr::Host2DeviceCopy::make(*graph, host_table); | ||||
auto z = opr::ParamPackConcat::make(srcs, table); | |||||
auto z = opr::ParamPackConcat::make(srcs, table, host_table_gen); | |||||
HostTensorND host_z; | HostTensorND host_z; | ||||
auto func = graph->compile({make_callback_copy(z, host_z)}); | auto func = graph->compile({make_callback_copy(z, host_z)}); | ||||