GitOrigin-RevId: ed00341d58
tags/v0.4.0
@@ -500,47 +500,20 @@ public: | |||||
/* | /* | ||||
* \param[in] srcs: TensorND on cpu. srcs[i] corresponding to the | * \param[in] srcs: TensorND on cpu. srcs[i] corresponding to the | ||||
* address of i-th Tensor. | * address of i-th Tensor. | ||||
* \param[in] table: with size `2 * srcs.nr_total_elems()`. | |||||
* table[addr] corresponding to outer_idx, | |||||
* table[addr+srcs.nr_total_elems()] corresponding to | |||||
* inner_idx of dsts. | |||||
* \param[in] offsets: with size `2 * srcs.shape[0]`. | |||||
* offsets[i * 2] and offsets[i * 2 + 1] means | |||||
* the begin and the end of offset in | |||||
* \param[out] dst: output TensorND, live on cpu or gpu | * \param[out] dst: output TensorND, live on cpu or gpu | ||||
*/ | */ | ||||
virtual void exec(_megdnn_tensor_in srcs, _megdnn_tensor_in table, | |||||
virtual void exec(_megdnn_tensor_in srcs, _megdnn_tensor_in offsets, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | ||||
virtual size_t get_workspace_in_bytes(const TensorShapeArray& srcs, | virtual size_t get_workspace_in_bytes(const TensorShapeArray& srcs, | ||||
const TensorShape& table, | |||||
const TensorShape& offsets, | |||||
const TensorShape& dst) = 0; | const TensorShape& dst) = 0; | ||||
}; | }; | ||||
/** | /** | ||||
* \brief ParamPackSplit, used for network forwarding. | |||||
* Split a single large param into several small tensors, use copy stategy | |||||
* either. | |||||
*/ | |||||
class ParamPackSplit: public ParamPackConcatSplitBase { | |||||
DEF_OPR_IMPL(ParamPackSplit, ParamPackConcatSplitBase, 2, 1); | |||||
public: | |||||
/* | |||||
* \param[in] src: input TensorND, live on cpu or gpu | |||||
* \param[in] table: with size `2 * srcs.nr_total_elems()`. | |||||
* table[addr] corresponding to outer_idx, | |||||
* table[addr+srcs.nr_total_elems()] corresponding to | |||||
* inner_idx of dsts. | |||||
* \param[out] dsts: TensorND on cpu. dsts[i] corresponding to the address | |||||
* of i-th Tensor | |||||
*/ | |||||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in table, | |||||
_megdnn_tensor_out dsts, _megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes(const TensorShape& src, | |||||
const TensorShape& table, | |||||
const TensorShapeArray& dsts) = 0; | |||||
}; | |||||
/** | |||||
* \brief base class for Tile and Repeat | * \brief base class for Tile and Repeat | ||||
*/ | */ | ||||
class TileRepeatBase: public OperatorBase { | class TileRepeatBase: public OperatorBase { | ||||
@@ -167,7 +167,6 @@ private: | |||||
cb(Resize) \ | cb(Resize) \ | ||||
cb(ResizeBackward) \ | cb(ResizeBackward) \ | ||||
cb(ParamPackConcat) \ | cb(ParamPackConcat) \ | ||||
cb(ParamPackSplit) \ | |||||
cb(MaxTensorDiff) \ | cb(MaxTensorDiff) \ | ||||
cb(MaskConvForward) \ | cb(MaskConvForward) \ | ||||
cb(MaskPropagate) \ | cb(MaskPropagate) \ | ||||
@@ -48,9 +48,9 @@ std::vector<dt_int32> ParamPackConcatSplitBase::gen_offsets( | |||||
size_t offset = 0; | size_t offset = 0; | ||||
for (size_t i = 0; i < shapes.size(); i++) { | for (size_t i = 0; i < shapes.size(); i++) { | ||||
offset = get_aligned(offset); | offset = get_aligned(offset); | ||||
offsets[i * 2] = offset; | |||||
offsets[i << 1] = offset; | |||||
offset += shapes[i].total_nr_elems(); | offset += shapes[i].total_nr_elems(); | ||||
offsets[i * 2 + 1] = offset; | |||||
offsets[(i << 1) + 1] = offset; | |||||
} | } | ||||
return offsets; | return offsets; | ||||
} | } | ||||
@@ -60,56 +60,5 @@ void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, | |||||
#undef cb | #undef cb | ||||
} | } | ||||
size_t ParamPackSplitImpl::get_workspace_in_bytes( | |||||
const TensorShape&, const TensorShape&, const TensorShapeArray& dsts) { | |||||
return sizeof(size_t) * dsts.size(); | |||||
} | |||||
template <typename T> | |||||
void ParamPackSplitImpl::exec_internal(_megdnn_tensor_in src, | |||||
_megdnn_tensor_in table, | |||||
_megdnn_tensor_out dsts, | |||||
_megdnn_workspace workspace) { | |||||
// inner and outer table must be int32 | |||||
megdnn_assert(table.layout.dtype == dtype::Int32()); | |||||
// dsts is src pointer, ndim must be 1 | |||||
megdnn_assert(dsts.layout.ndim == 1); | |||||
auto out_size = dsts.layout.shape[0], | |||||
inp_size = src.layout.total_nr_elems(); | |||||
auto stream = cuda_stream(this->handle()); | |||||
auto total_workspace_size = sizeof(T*) * out_size; | |||||
auto dsts_cpu = static_cast<T**>(dsts.raw_ptr); | |||||
megdnn_assert_internal(dsts_cpu); | |||||
auto dsts_gpu = reinterpret_cast<T**>(workspace.raw_ptr); | |||||
auto table_outer_gpu = table.ptr<int32_t>(); | |||||
auto table_inner_gpu = table_outer_gpu + inp_size; | |||||
cuda_check(cudaMemcpyAsync(dsts_gpu, dsts_cpu, total_workspace_size, | |||||
cudaMemcpyHostToDevice, stream)); | |||||
// param_pack_split_proxy() | |||||
param_pack::split_proxy<T>(src.ptr<T>(), dsts_gpu, inp_size, | |||||
table_outer_gpu, table_inner_gpu, stream); | |||||
} | |||||
void ParamPackSplitImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in table, | |||||
_megdnn_tensor_out dsts, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(src.layout, table.layout, dsts.layout); | |||||
#define cb(DType) \ | |||||
if (src.layout.dtype == DType()) { \ | |||||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||||
exec_internal<ctype>(src, table, dsts, workspace); \ | |||||
return; \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||||
megdnn_throw("bad type"); | |||||
#undef cb | |||||
} | |||||
} // namespace cuda | } // namespace cuda | ||||
} // namespace megdnn | } // namespace megdnn |
@@ -31,21 +31,5 @@ private: | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace); | _megdnn_tensor_out dst, _megdnn_workspace workspace); | ||||
}; | }; | ||||
class ParamPackSplitImpl final : public ParamPackSplit { | |||||
public: | |||||
using ParamPackSplit::ParamPackSplit; | |||||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in table, | |||||
_megdnn_tensor_out dsts, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorShape& src, | |||||
const TensorShape& table, | |||||
const TensorShapeArray& dsts) override; | |||||
private: | |||||
template <typename T> | |||||
void exec_internal(_megdnn_tensor_in src, _megdnn_tensor_in table, | |||||
_megdnn_tensor_out dsts, _megdnn_workspace workspace); | |||||
}; | |||||
} // namespace cuda | } // namespace cuda | ||||
} // namespace megdnn | } // namespace megdnn |
@@ -41,31 +41,6 @@ __global__ void concat_kernel(const T** srcs, T* dst, | |||||
} | } | ||||
template <typename T> | template <typename T> | ||||
__global__ void split_kernel(const T* src, T** dsts, | |||||
const int32_t* table_outer, | |||||
const int32_t* table_inner, | |||||
size_t total_size) { | |||||
size_t addr = threadIdx.x + blockIdx.x * blockDim.x; | |||||
if (addr < total_size) { | |||||
int32_t i = table_outer[addr]; | |||||
int32_t idx = table_inner[addr]; | |||||
if (idx != -1) { | |||||
dsts[i][idx] = src[addr]; | |||||
} | |||||
} | |||||
} | |||||
template <typename T> | |||||
void split_proxy(const T* src, T** dsts, size_t total_size, | |||||
const int32_t* table_outer, const int32_t* table_inner, | |||||
cudaStream_t stream) { | |||||
size_t NR_BLOCKS = DIVUP(total_size, NR_THREADS); | |||||
split_kernel<<<NR_BLOCKS, NR_THREADS, 0, stream>>>( | |||||
src, dsts, table_outer, table_inner, total_size); | |||||
after_kernel_launch(); | |||||
} | |||||
template <typename T> | |||||
void concat_proxy(const T** srcs, T* dst, size_t srcs_size, size_t total_size, | void concat_proxy(const T** srcs, T* dst, size_t srcs_size, size_t total_size, | ||||
const int32_t* offsets, | const int32_t* offsets, | ||||
cudaStream_t stream) { | cudaStream_t stream) { | ||||
@@ -78,10 +53,7 @@ void concat_proxy(const T** srcs, T* dst, size_t srcs_size, size_t total_size, | |||||
#define INST(T) \ | #define INST(T) \ | ||||
template void concat_proxy<T>(const T**, T*, size_t, size_t, \ | template void concat_proxy<T>(const T**, T*, size_t, size_t, \ | ||||
const int32_t*, \ | const int32_t*, \ | ||||
cudaStream_t); \ | |||||
template void split_proxy<T>(const T*, T**, size_t, \ | |||||
const int32_t*, const int32_t*, \ | |||||
cudaStream_t); | |||||
cudaStream_t); | |||||
#define cb(DType) INST(typename DTypeTrait<DType>::ctype) | #define cb(DType) INST(typename DTypeTrait<DType>::ctype) | ||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
#undef cb | #undef cb | ||||
@@ -20,11 +20,6 @@ namespace cuda { | |||||
namespace param_pack { | namespace param_pack { | ||||
template <typename T> | template <typename T> | ||||
void split_proxy(const T* src, T** dsts, size_t total_size, | |||||
const int32_t* table_outer, const int32_t* table_inner, | |||||
cudaStream_t stream); | |||||
template <typename T> | |||||
void concat_proxy(const T** srcs, T* dst, size_t srcs_size, size_t total_size, | void concat_proxy(const T** srcs, T* dst, size_t srcs_size, size_t total_size, | ||||
const int32_t* offsets, cudaStream_t stream); | const int32_t* offsets, cudaStream_t stream); | ||||
@@ -17,43 +17,6 @@ using namespace megdnn; | |||||
using namespace naive; | using namespace naive; | ||||
template <typename T> | template <typename T> | ||||
void ParamPackSplitImpl::exec_internal(_megdnn_tensor_in src, int32_t* table, | |||||
_megdnn_tensor_out dsts, | |||||
_megdnn_workspace) { | |||||
auto dsts_ptr = static_cast<T**>(dsts.raw_ptr); | |||||
auto src_ptr = src.ptr<T>(); | |||||
auto inp_size = src.layout.total_nr_elems(); | |||||
auto table_outer = table, table_inner = table_outer + inp_size; | |||||
for (size_t j = 0; j < inp_size; j++) { | |||||
int32_t i = table_outer[j]; | |||||
int32_t idx = table_inner[j]; | |||||
if (idx != -1) { | |||||
dsts_ptr[i][idx] = src_ptr[j]; | |||||
} | |||||
} | |||||
} | |||||
void ParamPackSplitImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in table, | |||||
_megdnn_tensor_out dsts, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(src.layout, table.layout, dsts.layout); | |||||
auto table_ptr = table.ptr<int32_t>(); | |||||
#define cb(DType) \ | |||||
if (src.layout.dtype == DType()) { \ | |||||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
exec_internal<ctype>(src, table_ptr, dsts, workspace)); \ | |||||
return; \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||||
megdnn_throw("bad type"); | |||||
#undef cb | |||||
} | |||||
template <typename T> | |||||
void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, | void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, | ||||
int32_t* offsets, | int32_t* offsets, | ||||
_megdnn_tensor_out dst, | _megdnn_tensor_out dst, | ||||
@@ -13,27 +13,10 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace naive { | namespace naive { | ||||
class ParamPackSplitImpl final : public ParamPackSplit { | |||||
public: | |||||
using ParamPackSplit::ParamPackSplit; | |||||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in table, | |||||
_megdnn_tensor_out dsts, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorShape&, const TensorShape&, | |||||
const TensorShapeArray&) override { | |||||
return 0; | |||||
} | |||||
private: | |||||
template <typename T> | |||||
void exec_internal(_megdnn_tensor_in src, int32_t* table, | |||||
_megdnn_tensor_out dsts, _megdnn_workspace workspace); | |||||
}; | |||||
class ParamPackConcatImpl final : public ParamPackConcat { | class ParamPackConcatImpl final : public ParamPackConcat { | ||||
public: | public: | ||||
using ParamPackConcat::ParamPackConcat; | using ParamPackConcat::ParamPackConcat; | ||||
void exec(_megdnn_tensor_in srcs, _megdnn_tensor_in table, | |||||
void exec(_megdnn_tensor_in srcs, _megdnn_tensor_in offsets, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | _megdnn_tensor_out dst, _megdnn_workspace workspace) override; | ||||
size_t get_workspace_in_bytes(const TensorShapeArray&, const TensorShape&, | size_t get_workspace_in_bytes(const TensorShapeArray&, const TensorShape&, | ||||
@@ -43,7 +26,7 @@ public: | |||||
private: | private: | ||||
template <typename T> | template <typename T> | ||||
void exec_internal(_megdnn_tensor_in srcs, int32_t* table, | |||||
void exec_internal(_megdnn_tensor_in srcs, int32_t* offsets, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace); | _megdnn_tensor_out dst, _megdnn_workspace workspace); | ||||
}; | }; | ||||
@@ -18,56 +18,38 @@ using namespace test; | |||||
namespace { | namespace { | ||||
template<class T> | template<class T> | ||||
std::vector<int32_t> create_table(const TensorShapeArray& shapes, | |||||
size_t align) { | |||||
std::vector<int32_t> create_offsets(const TensorShapeArray& shapes, | |||||
size_t alignment) { | |||||
size_t dtype_size = sizeof(T); | size_t dtype_size = sizeof(T); | ||||
if (align < dtype_size) | |||||
align = dtype_size; | |||||
if (alignment < dtype_size) | |||||
alignment = dtype_size; | |||||
alignment /= dtype_size; | |||||
align /= dtype_size; | |||||
auto get_aligned = [alignment](size_t v) { | |||||
auto mod = v & (alignment - 1); | |||||
return v + ((alignment - mod) & (alignment - 1)); | |||||
}; | |||||
size_t offset = shapes[0].total_nr_elems(); | |||||
for (size_t i = 1; i < shapes.size(); i++) { | |||||
auto d = offset & (align - 1); | |||||
offset += (align - d) & (align - 1); | |||||
offset += shapes[i].total_nr_elems(); | |||||
} | |||||
std::vector<int32_t> table(offset * 2); | |||||
int32_t* outer_table = table.data(); | |||||
int32_t* inner_table = outer_table + offset; | |||||
offset = 0; | |||||
std::vector<dt_int32> offsets(shapes.size() << 1); | |||||
size_t offset = 0; | |||||
for (size_t i = 0; i < shapes.size(); i++) { | for (size_t i = 0; i < shapes.size(); i++) { | ||||
for (; (offset & (align - 1)) != 0; offset++) { | |||||
outer_table[offset] = inner_table[offset] = -1; | |||||
} | |||||
size_t j = 0; | |||||
for (; j < shapes[i].total_nr_elems(); j++) { | |||||
outer_table[offset + j] = i; | |||||
inner_table[offset + j] = j; | |||||
} | |||||
offset += j; | |||||
offset = get_aligned(offset); | |||||
offsets[i << 1] = offset; | |||||
offset += shapes[i].total_nr_elems(); | |||||
offsets[(i << 1) + 1] = offset; | |||||
} | } | ||||
return table; | |||||
return offsets; | |||||
} | } | ||||
template<class T> | template<class T> | ||||
std::vector<T> create_pack(size_t pack_size, const std::vector<int32_t>& table, | |||||
std::vector<T> create_pack(size_t pack_size, const std::vector<int32_t>& offsets, | |||||
const std::vector<std::vector<T>>& ptr) { | const std::vector<std::vector<T>>& ptr) { | ||||
assert(pack_size == table.size() / 2); | |||||
const int32_t* outer_table = table.data(); | |||||
const int32_t* inner_table = outer_table + pack_size; | |||||
std::vector<T> data(pack_size); | |||||
for (size_t idx = 0; idx < pack_size; ++idx) { | |||||
int32_t out_idx = outer_table[idx]; | |||||
int32_t in_idx = inner_table[idx]; | |||||
if (in_idx != -1) { | |||||
data[idx] = ptr[out_idx][in_idx]; | |||||
} | |||||
assert(pack_size == offsets.back()); | |||||
std::vector<T> data(pack_size, 0); | |||||
for (size_t i = 0; i * 2 < offsets.size(); ++i) { | |||||
size_t begin = offsets[i * 2], end = offsets[i * 2 +1]; | |||||
for (size_t j = 0;j < end - begin; j++) | |||||
data[begin + j] = ptr[i][j]; | |||||
} | } | ||||
return data; | return data; | ||||
} | } | ||||
@@ -95,65 +77,6 @@ T* create_device_data(Handle* handle, const T* data, size_t size) { | |||||
return data_device; | return data_device; | ||||
} | } | ||||
template<class T> | |||||
void test_param_pack_split(Handle* handle, const TensorShapeArray& shapes, | |||||
DType type) { | |||||
auto split = handle->create_operator<ParamPackSplit>(); | |||||
size_t nr_params = shapes.size(); | |||||
std::vector<T*> param_ptrs; | |||||
for (size_t i = 0; i < nr_params; ++i) { | |||||
param_ptrs.push_back(create_device_data<T>(handle, | |||||
nullptr, shapes[i].total_nr_elems())); | |||||
} | |||||
std::vector<std::vector<T>> expected_param = create_params<T>(nr_params, | |||||
shapes); | |||||
std::vector<int32_t> table = | |||||
create_table<T>(shapes, handle->alignment_requirement()); | |||||
ASSERT_EQ(table, | |||||
ParamPackSplit::gen_offsets( | |||||
shapes, handle->alignment_requirement(), sizeof(T))); | |||||
size_t pack_size = table.size() / 2; | |||||
int32_t* table_gpu = create_device_data<int32_t>(handle, table.data(), | |||||
table.size()); | |||||
std::vector<T> pack = | |||||
create_pack<T>(pack_size, table, expected_param); | |||||
T* pack_gpu = create_device_data<T>(handle, pack.data(), pack.size()); | |||||
TensorLayout src_layout({pack_size}, type); | |||||
TensorND src_tensor(pack_gpu, src_layout); | |||||
TensorLayout table_layout({table.size()}, dtype::Int32()); | |||||
TensorND table_tensor(table_gpu, table_layout); | |||||
test::WorkspaceWrapper workspace(handle, split->get_workspace_in_bytes( | |||||
{pack_size}, table_layout, shapes)); | |||||
TensorND dst_tensor(param_ptrs.data(), | |||||
TensorLayout({nr_params}, dtype::Int32())); | |||||
split->exec(src_tensor, table_tensor, dst_tensor, workspace.workspace()); | |||||
// check | |||||
for (size_t i = 0; i < nr_params; ++i) { | |||||
T* actual_param = static_cast<T*>(malloc(shapes[i].total_nr_elems() | |||||
* sizeof(T))); | |||||
test::megdnn_memcpy_D2H(handle, actual_param, param_ptrs[i], | |||||
shapes[i].total_nr_elems() * sizeof(T)); | |||||
for (size_t idx = 0; idx < shapes[i].total_nr_elems(); ++idx) { | |||||
ASSERT_EQ(actual_param[idx], expected_param[i][idx]); | |||||
} | |||||
free(actual_param); | |||||
} | |||||
test::megdnn_free(handle, pack_gpu); | |||||
test::megdnn_free(handle, table_gpu); | |||||
for (auto ptr : param_ptrs) { | |||||
test::megdnn_free(handle, ptr); | |||||
} | |||||
} | |||||
template <class T> | template <class T> | ||||
void test_param_pack_concat(Handle* handle, const TensorShapeArray& shapes, | void test_param_pack_concat(Handle* handle, const TensorShapeArray& shapes, | ||||
DType type) { | DType type) { | ||||
@@ -167,28 +90,28 @@ void test_param_pack_concat(Handle* handle, const TensorShapeArray& shapes, | |||||
param_ptrs.push_back(create_device_data<T>(handle, | param_ptrs.push_back(create_device_data<T>(handle, | ||||
params[i].data(), shapes[i].total_nr_elems())); | params[i].data(), shapes[i].total_nr_elems())); | ||||
} | } | ||||
std::vector<int32_t> table = | |||||
create_table<T>(shapes, handle->alignment_requirement()); | |||||
size_t pack_size = table.size() / 2; | |||||
int32_t* table_gpu = create_device_data<int32_t>(handle, table.data(), | |||||
table.size()); | |||||
std::vector<int32_t> offsets = | |||||
create_offsets<T>(shapes, handle->alignment_requirement()); | |||||
size_t pack_size = offsets.back(); | |||||
int32_t* offsets_gpu = create_device_data<int32_t>(handle, offsets.data(), | |||||
offsets.size()); | |||||
std::vector<T> expected_pack = | std::vector<T> expected_pack = | ||||
create_pack<T>(pack_size, table, params); | |||||
create_pack<T>(pack_size, offsets, params); | |||||
T* pack_gpu = create_device_data<T>(handle, nullptr, expected_pack.size()); | T* pack_gpu = create_device_data<T>(handle, nullptr, expected_pack.size()); | ||||
TensorLayout dst_layout({pack_size}, type); | TensorLayout dst_layout({pack_size}, type); | ||||
TensorND dst_tensor(pack_gpu, dst_layout); | TensorND dst_tensor(pack_gpu, dst_layout); | ||||
TensorLayout table_layout({table.size()}, dtype::Int32()); | |||||
TensorND table_tensor(table_gpu, table_layout); | |||||
TensorLayout offsets_layout({offsets.size()}, dtype::Int32()); | |||||
TensorND offsets_tensor(offsets_gpu, offsets_layout); | |||||
test::WorkspaceWrapper workspace(handle, concat->get_workspace_in_bytes( | test::WorkspaceWrapper workspace(handle, concat->get_workspace_in_bytes( | ||||
shapes, table_layout, {pack_size})); | |||||
shapes, offsets_layout, {pack_size})); | |||||
TensorND src_tensor(param_ptrs.data(), | TensorND src_tensor(param_ptrs.data(), | ||||
TensorLayout({nr_params}, dtype::Int32())); | TensorLayout({nr_params}, dtype::Int32())); | ||||
concat->exec(src_tensor, table_tensor, dst_tensor, workspace.workspace()); | |||||
concat->exec(src_tensor, offsets_tensor, dst_tensor, workspace.workspace()); | |||||
// check | // check | ||||
T* actual_pack = static_cast<T*>(malloc(pack_size * sizeof(T))); | T* actual_pack = static_cast<T*>(malloc(pack_size * sizeof(T))); | ||||
@@ -199,7 +122,7 @@ void test_param_pack_concat(Handle* handle, const TensorShapeArray& shapes, | |||||
} | } | ||||
free(actual_pack); | free(actual_pack); | ||||
test::megdnn_free(handle, pack_gpu); | test::megdnn_free(handle, pack_gpu); | ||||
test::megdnn_free(handle, table_gpu); | |||||
test::megdnn_free(handle, offsets_gpu); | |||||
for (auto ptr : param_ptrs) { | for (auto ptr : param_ptrs) { | ||||
test::megdnn_free(handle, ptr); | test::megdnn_free(handle, ptr); | ||||
} | } | ||||
@@ -222,9 +145,6 @@ TEST_F(CUDA, PARAM_PACK) { | |||||
{111, 111, 111}, | {111, 111, 111}, | ||||
{128, 128, 128}}); | {128, 128, 128}}); | ||||
for (auto shapes : shapes_vec) { | for (auto shapes : shapes_vec) { | ||||
test_param_pack_split<int32_t>(handle_cuda(), shapes, dtype::Int32()); | |||||
test_param_pack_split<int16_t>(handle_cuda(), shapes, dtype::Int16()); | |||||
test_param_pack_split<float>(handle_cuda(), shapes, dtype::Float32()); | |||||
test_param_pack_concat<int32_t>(handle_cuda(), shapes, dtype::Int32()); | test_param_pack_concat<int32_t>(handle_cuda(), shapes, dtype::Int32()); | ||||
test_param_pack_concat<int16_t>(handle_cuda(), shapes, dtype::Int16()); | test_param_pack_concat<int16_t>(handle_cuda(), shapes, dtype::Int16()); | ||||
test_param_pack_concat<float>(handle_cuda(), shapes, dtype::Float32()); | test_param_pack_concat<float>(handle_cuda(), shapes, dtype::Float32()); | ||||
@@ -38,8 +38,7 @@ SymbolVar _Opr::_axis_add_remove(SymbolVar src, | |||||
} | } | ||||
SymbolVarArray _Opr::param_pack_split( | SymbolVarArray _Opr::param_pack_split( | ||||
SymbolVar src, SymbolVar table, | |||||
const std::vector<std::vector<size_t>>& shapes, | |||||
SymbolVar src, const std::vector<std::vector<size_t>>& shapes, | |||||
const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
auto size = shapes.size(); | auto size = shapes.size(); | ||||
mgb::TensorShapeArray shapearr(size); | mgb::TensorShapeArray shapearr(size); | ||||
@@ -48,18 +47,16 @@ SymbolVarArray _Opr::param_pack_split( | |||||
} | } | ||||
auto cn = src.node()->comp_node(); | auto cn = src.node()->comp_node(); | ||||
auto table_val = megdnn::ParamPackSplit::gen_offsets( | |||||
auto offsets_val = megdnn::ParamPackConcat::gen_offsets( | |||||
shapearr, cn.get_mem_addr_alignment(), src.dtype().size()); | shapearr, cn.get_mem_addr_alignment(), src.dtype().size()); | ||||
if (!table.node()) { | |||||
if (config.has_comp_node_set()) { | |||||
cn = config.get_single_comp_node(); | |||||
} | |||||
HostTensorND hv{cn, TensorShape{{table_val.size()}}, dtype::Int32{}}; | |||||
memcpy(hv.raw_ptr(), table_val.data(), table_val.size() * sizeof(int)); | |||||
table = opr::ImmutableTensor::make(*src.node()->owner_graph(), hv); | |||||
if (config.has_comp_node_set()) { | |||||
cn = config.get_single_comp_node(); | |||||
} | } | ||||
HostTensorND hv{cn, TensorShape{{offsets_val.size()}}, dtype::Int32{}}; | |||||
memcpy(hv.raw_ptr(), offsets_val.data(), offsets_val.size() * sizeof(int)); | |||||
auto offsets = opr::ImmutableTensor::make(*src.node()->owner_graph(), hv); | |||||
return mgb::opr::ParamPackSplit::make(src, table, table_val, shapearr, config); | |||||
return mgb::opr::ParamPackSplit::make(src, offsets, offsets_val, shapearr, config); | |||||
} | } | ||||
#if MGB_ENABLE_OPR_MM | #if MGB_ENABLE_OPR_MM | ||||
@@ -44,8 +44,7 @@ static SymbolVar add_update(SymbolVar dest, SymbolVar delta, | |||||
// tensor manip | // tensor manip | ||||
static SymbolVarArray param_pack_split( | static SymbolVarArray param_pack_split( | ||||
SymbolVar src, SymbolVar table, | |||||
const std::vector<std::vector<size_t>>& shapes, | |||||
SymbolVar src, const std::vector<std::vector<size_t>>& shapes, | |||||
const OperatorNodeConfig& config); | const OperatorNodeConfig& config); | ||||
static SymbolVar dimshuffle(SymbolVar src, | static SymbolVar dimshuffle(SymbolVar src, | ||||
@@ -159,11 +159,11 @@ def dimshuffle(src, pattern, ndim=0, *, | |||||
pattern_mgb.push_back(i) | pattern_mgb.push_back(i) | ||||
return _mgb._Opr.dimshuffle(src, pattern_mgb, int(ndim), config) | return _mgb._Opr.dimshuffle(src, pattern_mgb, int(ndim), config) | ||||
def param_pack_split(src, shapes, table=None, *, | |||||
def param_pack_split(src, shapes, *, | |||||
name=None, comp_node=None, config=None): | name=None, comp_node=None, config=None): | ||||
""" | """ | ||||
split param into a list of tensor for given shape | split param into a list of tensor for given shape | ||||
ParamPackSplit operator has two inputs: ``src`` and ``tables`` and would | |||||
ParamPackSplit operator has a input: ``src`` and would | |||||
have a ``output``. output[i] indicates the address of tensor which part of | have a ``output``. output[i] indicates the address of tensor which part of | ||||
``src`` would transfer its elements into. | ``src`` would transfer its elements into. | ||||
@@ -172,24 +172,13 @@ def param_pack_split(src, shapes, table=None, *, | |||||
output[0] indicates the address of tensor with shapes[0]:(1, 2, 4), | output[0] indicates the address of tensor with shapes[0]:(1, 2, 4), | ||||
output[1] indicates the address of tensor with shapes[1]:(4, 2, 2), | output[1] indicates the address of tensor with shapes[1]:(4, 2, 2), | ||||
output[2] indicates the address of tensor with shapes[2]:(4, 2, 1). | output[2] indicates the address of tensor with shapes[2]:(4, 2, 1). | ||||
And table have the double size of input tensor. | |||||
For each element in the tensor input[i], we may have | |||||
output[outer_index[i]][inner_index[i]] = input[i]. | |||||
Table would the concatation of outer_index and inner_index, so more | |||||
alternatively, output[table[i]][table[i+len(input)]] = input[i] | |||||
:param src: The concatenated input tensor. | :param src: The concatenated input tensor. | ||||
:type src: :class:`SymbolVar` | :type src: :class:`SymbolVar` | ||||
:param shapes: Shapes of output tensors | :param shapes: Shapes of output tensors | ||||
:type shapes: list of list of int | :type shapes: list of list of int | ||||
:param table: Output element mapping table; it if it is None, a table would | |||||
be generated from ``shapes`` | |||||
:type table: :class:`SymbolVar` with int32 type, or None | |||||
""" | """ | ||||
config = _helper.gen_config(name, comp_node, config) | config = _helper.gen_config(name, comp_node, config) | ||||
if isinstance(table, (list, tuple)) and isinstance(shapes, _mgb.SymbolVar): | |||||
# compatible with old API | |||||
table, shapes = shapes, table | |||||
if not isinstance(shapes, (list, tuple)): | if not isinstance(shapes, (list, tuple)): | ||||
raise TypeError('could not convert {} to tensor shapes'.format( | raise TypeError('could not convert {} to tensor shapes'.format( | ||||
@@ -202,10 +191,7 @@ def param_pack_split(src, shapes, table=None, *, | |||||
assert min(s) > 0 | assert min(s) > 0 | ||||
shapes_mgb.push_back(s) | shapes_mgb.push_back(s) | ||||
if table is None: | |||||
table = _mgb.SymbolVar() | |||||
return _mgb._Opr.param_pack_split(src, table, shapes_mgb, config) | |||||
return _mgb._Opr.param_pack_split(src, shapes_mgb, config) | |||||
class _modify_subtensor_helper: | class _modify_subtensor_helper: | ||||
def __init__(self, dest, val, *, name=None, comp_node=None, config=None): | def __init__(self, dest, val, *, name=None, comp_node=None, config=None): | ||||
@@ -1400,8 +1400,8 @@ void ParamPackConcat::init_output_static_infer_desc(){ | |||||
using namespace cg::static_infer; | using namespace cg::static_infer; | ||||
auto &&mgr = owner_graph()->static_infer_manager(); | auto &&mgr = owner_graph()->static_infer_manager(); | ||||
auto infer_out = [this](TensorShape &dest, const InpVal &inp) { | |||||
dest = {m_offsets.back()}; | |||||
auto infer_out = [this](TensorShape& dest, const InpVal& inp) { | |||||
dest = {static_cast<unsigned int>(m_offsets.back())}; | |||||
return true; | return true; | ||||
}; | }; | ||||
DepVal shp_deps; | DepVal shp_deps; | ||||
@@ -1476,9 +1476,6 @@ SymbolVarArray ParamPackSplit::make(const SymbolVar& src, | |||||
return ret; | return ret; | ||||
} | } | ||||
void ParamPackSplit::scn_do_execute() { | |||||
} | |||||
void ParamPackSplit::init_output_dtype() { | void ParamPackSplit::init_output_dtype() { | ||||
// already initialized in constructor | // already initialized in constructor | ||||
} | } | ||||
@@ -1518,7 +1515,6 @@ void ParamPackSplit::init_output_static_infer_desc() { | |||||
MGB_IMPL_OPR_GRAD(ParamPackSplit) { | MGB_IMPL_OPR_GRAD(ParamPackSplit) { | ||||
mgb_assert(out_grad.size() == opr.output().size()); | mgb_assert(out_grad.size() == opr.output().size()); | ||||
SmallVector<SymbolVar> grad; | SmallVector<SymbolVar> grad; | ||||
// last var is workspace, ignore it | |||||
for (size_t i = 0; i < out_grad.size(); ++i) { | for (size_t i = 0; i < out_grad.size(); ++i) { | ||||
auto gval = out_grad[i]; | auto gval = out_grad[i]; | ||||
if (!gval) { | if (!gval) { | ||||
@@ -583,7 +583,7 @@ MGB_DEFINE_OPR_CLASS(ParamPackSplit, cg::SingleCNOperatorNodeBase) // { | |||||
std::vector<dt_int32> m_offsets; | std::vector<dt_int32> m_offsets; | ||||
std::vector<bool> m_mem_fwd_success; | std::vector<bool> m_mem_fwd_success; | ||||
void scn_do_execute() override; | |||||
void scn_do_execute() override{}; | |||||
void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
bool infer_shape(size_t index, TensorShape &dest, | bool infer_shape(size_t index, TensorShape &dest, | ||||
const cg::static_infer::InpVal &inp); | const cg::static_infer::InpVal &inp); | ||||
@@ -1898,15 +1898,15 @@ void test_param_pack_concat(const TensorShapeArray &shapes, DType type){ | |||||
srcs.push_back(nd); | srcs.push_back(nd); | ||||
} | } | ||||
auto host_table_gen = megdnn::ParamPackSplit::gen_offsets(shapes, | |||||
auto host_offsets_gen = megdnn::ParamPackConcat::gen_offsets(shapes, | |||||
cn.get_mem_addr_alignment(), 4); | cn.get_mem_addr_alignment(), 4); | ||||
ASSERT_EQ(host_table_gen.size(), size * 2); | |||||
auto host_table = std::make_shared<HostTensorND>(); | |||||
host_table->comp_node(cn).dtype(dtype::Int32{}).resize({size * 2}); | |||||
memcpy(host_table->raw_ptr(), host_table_gen.data(), size * 8); | |||||
auto table = opr::Host2DeviceCopy::make(*graph, host_table); | |||||
ASSERT_EQ(host_offsets_gen.back(), size); | |||||
auto host_offsets = std::make_shared<HostTensorND>(); | |||||
host_offsets->comp_node(cn).dtype(dtype::Int32{}).resize({srcs.size() * 2}); | |||||
memcpy(host_offsets->raw_ptr(), host_offsets_gen.data(), srcs.size() * 8); | |||||
auto offsets = opr::Host2DeviceCopy::make(*graph, host_offsets); | |||||
auto z = opr::ParamPackConcat::make(srcs, table, host_table_gen); | |||||
auto z = opr::ParamPackConcat::make(srcs, offsets, host_offsets_gen); | |||||
HostTensorND host_z; | HostTensorND host_z; | ||||
auto func = graph->compile({make_callback_copy(z, host_z)}); | auto func = graph->compile({make_callback_copy(z, host_z)}); | ||||
@@ -1944,17 +1944,17 @@ void test_param_pack_split(const TensorShapeArray& shapes) { | |||||
auto make_graph = [&](const typename Checker::SymInpArray& inputs) -> | auto make_graph = [&](const typename Checker::SymInpArray& inputs) -> | ||||
typename Checker::SymOutArray { | typename Checker::SymOutArray { | ||||
auto table_val = megdnn::ParamPackSplit::gen_offsets( | |||||
auto offsets_val = megdnn::ParamPackConcat::gen_offsets( | |||||
shapes, cn.get_mem_addr_alignment(), 4); | shapes, cn.get_mem_addr_alignment(), 4); | ||||
HostTensorND table; | |||||
std::copy_n(table_val.data(), table_val.size(), | |||||
table.dtype(dtype::Int32{}) | |||||
HostTensorND offsets; | |||||
std::copy_n(offsets_val.data(), offsets_val.size(), | |||||
offsets.dtype(dtype::Int32{}) | |||||
.comp_node(cn) | .comp_node(cn) | ||||
.resize({table_val.size()}) | |||||
.resize({offsets_val.size()}) | |||||
.ptr<dt_int32>()); | .ptr<dt_int32>()); | ||||
auto sym_table = opr::SharedDeviceTensor::make( | |||||
*inputs[0].node()->owner_graph(), table); | |||||
auto out = opr::ParamPackSplit::make(inputs[0], sym_table, table_val, | |||||
auto sym_offsets = opr::SharedDeviceTensor::make( | |||||
*inputs[0].node()->owner_graph(), offsets); | |||||
auto out = opr::ParamPackSplit::make(inputs[0], sym_offsets, offsets_val, | |||||
shapes); | shapes); | ||||
mgb_assert(out.size() == nr_out); | mgb_assert(out.size() == nr_out); | ||||
typename Checker::SymOutArray ret; | typename Checker::SymOutArray ret; | ||||