GitOrigin-RevId: a802a14e8d
tags/v0.4.0
@@ -15,18 +15,16 @@ | |||
using namespace megdnn; | |||
void ParamPackConcatSplitBase::check_exec(const TensorLayout& concated, | |||
const TensorLayout& table, | |||
const TensorLayout& offsets, | |||
const TensorLayout& parts) { | |||
megdnn_assert(table.dtype == dtype::Int32{}, "bad dtype: %s", | |||
table.dtype.name()); | |||
megdnn_assert(concated.ndim == 1 && table.ndim == 1 && parts.ndim == 1 && | |||
concated.stride[0] == 1 && table.stride[0] == 1 && | |||
megdnn_assert(offsets.dtype == dtype::Int32{}, "bad dtype: %s", | |||
offsets.dtype.name()); | |||
megdnn_assert(concated.ndim == 1 && offsets.ndim == 1 && parts.ndim == 1 && | |||
concated.stride[0] == 1 && offsets.stride[0] == 1 && | |||
parts.stride[0] == 1, | |||
"bad layout: concated=%s table=%s parts=%s", | |||
concated.to_string().c_str(), table.to_string().c_str(), | |||
"bad layout: concated=%s offsets=%s parts=%s", | |||
concated.to_string().c_str(), offsets.to_string().c_str(), | |||
parts.to_string().c_str()); | |||
megdnn_assert(table.shape[0] == concated.shape[0] * 2, | |||
"concated=%zu table=%zu", concated.shape[0], table.shape[0]); | |||
} | |||
std::vector<dt_int32> ParamPackConcatSplitBase::gen_offsets( | |||
@@ -46,11 +44,13 @@ std::vector<dt_int32> ParamPackConcatSplitBase::gen_offsets( | |||
return v + ((alignment - mod) & (alignment - 1)); | |||
}; | |||
std::vector<dt_int32> offsets(shapes.size()); | |||
std::vector<dt_int32> offsets(shapes.size() << 1); | |||
size_t offset = 0; | |||
for (size_t i = 0; i < shapes.size(); i++) { | |||
offsets[i] = offset; | |||
offset = get_aligned(offset) + shapes[i].total_nr_elems(); | |||
offset = get_aligned(offset); | |||
offsets[i * 2] = offset; | |||
offset += shapes[i].total_nr_elems(); | |||
offsets[i * 2 + 1] = offset; | |||
} | |||
return offsets; | |||
} | |||
@@ -24,7 +24,7 @@ size_t ParamPackConcatImpl::get_workspace_in_bytes(const TensorShapeArray& srcs, | |||
template <typename T> | |||
void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, | |||
_megdnn_tensor_in table, | |||
_megdnn_tensor_in offsets, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) { | |||
size_t inp_size = srcs.layout.shape[0], | |||
@@ -35,25 +35,25 @@ void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, | |||
megdnn_assert_internal(src_cpu); | |||
auto src_gpu = reinterpret_cast<const T**>(workspace.raw_ptr); | |||
auto table_outer_gpu = table.ptr<int32_t>(), | |||
table_inner_gpu = table_outer_gpu + out_size; | |||
auto offsets_gpu = offsets.ptr<int32_t>(); | |||
cuda_check(cudaMemcpyAsync(src_gpu, src_cpu, sizeof(const T*) * inp_size, | |||
cudaMemcpyHostToDevice, stream)); | |||
param_pack::concat_proxy<T>(src_gpu, dst.ptr<T>(), out_size, | |||
table_outer_gpu, table_inner_gpu, stream); | |||
param_pack::concat_proxy<T>(src_gpu, dst.ptr<T>(), inp_size, out_size, | |||
offsets_gpu, stream); | |||
} | |||
void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, _megdnn_tensor_in table, | |||
void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, | |||
_megdnn_tensor_in offsets, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) { | |||
check_exec(dst.layout, table.layout, srcs.layout); | |||
#define cb(DType) \ | |||
if (dst.layout.dtype == DType()) { \ | |||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||
exec_internal<ctype>(srcs, table, dst, workspace); \ | |||
return; \ | |||
check_exec(dst.layout, offsets.layout, srcs.layout); | |||
#define cb(DType) \ | |||
if (dst.layout.dtype == DType()) { \ | |||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||
exec_internal<ctype>(srcs, offsets, dst, workspace); \ | |||
return; \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
megdnn_throw("bad type"); | |||
@@ -19,17 +19,24 @@ namespace param_pack { | |||
template <typename T> | |||
__global__ void concat_kernel(const T** srcs, T* dst, | |||
const int32_t* table_outer, | |||
const int32_t* table_inner, | |||
const int32_t* offsets, | |||
size_t srcs_size, | |||
size_t total_size) { | |||
size_t addr = threadIdx.x + blockIdx.x * blockDim.x; | |||
if (addr < total_size) { | |||
int32_t i = table_outer[addr]; | |||
int32_t idx = table_inner[addr]; | |||
if (idx != -1) | |||
dst[addr] = srcs[i][idx]; | |||
else | |||
size_t l = 0, r = srcs_size - 1, mid; | |||
while (l < r) { | |||
mid = (l + r) >> 1; | |||
if (offsets[(mid << 1) + 1] > addr) { | |||
r = mid; | |||
} else { | |||
l = mid + 1; | |||
} | |||
} | |||
if (addr < offsets[l << 1]) | |||
dst[addr] = 0; | |||
else | |||
dst[addr] = srcs[l][addr - offsets[l << 1]]; | |||
} | |||
} | |||
@@ -59,20 +66,20 @@ void split_proxy(const T* src, T** dsts, size_t total_size, | |||
} | |||
template <typename T> | |||
void concat_proxy(const T** srcs, T* dst, size_t total_size, | |||
const int32_t* table_outer, | |||
const int32_t* table_inner, cudaStream_t stream) { | |||
void concat_proxy(const T** srcs, T* dst, size_t srcs_size, size_t total_size, | |||
const int32_t* offsets, | |||
cudaStream_t stream) { | |||
size_t NR_BLOCKS = DIVUP(total_size, NR_THREADS); | |||
concat_kernel<<<NR_BLOCKS, NR_THREADS, 0, stream>>>( | |||
srcs, dst, table_outer, table_inner, total_size); | |||
srcs, dst, offsets, srcs_size, total_size); | |||
after_kernel_launch(); | |||
} | |||
#define INST(T) \ | |||
template void concat_proxy<T>(const T**, T*, size_t, \ | |||
const int32_t*, const int32_t*, \ | |||
template void concat_proxy<T>(const T**, T*, size_t, size_t, \ | |||
const int32_t*, \ | |||
cudaStream_t); \ | |||
template void split_proxy<T>(const T*, T**, size_t, \ | |||
template void split_proxy<T>(const T*, T**, size_t, \ | |||
const int32_t*, const int32_t*, \ | |||
cudaStream_t); | |||
#define cb(DType) INST(typename DTypeTrait<DType>::ctype) | |||
@@ -25,9 +25,8 @@ void split_proxy(const T* src, T** dsts, size_t total_size, | |||
cudaStream_t stream); | |||
template <typename T> | |||
void concat_proxy(const T** srcs, T* dst, size_t total_size, | |||
const int32_t* table_outer, | |||
const int32_t* table_inner, cudaStream_t stream); | |||
void concat_proxy(const T** srcs, T* dst, size_t srcs_size, size_t total_size, | |||
const int32_t* offsets, cudaStream_t stream); | |||
} // namespace param_pack | |||
} // namespace cuda | |||
@@ -54,38 +54,40 @@ void ParamPackSplitImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in table, | |||
} | |||
template <typename T> | |||
void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, int32_t* table, | |||
void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, | |||
int32_t* offsets, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace) { | |||
size_t out_size = dst.layout.total_nr_elems(); | |||
auto srcs_ptr = static_cast<const T**>(srcs.raw_ptr); | |||
auto dst_ptr = dst.ptr<T>(); | |||
auto table_outer = table, table_inner = table_outer + out_size; | |||
for (size_t j = 0; j < out_size; j++) { | |||
int32_t i = table_outer[j]; | |||
int32_t idx = table_inner[j]; | |||
if (idx != -1) | |||
dst_ptr[j] = srcs_ptr[i][idx]; | |||
else | |||
dst_ptr[j] = 0; | |||
int32_t last_pos = 0; | |||
for (size_t i = 0; i < srcs.layout[0]; i++) { | |||
int32_t begin = offsets[i * 2], end = offsets[i * 2 + 1]; | |||
while (last_pos < begin) { | |||
dst_ptr[last_pos] = 0; | |||
last_pos++; | |||
} | |||
for (int32_t j = 0; j < end - begin; j++) { | |||
dst_ptr[begin + j] = srcs_ptr[i][j]; | |||
} | |||
last_pos = end; | |||
} | |||
} | |||
void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, _megdnn_tensor_in table, | |||
void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, | |||
_megdnn_tensor_in offsets, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) { | |||
check_exec(dst.layout, table.layout, srcs.layout); | |||
auto table_ptr = table.ptr<int32_t>(); | |||
check_exec(dst.layout, offsets.layout, srcs.layout); | |||
auto offsets_ptr = offsets.ptr<int32_t>(); | |||
#define cb(DType) \ | |||
if (dst.layout.dtype == DType()) { \ | |||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
exec_internal<ctype>(srcs, table_ptr, dst, workspace)); \ | |||
return; \ | |||
#define cb(DType) \ | |||
if (dst.layout.dtype == DType()) { \ | |||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
exec_internal<ctype>(srcs, offsets_ptr, dst, workspace)); \ | |||
return; \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
megdnn_throw("bad type"); | |||
@@ -1339,8 +1339,10 @@ void Concat::init_output_comp_node() { | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackConcat); | |||
ParamPackConcat::ParamPackConcat(VarNodeArray& inp, VarNode* table, | |||
const std::vector<dt_int32> offsets_val, | |||
const OperatorNodeConfig& config) | |||
: Super(inp[0]->owner_graph(), config, "ParamPackConcat", inp) { | |||
: Super(inp[0]->owner_graph(), config, "ParamPackConcat", inp), | |||
m_offsets(offsets_val) { | |||
CompNode cn = inp[0]->comp_node(); | |||
add_input({inp[0]}); | |||
for (size_t i = 1; i < inp.size(); i++) { | |||
@@ -1361,14 +1363,16 @@ void ParamPackConcat::add_input_layout_constraint(){ | |||
} | |||
} | |||
SymbolVar ParamPackConcat::make(const SmallVector<SymbolVar> &inp, | |||
const SymbolVar &table, const OperatorNodeConfig& config) { | |||
SymbolVar ParamPackConcat::make(const SmallVector<SymbolVar>& inp, | |||
const SymbolVar& offsets, | |||
const std::vector<dt_int32> offsets_val, | |||
const OperatorNodeConfig& config) { | |||
VarNodeArray array(inp.size()); | |||
for (size_t i = 0; i < inp.size(); i++) { | |||
array[i] = inp[i].node(); | |||
} | |||
return inp.front(). | |||
insert_single_output_opr<ParamPackConcat>(array, table.node(), config); | |||
return inp.front().insert_single_output_opr<ParamPackConcat>( | |||
array, offsets.node(), offsets_val, config); | |||
} | |||
void ParamPackConcat::scn_do_execute() { | |||
@@ -1379,13 +1383,13 @@ void ParamPackConcat::scn_do_execute() { | |||
for (size_t i = 0; i < inputs.size() - 1; i++) { | |||
ptr[i] = inputs[i]->dev_tensor().as_megdnn().raw_ptr; | |||
} | |||
auto table = inputs.back()->dev_tensor().as_megdnn(); | |||
auto offsets = inputs.back()->dev_tensor().as_megdnn(); | |||
megdnn::TensorND srcs( | |||
ptr, megdnn::TensorLayout({inputs.size() - 1}, dtype::Int32())); | |||
auto&& dst = output(0)->dev_tensor().as_megdnn(); | |||
m_opr->exec(srcs, table, dst, get_megdnn_workspace_from_var(output(1))); | |||
m_opr->exec(srcs, offsets, dst, get_megdnn_workspace_from_var(output(1))); | |||
} | |||
void ParamPackConcat::init_output_dtype() { | |||
@@ -1396,8 +1400,8 @@ void ParamPackConcat::init_output_static_infer_desc(){ | |||
using namespace cg::static_infer; | |||
auto &&mgr = owner_graph()->static_infer_manager(); | |||
auto infer_out = [](TensorShape &dest, const InpVal &inp) { | |||
dest = {inp.val.back().shape().total_nr_elems()/2}; | |||
auto infer_out = [this](TensorShape &dest, const InpVal &inp) { | |||
dest = {m_offsets.back()}; | |||
return true; | |||
}; | |||
DepVal shp_deps; | |||
@@ -1480,10 +1484,10 @@ void ParamPackSplit::init_output_dtype() { | |||
} | |||
void ParamPackSplit::mem_plan_fwd_in2out_readonly() { | |||
mgb_assert(m_offsets.size() == output().size()); | |||
mgb_assert(m_offsets.size() == output().size() * 2); | |||
for (size_t i = 0; i < output().size(); i++) { | |||
auto layout = output(i)->layout(); | |||
auto spec = SubTensorSpec::make_from_offset_elem(layout, m_offsets[i]); | |||
auto spec = SubTensorSpec::make_from_offset_elem(layout, m_offsets[i * 2]); | |||
m_mem_fwd_success[i] = output(i)->set_fwd_in2out_readonly( | |||
input(0), spec); | |||
mgb_assert(m_mem_fwd_success[i]); | |||
@@ -1524,7 +1528,7 @@ MGB_IMPL_OPR_GRAD(ParamPackSplit) { | |||
} | |||
return ParamPackConcat::make( | |||
grad, opr.input(1), | |||
grad, opr.input(1), opr.get_offsets(), | |||
OperatorNodeConfig{}.follow_comp_node(opr.input(0))) | |||
.node(); | |||
} | |||
@@ -32,31 +32,6 @@ namespace serialization { | |||
public OprMakerVariadic<opr::GetVarShape>{}; | |||
template<> | |||
struct OprLoadDumpImpl<opr::ParamPackConcat, 0> | |||
{ | |||
using ParamPackConcat = opr::ParamPackConcat; | |||
using Param = opr::ParamPackConcat::Param; | |||
static void dump(OprDumpContext &ctx, | |||
const cg::OperatorNodeBase &opr_) { | |||
auto &&opr = opr_.cast_final_safe<ParamPackConcat>(); | |||
ctx.write_param<Param>(opr.param()); | |||
} | |||
static cg::OperatorNodeBase* load( | |||
OprLoadContext &ctx, const cg::VarNodeArray &inputs, | |||
const OperatorNodeConfig &config) { | |||
auto param = ctx.read_param<Param>(); | |||
mgb_assert(!inputs.empty()); | |||
SymbolVarArray ivar{inputs.size() - 1}; | |||
for (size_t i = 0; i < inputs.size() - 1; ++ i) | |||
ivar[i] = inputs[i]; | |||
return ParamPackConcat::make(ivar, inputs.back(), | |||
param, config).node()->owner_opr(); | |||
} | |||
}; | |||
template<> | |||
struct OprLoadDumpImpl<opr::Split, 0> { | |||
using Split = opr::Split; | |||
using Options = Split::Options; | |||
@@ -151,7 +126,6 @@ namespace opr { | |||
MGB_SEREG_OPR(Dimshuffle, 1); | |||
MGB_SEREG_OPR(AxisAddRemove, 1); | |||
MGB_SEREG_OPR(Concat, 0); | |||
MGB_SEREG_OPR(ParamPackConcat, 0); | |||
using GetVarShapeV1 = opr::GetVarShape; | |||
MGB_SEREG_OPR(GetVarShapeV1, 0); | |||
using ReshapeV1 = opr::Reshape; | |||
@@ -193,6 +167,22 @@ namespace opr { | |||
} | |||
MGB_REG_OPR_SHALLOW_COPY(ParamPackSplit, opr_shallow_copy_param_pack_split); | |||
cg::OperatorNodeBase* opr_shallow_copy_param_pack_concat( | |||
const serialization::OprShallowCopyContext &ctx, | |||
const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, | |||
const OperatorNodeConfig &config){ | |||
auto &&opr = opr_.cast_final_safe<ParamPackConcat>(); | |||
auto &&offsets = opr.get_offsets(); | |||
SymbolVarArray ivar{inputs.size() - 1}; | |||
for (size_t i = 0; i < inputs.size() - 1; ++i) | |||
ivar[i] = inputs[i]; | |||
return ParamPackConcat::make(ivar, inputs.back(), offsets, config). | |||
node()->owner_opr(); | |||
} | |||
MGB_REG_OPR_SHALLOW_COPY(ParamPackConcat, opr_shallow_copy_param_pack_concat); | |||
MGB_SEREG_OPR(RelayoutFormat, 1); | |||
MGB_SEREG_OPR(WinogradFilterPreprocess, 1); | |||
} // namespace opr | |||
@@ -539,6 +539,7 @@ MGB_DEFINE_OPR_CLASS(Concat, cg::SingleCNOutshapePureByInshapeOprBase) // { | |||
MGB_DEFINE_OPR_CLASS(ParamPackConcat, cg::SingleCNOperatorNodeBase) // { | |||
//! input pointer buffer | |||
SmallVector<void*> m_inp_ptr; | |||
std::vector<dt_int32> m_offsets; | |||
intl::UniqPtrWithCN<megdnn::ParamPackConcat> m_opr; | |||
void add_input_layout_constraint() override; | |||
@@ -554,15 +555,23 @@ public: | |||
return {}; | |||
} | |||
ParamPackConcat(VarNodeArray &inp, VarNode *table, | |||
const OperatorNodeConfig &config); | |||
static SymbolVar make(const SmallVector<SymbolVar> &inp, | |||
const SymbolVar &table, const OperatorNodeConfig &config = {}); | |||
ParamPackConcat(VarNodeArray& inp, VarNode* offsets, | |||
const std::vector<dt_int32> offsets_val, | |||
const OperatorNodeConfig& config); | |||
static SymbolVar make(const SmallVector<SymbolVar>& inp, | |||
const SymbolVar& offsets, | |||
const std::vector<dt_int32> offsets_val, | |||
const OperatorNodeConfig& config = {}); | |||
static SymbolVar make(const SmallVector<SymbolVar>& inp, | |||
const SymbolVar& offsets, | |||
const std::vector<dt_int32> offsets_val, const Param&, | |||
const OperatorNodeConfig& config) { | |||
return make(inp, offsets, offsets_val, config); | |||
} | |||
static SymbolVar make(const SmallVector<SymbolVar> &inp, | |||
const SymbolVar &table, const Param &, | |||
const OperatorNodeConfig &config) { | |||
return make(inp, table, config); | |||
const std::vector<dt_int32>& get_offsets() const { | |||
return m_offsets; | |||
} | |||
}; | |||
@@ -1906,7 +1906,7 @@ void test_param_pack_concat(const TensorShapeArray &shapes, DType type){ | |||
memcpy(host_table->raw_ptr(), host_table_gen.data(), size * 8); | |||
auto table = opr::Host2DeviceCopy::make(*graph, host_table); | |||
auto z = opr::ParamPackConcat::make(srcs, table); | |||
auto z = opr::ParamPackConcat::make(srcs, table, host_table_gen); | |||
HostTensorND host_z; | |||
auto func = graph->compile({make_callback_copy(z, host_z)}); | |||