GitOrigin-RevId: ed00341d58
tags/v0.4.0
@@ -500,47 +500,20 @@ public: | |||
/* | |||
* \param[in] srcs: TensorND on cpu. srcs[i] corresponding to the | |||
* address of i-th Tensor. | |||
* \param[in] table: with size `2 * srcs.nr_total_elems()`. | |||
* table[addr] corresponding to outer_idx, | |||
* table[addr+srcs.nr_total_elems()] corresponding to | |||
* inner_idx of dsts. | |||
* \param[in] offsets: with size `2 * srcs.shape[0]`. | |||
* offsets[i * 2] and offsets[i * 2 + 1] means | |||
* the begin and the end of offset in | |||
* \param[out] dst: output TensorND, live on cpu or gpu | |||
*/ | |||
virtual void exec(_megdnn_tensor_in srcs, _megdnn_tensor_in table, | |||
virtual void exec(_megdnn_tensor_in srcs, _megdnn_tensor_in offsets, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes(const TensorShapeArray& srcs, | |||
const TensorShape& table, | |||
const TensorShape& offsets, | |||
const TensorShape& dst) = 0; | |||
}; | |||
/** | |||
* \brief ParamPackSplit, used for network forwarding. | |||
* Split a single large param into several small tensors, use copy stategy | |||
* either. | |||
*/ | |||
class ParamPackSplit: public ParamPackConcatSplitBase { | |||
DEF_OPR_IMPL(ParamPackSplit, ParamPackConcatSplitBase, 2, 1); | |||
public: | |||
/* | |||
* \param[in] src: input TensorND, live on cpu or gpu | |||
* \param[in] table: with size `2 * srcs.nr_total_elems()`. | |||
* table[addr] corresponding to outer_idx, | |||
* table[addr+srcs.nr_total_elems()] corresponding to | |||
* inner_idx of dsts. | |||
* \param[out] dsts: TensorND on cpu. dsts[i] corresponding to the address | |||
* of i-th Tensor | |||
*/ | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in table, | |||
_megdnn_tensor_out dsts, _megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes(const TensorShape& src, | |||
const TensorShape& table, | |||
const TensorShapeArray& dsts) = 0; | |||
}; | |||
/** | |||
* \brief base class for Tile and Repeat | |||
*/ | |||
class TileRepeatBase: public OperatorBase { | |||
@@ -167,7 +167,6 @@ private: | |||
cb(Resize) \ | |||
cb(ResizeBackward) \ | |||
cb(ParamPackConcat) \ | |||
cb(ParamPackSplit) \ | |||
cb(MaxTensorDiff) \ | |||
cb(MaskConvForward) \ | |||
cb(MaskPropagate) \ | |||
@@ -48,9 +48,9 @@ std::vector<dt_int32> ParamPackConcatSplitBase::gen_offsets( | |||
size_t offset = 0; | |||
for (size_t i = 0; i < shapes.size(); i++) { | |||
offset = get_aligned(offset); | |||
offsets[i * 2] = offset; | |||
offsets[i << 1] = offset; | |||
offset += shapes[i].total_nr_elems(); | |||
offsets[i * 2 + 1] = offset; | |||
offsets[(i << 1) + 1] = offset; | |||
} | |||
return offsets; | |||
} | |||
@@ -60,56 +60,5 @@ void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, | |||
#undef cb | |||
} | |||
size_t ParamPackSplitImpl::get_workspace_in_bytes( | |||
const TensorShape&, const TensorShape&, const TensorShapeArray& dsts) { | |||
return sizeof(size_t) * dsts.size(); | |||
} | |||
template <typename T> | |||
void ParamPackSplitImpl::exec_internal(_megdnn_tensor_in src, | |||
_megdnn_tensor_in table, | |||
_megdnn_tensor_out dsts, | |||
_megdnn_workspace workspace) { | |||
// inner and outer table must be int32 | |||
megdnn_assert(table.layout.dtype == dtype::Int32()); | |||
// dsts is src pointer, ndim must be 1 | |||
megdnn_assert(dsts.layout.ndim == 1); | |||
auto out_size = dsts.layout.shape[0], | |||
inp_size = src.layout.total_nr_elems(); | |||
auto stream = cuda_stream(this->handle()); | |||
auto total_workspace_size = sizeof(T*) * out_size; | |||
auto dsts_cpu = static_cast<T**>(dsts.raw_ptr); | |||
megdnn_assert_internal(dsts_cpu); | |||
auto dsts_gpu = reinterpret_cast<T**>(workspace.raw_ptr); | |||
auto table_outer_gpu = table.ptr<int32_t>(); | |||
auto table_inner_gpu = table_outer_gpu + inp_size; | |||
cuda_check(cudaMemcpyAsync(dsts_gpu, dsts_cpu, total_workspace_size, | |||
cudaMemcpyHostToDevice, stream)); | |||
// param_pack_split_proxy() | |||
param_pack::split_proxy<T>(src.ptr<T>(), dsts_gpu, inp_size, | |||
table_outer_gpu, table_inner_gpu, stream); | |||
} | |||
void ParamPackSplitImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in table, | |||
_megdnn_tensor_out dsts, | |||
_megdnn_workspace workspace) { | |||
check_exec(src.layout, table.layout, dsts.layout); | |||
#define cb(DType) \ | |||
if (src.layout.dtype == DType()) { \ | |||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||
exec_internal<ctype>(src, table, dsts, workspace); \ | |||
return; \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
megdnn_throw("bad type"); | |||
#undef cb | |||
} | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -31,21 +31,5 @@ private: | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace); | |||
}; | |||
class ParamPackSplitImpl final : public ParamPackSplit { | |||
public: | |||
using ParamPackSplit::ParamPackSplit; | |||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in table, | |||
_megdnn_tensor_out dsts, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorShape& src, | |||
const TensorShape& table, | |||
const TensorShapeArray& dsts) override; | |||
private: | |||
template <typename T> | |||
void exec_internal(_megdnn_tensor_in src, _megdnn_tensor_in table, | |||
_megdnn_tensor_out dsts, _megdnn_workspace workspace); | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -41,31 +41,6 @@ __global__ void concat_kernel(const T** srcs, T* dst, | |||
} | |||
template <typename T> | |||
__global__ void split_kernel(const T* src, T** dsts, | |||
const int32_t* table_outer, | |||
const int32_t* table_inner, | |||
size_t total_size) { | |||
size_t addr = threadIdx.x + blockIdx.x * blockDim.x; | |||
if (addr < total_size) { | |||
int32_t i = table_outer[addr]; | |||
int32_t idx = table_inner[addr]; | |||
if (idx != -1) { | |||
dsts[i][idx] = src[addr]; | |||
} | |||
} | |||
} | |||
template <typename T> | |||
void split_proxy(const T* src, T** dsts, size_t total_size, | |||
const int32_t* table_outer, const int32_t* table_inner, | |||
cudaStream_t stream) { | |||
size_t NR_BLOCKS = DIVUP(total_size, NR_THREADS); | |||
split_kernel<<<NR_BLOCKS, NR_THREADS, 0, stream>>>( | |||
src, dsts, table_outer, table_inner, total_size); | |||
after_kernel_launch(); | |||
} | |||
template <typename T> | |||
void concat_proxy(const T** srcs, T* dst, size_t srcs_size, size_t total_size, | |||
const int32_t* offsets, | |||
cudaStream_t stream) { | |||
@@ -78,10 +53,7 @@ void concat_proxy(const T** srcs, T* dst, size_t srcs_size, size_t total_size, | |||
#define INST(T) \ | |||
template void concat_proxy<T>(const T**, T*, size_t, size_t, \ | |||
const int32_t*, \ | |||
cudaStream_t); \ | |||
template void split_proxy<T>(const T*, T**, size_t, \ | |||
const int32_t*, const int32_t*, \ | |||
cudaStream_t); | |||
cudaStream_t); | |||
#define cb(DType) INST(typename DTypeTrait<DType>::ctype) | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
#undef cb | |||
@@ -20,11 +20,6 @@ namespace cuda { | |||
namespace param_pack { | |||
template <typename T> | |||
void split_proxy(const T* src, T** dsts, size_t total_size, | |||
const int32_t* table_outer, const int32_t* table_inner, | |||
cudaStream_t stream); | |||
template <typename T> | |||
void concat_proxy(const T** srcs, T* dst, size_t srcs_size, size_t total_size, | |||
const int32_t* offsets, cudaStream_t stream); | |||
@@ -17,43 +17,6 @@ using namespace megdnn; | |||
using namespace naive; | |||
template <typename T> | |||
void ParamPackSplitImpl::exec_internal(_megdnn_tensor_in src, int32_t* table, | |||
_megdnn_tensor_out dsts, | |||
_megdnn_workspace) { | |||
auto dsts_ptr = static_cast<T**>(dsts.raw_ptr); | |||
auto src_ptr = src.ptr<T>(); | |||
auto inp_size = src.layout.total_nr_elems(); | |||
auto table_outer = table, table_inner = table_outer + inp_size; | |||
for (size_t j = 0; j < inp_size; j++) { | |||
int32_t i = table_outer[j]; | |||
int32_t idx = table_inner[j]; | |||
if (idx != -1) { | |||
dsts_ptr[i][idx] = src_ptr[j]; | |||
} | |||
} | |||
} | |||
void ParamPackSplitImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in table, | |||
_megdnn_tensor_out dsts, | |||
_megdnn_workspace workspace) { | |||
check_exec(src.layout, table.layout, dsts.layout); | |||
auto table_ptr = table.ptr<int32_t>(); | |||
#define cb(DType) \ | |||
if (src.layout.dtype == DType()) { \ | |||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
exec_internal<ctype>(src, table_ptr, dsts, workspace)); \ | |||
return; \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
megdnn_throw("bad type"); | |||
#undef cb | |||
} | |||
template <typename T> | |||
void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, | |||
int32_t* offsets, | |||
_megdnn_tensor_out dst, | |||
@@ -13,27 +13,10 @@ | |||
namespace megdnn { | |||
namespace naive { | |||
class ParamPackSplitImpl final : public ParamPackSplit { | |||
public: | |||
using ParamPackSplit::ParamPackSplit; | |||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in table, | |||
_megdnn_tensor_out dsts, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorShape&, const TensorShape&, | |||
const TensorShapeArray&) override { | |||
return 0; | |||
} | |||
private: | |||
template <typename T> | |||
void exec_internal(_megdnn_tensor_in src, int32_t* table, | |||
_megdnn_tensor_out dsts, _megdnn_workspace workspace); | |||
}; | |||
class ParamPackConcatImpl final : public ParamPackConcat { | |||
public: | |||
using ParamPackConcat::ParamPackConcat; | |||
void exec(_megdnn_tensor_in srcs, _megdnn_tensor_in table, | |||
void exec(_megdnn_tensor_in srcs, _megdnn_tensor_in offsets, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorShapeArray&, const TensorShape&, | |||
@@ -43,7 +26,7 @@ public: | |||
private: | |||
template <typename T> | |||
void exec_internal(_megdnn_tensor_in srcs, int32_t* table, | |||
void exec_internal(_megdnn_tensor_in srcs, int32_t* offsets, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace); | |||
}; | |||
@@ -18,56 +18,38 @@ using namespace test; | |||
namespace { | |||
template<class T> | |||
std::vector<int32_t> create_table(const TensorShapeArray& shapes, | |||
size_t align) { | |||
std::vector<int32_t> create_offsets(const TensorShapeArray& shapes, | |||
size_t alignment) { | |||
size_t dtype_size = sizeof(T); | |||
if (align < dtype_size) | |||
align = dtype_size; | |||
if (alignment < dtype_size) | |||
alignment = dtype_size; | |||
alignment /= dtype_size; | |||
align /= dtype_size; | |||
auto get_aligned = [alignment](size_t v) { | |||
auto mod = v & (alignment - 1); | |||
return v + ((alignment - mod) & (alignment - 1)); | |||
}; | |||
size_t offset = shapes[0].total_nr_elems(); | |||
for (size_t i = 1; i < shapes.size(); i++) { | |||
auto d = offset & (align - 1); | |||
offset += (align - d) & (align - 1); | |||
offset += shapes[i].total_nr_elems(); | |||
} | |||
std::vector<int32_t> table(offset * 2); | |||
int32_t* outer_table = table.data(); | |||
int32_t* inner_table = outer_table + offset; | |||
offset = 0; | |||
std::vector<dt_int32> offsets(shapes.size() << 1); | |||
size_t offset = 0; | |||
for (size_t i = 0; i < shapes.size(); i++) { | |||
for (; (offset & (align - 1)) != 0; offset++) { | |||
outer_table[offset] = inner_table[offset] = -1; | |||
} | |||
size_t j = 0; | |||
for (; j < shapes[i].total_nr_elems(); j++) { | |||
outer_table[offset + j] = i; | |||
inner_table[offset + j] = j; | |||
} | |||
offset += j; | |||
offset = get_aligned(offset); | |||
offsets[i << 1] = offset; | |||
offset += shapes[i].total_nr_elems(); | |||
offsets[(i << 1) + 1] = offset; | |||
} | |||
return table; | |||
return offsets; | |||
} | |||
template<class T> | |||
std::vector<T> create_pack(size_t pack_size, const std::vector<int32_t>& table, | |||
std::vector<T> create_pack(size_t pack_size, const std::vector<int32_t>& offsets, | |||
const std::vector<std::vector<T>>& ptr) { | |||
assert(pack_size == table.size() / 2); | |||
const int32_t* outer_table = table.data(); | |||
const int32_t* inner_table = outer_table + pack_size; | |||
std::vector<T> data(pack_size); | |||
for (size_t idx = 0; idx < pack_size; ++idx) { | |||
int32_t out_idx = outer_table[idx]; | |||
int32_t in_idx = inner_table[idx]; | |||
if (in_idx != -1) { | |||
data[idx] = ptr[out_idx][in_idx]; | |||
} | |||
assert(pack_size == offsets.back()); | |||
std::vector<T> data(pack_size, 0); | |||
for (size_t i = 0; i * 2 < offsets.size(); ++i) { | |||
size_t begin = offsets[i * 2], end = offsets[i * 2 +1]; | |||
for (size_t j = 0;j < end - begin; j++) | |||
data[begin + j] = ptr[i][j]; | |||
} | |||
return data; | |||
} | |||
@@ -95,65 +77,6 @@ T* create_device_data(Handle* handle, const T* data, size_t size) { | |||
return data_device; | |||
} | |||
template<class T> | |||
void test_param_pack_split(Handle* handle, const TensorShapeArray& shapes, | |||
DType type) { | |||
auto split = handle->create_operator<ParamPackSplit>(); | |||
size_t nr_params = shapes.size(); | |||
std::vector<T*> param_ptrs; | |||
for (size_t i = 0; i < nr_params; ++i) { | |||
param_ptrs.push_back(create_device_data<T>(handle, | |||
nullptr, shapes[i].total_nr_elems())); | |||
} | |||
std::vector<std::vector<T>> expected_param = create_params<T>(nr_params, | |||
shapes); | |||
std::vector<int32_t> table = | |||
create_table<T>(shapes, handle->alignment_requirement()); | |||
ASSERT_EQ(table, | |||
ParamPackSplit::gen_offsets( | |||
shapes, handle->alignment_requirement(), sizeof(T))); | |||
size_t pack_size = table.size() / 2; | |||
int32_t* table_gpu = create_device_data<int32_t>(handle, table.data(), | |||
table.size()); | |||
std::vector<T> pack = | |||
create_pack<T>(pack_size, table, expected_param); | |||
T* pack_gpu = create_device_data<T>(handle, pack.data(), pack.size()); | |||
TensorLayout src_layout({pack_size}, type); | |||
TensorND src_tensor(pack_gpu, src_layout); | |||
TensorLayout table_layout({table.size()}, dtype::Int32()); | |||
TensorND table_tensor(table_gpu, table_layout); | |||
test::WorkspaceWrapper workspace(handle, split->get_workspace_in_bytes( | |||
{pack_size}, table_layout, shapes)); | |||
TensorND dst_tensor(param_ptrs.data(), | |||
TensorLayout({nr_params}, dtype::Int32())); | |||
split->exec(src_tensor, table_tensor, dst_tensor, workspace.workspace()); | |||
// check | |||
for (size_t i = 0; i < nr_params; ++i) { | |||
T* actual_param = static_cast<T*>(malloc(shapes[i].total_nr_elems() | |||
* sizeof(T))); | |||
test::megdnn_memcpy_D2H(handle, actual_param, param_ptrs[i], | |||
shapes[i].total_nr_elems() * sizeof(T)); | |||
for (size_t idx = 0; idx < shapes[i].total_nr_elems(); ++idx) { | |||
ASSERT_EQ(actual_param[idx], expected_param[i][idx]); | |||
} | |||
free(actual_param); | |||
} | |||
test::megdnn_free(handle, pack_gpu); | |||
test::megdnn_free(handle, table_gpu); | |||
for (auto ptr : param_ptrs) { | |||
test::megdnn_free(handle, ptr); | |||
} | |||
} | |||
template <class T> | |||
void test_param_pack_concat(Handle* handle, const TensorShapeArray& shapes, | |||
DType type) { | |||
@@ -167,28 +90,28 @@ void test_param_pack_concat(Handle* handle, const TensorShapeArray& shapes, | |||
param_ptrs.push_back(create_device_data<T>(handle, | |||
params[i].data(), shapes[i].total_nr_elems())); | |||
} | |||
std::vector<int32_t> table = | |||
create_table<T>(shapes, handle->alignment_requirement()); | |||
size_t pack_size = table.size() / 2; | |||
int32_t* table_gpu = create_device_data<int32_t>(handle, table.data(), | |||
table.size()); | |||
std::vector<int32_t> offsets = | |||
create_offsets<T>(shapes, handle->alignment_requirement()); | |||
size_t pack_size = offsets.back(); | |||
int32_t* offsets_gpu = create_device_data<int32_t>(handle, offsets.data(), | |||
offsets.size()); | |||
std::vector<T> expected_pack = | |||
create_pack<T>(pack_size, table, params); | |||
create_pack<T>(pack_size, offsets, params); | |||
T* pack_gpu = create_device_data<T>(handle, nullptr, expected_pack.size()); | |||
TensorLayout dst_layout({pack_size}, type); | |||
TensorND dst_tensor(pack_gpu, dst_layout); | |||
TensorLayout table_layout({table.size()}, dtype::Int32()); | |||
TensorND table_tensor(table_gpu, table_layout); | |||
TensorLayout offsets_layout({offsets.size()}, dtype::Int32()); | |||
TensorND offsets_tensor(offsets_gpu, offsets_layout); | |||
test::WorkspaceWrapper workspace(handle, concat->get_workspace_in_bytes( | |||
shapes, table_layout, {pack_size})); | |||
shapes, offsets_layout, {pack_size})); | |||
TensorND src_tensor(param_ptrs.data(), | |||
TensorLayout({nr_params}, dtype::Int32())); | |||
concat->exec(src_tensor, table_tensor, dst_tensor, workspace.workspace()); | |||
concat->exec(src_tensor, offsets_tensor, dst_tensor, workspace.workspace()); | |||
// check | |||
T* actual_pack = static_cast<T*>(malloc(pack_size * sizeof(T))); | |||
@@ -199,7 +122,7 @@ void test_param_pack_concat(Handle* handle, const TensorShapeArray& shapes, | |||
} | |||
free(actual_pack); | |||
test::megdnn_free(handle, pack_gpu); | |||
test::megdnn_free(handle, table_gpu); | |||
test::megdnn_free(handle, offsets_gpu); | |||
for (auto ptr : param_ptrs) { | |||
test::megdnn_free(handle, ptr); | |||
} | |||
@@ -222,9 +145,6 @@ TEST_F(CUDA, PARAM_PACK) { | |||
{111, 111, 111}, | |||
{128, 128, 128}}); | |||
for (auto shapes : shapes_vec) { | |||
test_param_pack_split<int32_t>(handle_cuda(), shapes, dtype::Int32()); | |||
test_param_pack_split<int16_t>(handle_cuda(), shapes, dtype::Int16()); | |||
test_param_pack_split<float>(handle_cuda(), shapes, dtype::Float32()); | |||
test_param_pack_concat<int32_t>(handle_cuda(), shapes, dtype::Int32()); | |||
test_param_pack_concat<int16_t>(handle_cuda(), shapes, dtype::Int16()); | |||
test_param_pack_concat<float>(handle_cuda(), shapes, dtype::Float32()); | |||
@@ -38,8 +38,7 @@ SymbolVar _Opr::_axis_add_remove(SymbolVar src, | |||
} | |||
SymbolVarArray _Opr::param_pack_split( | |||
SymbolVar src, SymbolVar table, | |||
const std::vector<std::vector<size_t>>& shapes, | |||
SymbolVar src, const std::vector<std::vector<size_t>>& shapes, | |||
const OperatorNodeConfig& config) { | |||
auto size = shapes.size(); | |||
mgb::TensorShapeArray shapearr(size); | |||
@@ -48,18 +47,16 @@ SymbolVarArray _Opr::param_pack_split( | |||
} | |||
auto cn = src.node()->comp_node(); | |||
auto table_val = megdnn::ParamPackSplit::gen_offsets( | |||
auto offsets_val = megdnn::ParamPackConcat::gen_offsets( | |||
shapearr, cn.get_mem_addr_alignment(), src.dtype().size()); | |||
if (!table.node()) { | |||
if (config.has_comp_node_set()) { | |||
cn = config.get_single_comp_node(); | |||
} | |||
HostTensorND hv{cn, TensorShape{{table_val.size()}}, dtype::Int32{}}; | |||
memcpy(hv.raw_ptr(), table_val.data(), table_val.size() * sizeof(int)); | |||
table = opr::ImmutableTensor::make(*src.node()->owner_graph(), hv); | |||
if (config.has_comp_node_set()) { | |||
cn = config.get_single_comp_node(); | |||
} | |||
HostTensorND hv{cn, TensorShape{{offsets_val.size()}}, dtype::Int32{}}; | |||
memcpy(hv.raw_ptr(), offsets_val.data(), offsets_val.size() * sizeof(int)); | |||
auto offsets = opr::ImmutableTensor::make(*src.node()->owner_graph(), hv); | |||
return mgb::opr::ParamPackSplit::make(src, table, table_val, shapearr, config); | |||
return mgb::opr::ParamPackSplit::make(src, offsets, offsets_val, shapearr, config); | |||
} | |||
#if MGB_ENABLE_OPR_MM | |||
@@ -44,8 +44,7 @@ static SymbolVar add_update(SymbolVar dest, SymbolVar delta, | |||
// tensor manip | |||
static SymbolVarArray param_pack_split( | |||
SymbolVar src, SymbolVar table, | |||
const std::vector<std::vector<size_t>>& shapes, | |||
SymbolVar src, const std::vector<std::vector<size_t>>& shapes, | |||
const OperatorNodeConfig& config); | |||
static SymbolVar dimshuffle(SymbolVar src, | |||
@@ -159,11 +159,11 @@ def dimshuffle(src, pattern, ndim=0, *, | |||
pattern_mgb.push_back(i) | |||
return _mgb._Opr.dimshuffle(src, pattern_mgb, int(ndim), config) | |||
def param_pack_split(src, shapes, table=None, *, | |||
def param_pack_split(src, shapes, *, | |||
name=None, comp_node=None, config=None): | |||
""" | |||
split param into a list of tensor for given shape | |||
ParamPackSplit operator has two inputs: ``src`` and ``tables`` and would | |||
ParamPackSplit operator has a input: ``src`` and would | |||
have a ``output``. output[i] indicates the address of tensor which part of | |||
``src`` would transfer its elements into. | |||
@@ -172,24 +172,13 @@ def param_pack_split(src, shapes, table=None, *, | |||
output[0] indicates the address of tensor with shapes[0]:(1, 2, 4), | |||
output[1] indicates the address of tensor with shapes[1]:(4, 2, 2), | |||
output[2] indicates the address of tensor with shapes[2]:(4, 2, 1). | |||
And table have the double size of input tensor. | |||
For each element in the tensor input[i], we may have | |||
output[outer_index[i]][inner_index[i]] = input[i]. | |||
Table would the concatation of outer_index and inner_index, so more | |||
alternatively, output[table[i]][table[i+len(input)]] = input[i] | |||
:param src: The concatenated input tensor. | |||
:type src: :class:`SymbolVar` | |||
:param shapes: Shapes of output tensors | |||
:type shapes: list of list of int | |||
:param table: Output element mapping table; it if it is None, a table would | |||
be generated from ``shapes`` | |||
:type table: :class:`SymbolVar` with int32 type, or None | |||
""" | |||
config = _helper.gen_config(name, comp_node, config) | |||
if isinstance(table, (list, tuple)) and isinstance(shapes, _mgb.SymbolVar): | |||
# compatible with old API | |||
table, shapes = shapes, table | |||
if not isinstance(shapes, (list, tuple)): | |||
raise TypeError('could not convert {} to tensor shapes'.format( | |||
@@ -202,10 +191,7 @@ def param_pack_split(src, shapes, table=None, *, | |||
assert min(s) > 0 | |||
shapes_mgb.push_back(s) | |||
if table is None: | |||
table = _mgb.SymbolVar() | |||
return _mgb._Opr.param_pack_split(src, table, shapes_mgb, config) | |||
return _mgb._Opr.param_pack_split(src, shapes_mgb, config) | |||
class _modify_subtensor_helper: | |||
def __init__(self, dest, val, *, name=None, comp_node=None, config=None): | |||
@@ -1400,8 +1400,8 @@ void ParamPackConcat::init_output_static_infer_desc(){ | |||
using namespace cg::static_infer; | |||
auto &&mgr = owner_graph()->static_infer_manager(); | |||
auto infer_out = [this](TensorShape &dest, const InpVal &inp) { | |||
dest = {m_offsets.back()}; | |||
auto infer_out = [this](TensorShape& dest, const InpVal& inp) { | |||
dest = {static_cast<unsigned int>(m_offsets.back())}; | |||
return true; | |||
}; | |||
DepVal shp_deps; | |||
@@ -1476,9 +1476,6 @@ SymbolVarArray ParamPackSplit::make(const SymbolVar& src, | |||
return ret; | |||
} | |||
void ParamPackSplit::scn_do_execute() { | |||
} | |||
void ParamPackSplit::init_output_dtype() { | |||
// already initialized in constructor | |||
} | |||
@@ -1518,7 +1515,6 @@ void ParamPackSplit::init_output_static_infer_desc() { | |||
MGB_IMPL_OPR_GRAD(ParamPackSplit) { | |||
mgb_assert(out_grad.size() == opr.output().size()); | |||
SmallVector<SymbolVar> grad; | |||
// last var is workspace, ignore it | |||
for (size_t i = 0; i < out_grad.size(); ++i) { | |||
auto gval = out_grad[i]; | |||
if (!gval) { | |||
@@ -583,7 +583,7 @@ MGB_DEFINE_OPR_CLASS(ParamPackSplit, cg::SingleCNOperatorNodeBase) // { | |||
std::vector<dt_int32> m_offsets; | |||
std::vector<bool> m_mem_fwd_success; | |||
void scn_do_execute() override; | |||
void scn_do_execute() override{}; | |||
void init_output_static_infer_desc() override; | |||
bool infer_shape(size_t index, TensorShape &dest, | |||
const cg::static_infer::InpVal &inp); | |||
@@ -1898,15 +1898,15 @@ void test_param_pack_concat(const TensorShapeArray &shapes, DType type){ | |||
srcs.push_back(nd); | |||
} | |||
auto host_table_gen = megdnn::ParamPackSplit::gen_offsets(shapes, | |||
auto host_offsets_gen = megdnn::ParamPackConcat::gen_offsets(shapes, | |||
cn.get_mem_addr_alignment(), 4); | |||
ASSERT_EQ(host_table_gen.size(), size * 2); | |||
auto host_table = std::make_shared<HostTensorND>(); | |||
host_table->comp_node(cn).dtype(dtype::Int32{}).resize({size * 2}); | |||
memcpy(host_table->raw_ptr(), host_table_gen.data(), size * 8); | |||
auto table = opr::Host2DeviceCopy::make(*graph, host_table); | |||
ASSERT_EQ(host_offsets_gen.back(), size); | |||
auto host_offsets = std::make_shared<HostTensorND>(); | |||
host_offsets->comp_node(cn).dtype(dtype::Int32{}).resize({srcs.size() * 2}); | |||
memcpy(host_offsets->raw_ptr(), host_offsets_gen.data(), srcs.size() * 8); | |||
auto offsets = opr::Host2DeviceCopy::make(*graph, host_offsets); | |||
auto z = opr::ParamPackConcat::make(srcs, table, host_table_gen); | |||
auto z = opr::ParamPackConcat::make(srcs, offsets, host_offsets_gen); | |||
HostTensorND host_z; | |||
auto func = graph->compile({make_callback_copy(z, host_z)}); | |||
@@ -1944,17 +1944,17 @@ void test_param_pack_split(const TensorShapeArray& shapes) { | |||
auto make_graph = [&](const typename Checker::SymInpArray& inputs) -> | |||
typename Checker::SymOutArray { | |||
auto table_val = megdnn::ParamPackSplit::gen_offsets( | |||
auto offsets_val = megdnn::ParamPackConcat::gen_offsets( | |||
shapes, cn.get_mem_addr_alignment(), 4); | |||
HostTensorND table; | |||
std::copy_n(table_val.data(), table_val.size(), | |||
table.dtype(dtype::Int32{}) | |||
HostTensorND offsets; | |||
std::copy_n(offsets_val.data(), offsets_val.size(), | |||
offsets.dtype(dtype::Int32{}) | |||
.comp_node(cn) | |||
.resize({table_val.size()}) | |||
.resize({offsets_val.size()}) | |||
.ptr<dt_int32>()); | |||
auto sym_table = opr::SharedDeviceTensor::make( | |||
*inputs[0].node()->owner_graph(), table); | |||
auto out = opr::ParamPackSplit::make(inputs[0], sym_table, table_val, | |||
auto sym_offsets = opr::SharedDeviceTensor::make( | |||
*inputs[0].node()->owner_graph(), offsets); | |||
auto out = opr::ParamPackSplit::make(inputs[0], sym_offsets, offsets_val, | |||
shapes); | |||
mgb_assert(out.size() == nr_out); | |||
typename Checker::SymOutArray ret; | |||