diff --git a/dnn/include/megdnn/oprs/utils.h b/dnn/include/megdnn/oprs/utils.h index ddf3cedf..af22a5a8 100644 --- a/dnn/include/megdnn/oprs/utils.h +++ b/dnn/include/megdnn/oprs/utils.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "megdnn/internal/opr_header_prologue.h" @@ -94,6 +95,42 @@ class PermutationRNG: public RNGBase { void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); }; +class ShuffleRNGForward : public OperatorBase { + DEF_OPR_IMPL(ShuffleRNGForward, OperatorBase, 1, 2); + DEF_OPR_PARAM(ShuffleRNG); + +public: + virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_tensor_out indices, + _megdnn_workspace workspace) = 0; + void deduce_layout(const TensorLayout& src, TensorLayout& dst, + TensorLayout& indices); + virtual size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& dst, + const TensorLayout& indices) = 0; + +protected: + void check_exec(const TensorLayout& src, const TensorLayout& dst, + const TensorLayout& indices, size_t workspace_in_bytes); +}; +using ShuffleRNG = ShuffleRNGForward; + +class ShuffleRNGBackward : public OperatorBase { + DEF_OPR_IMPL(ShuffleRNGBackward, OperatorBase, 2, 1); + DEF_OPR_PARAM(ShuffleRNG); + +public: + virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices, + _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes(const TensorLayout& diff, + const TensorLayout& indices, + const TensorLayout& grad) = 0; + +protected: + void check_exec(const TensorLayout& diff, const TensorLayout& indices, + const TensorLayout& grad, size_t workspace_in_bytes); +}; + /*! * \brief sleep for specific time on the computing device; useful for testing * async problems diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index e74915aa..d8634b33 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -781,6 +781,9 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) 'Float32 are supported.'), 'DTypeEnum::Int32')) +(pdef('ShuffleRNG'). + add_fields('uint64', 'seed', 0)) + (pdef('Flip'). add_fields('bool', 'vertical', 'false', 'horizontal', 'false')) diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index b2b58b45..81696480 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -165,6 +165,8 @@ private: cb(BetaRNG) \ cb(PoissonRNG) \ cb(PermutationRNG) \ + cb(ShuffleRNGForward) \ + cb(ShuffleRNGBackward) \ cb(SeparableConvForward) \ cb(SeparableFilterForward) \ cb(BNForward) \ diff --git a/dnn/src/common/opr_trait.h b/dnn/src/common/opr_trait.h index 92e109ee..fcf4fa10 100644 --- a/dnn/src/common/opr_trait.h +++ b/dnn/src/common/opr_trait.h @@ -128,6 +128,8 @@ DEF(GammaRNG, 3, true, true); DEF(BetaRNG, 3, true, true); DEF(PoissonRNG, 2, true, true); DEF(PermutationRNG, 1, true, true); +DEF(ShuffleRNGForward, 3, true, true); +DEF(ShuffleRNGBackward, 3, true, false); DEF(ChecksumForward, 1, true, false); DEF(CheckHasInf, 2, true, true); DEF(LSQForward, 5, true, true); diff --git a/dnn/src/common/rng.cpp b/dnn/src/common/rng.cpp index fefb4add..68acbf5f 100644 --- a/dnn/src/common/rng.cpp +++ b/dnn/src/common/rng.cpp @@ -15,6 +15,47 @@ namespace megdnn { +void ShuffleRNGForward::deduce_layout(const TensorLayout& src, + TensorLayout& dst, + TensorLayout& indices) { + dst = src; + indices = TensorLayout(TensorShape({src.shape[0]}), dtype::Int32()); +} +void ShuffleRNGForward::check_exec(const TensorLayout& src, + const TensorLayout& dst, + const TensorLayout& indices, + size_t workspace_in_bytes) { + TensorLayout dst_expected, indices_expected; + megdnn_assert_contiguous(src); + deduce_layout(src, dst_expected, indices_expected); + + megdnn_assert_eq_layout(dst_expected, dst); + megdnn_assert_eq_layout(indices_expected, indices); + megdnn_assert_contiguous(indices); + megdnn_assert(src.dtype == dst.dtype); + megdnn_assert(indices.dtype == dtype::Int32()); + + auto required_workspace_in_bytes = + get_workspace_in_bytes(src, dst, indices); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); +} + +void ShuffleRNGBackward::check_exec(const TensorLayout& diff, + const TensorLayout& indices, + const TensorLayout& grad, + size_t workspace_in_bytes) { + megdnn_assert( + diff.shape[0] == indices.shape[0] && diff.dtype == grad.dtype && + indices.dtype == dtype::Int32{} && diff.is_contiguous() && + indices.is_contiguous() && grad.is_contiguous(), + "invalid layouts: diff=%s indices=%s grad=%s", + diff.to_string().c_str(), indices.to_string().c_str(), + grad.to_string().c_str()); + auto required_workspace_in_bytes = + get_workspace_in_bytes(diff, indices, grad); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); +} + void PermutationRNG::check_exec( const TensorLayout &dst, size_t workspace_in_bytes) { megdnn_assert((dst.dtype == dtype::Float32() || diff --git a/dnn/src/cuda/rng/kernel.cu b/dnn/src/cuda/rng/kernel.cu index 0db1e914..5ce4d704 100644 --- a/dnn/src/cuda/rng/kernel.cu +++ b/dnn/src/cuda/rng/kernel.cu @@ -55,6 +55,42 @@ __global__ void permute_duplicate_keys_kernel(KeyType* keys, ValueType* indexs, } } +template +__global__ void shuffle_fwd_kernel(uint32_t step, uint32_t src_size, const T* sptr, + T* dptr, const int* iptr) { + uint32_t idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < src_size) { + uint32_t r = idx / step; + dptr[idx]=sptr[iptr[r] * step + idx % step]; + } +} +template +void shuffle_forward(T* sptr, T* dptr, dt_int32* iptr, + size_t len, size_t step, cudaStream_t stream) { + uint32_t src_size = len * step; + shuffle_fwd_kernel<<>>( + step, src_size, sptr, dptr, iptr); + after_kernel_launch(); +} + +template +__global__ void shuffle_bwd_kernel(uint32_t step, uint32_t src_size, T* sptr, + T* dptr, const int* iptr) { + uint32_t idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < src_size) { + uint32_t r = idx / step; + sptr[iptr[r] * step + idx % step]=dptr[idx]; + } +} +template +void shuffle_backward(T* dptr, dt_int32* iptr, T* sptr, + size_t len, size_t step, cudaStream_t stream) { + uint32_t src_size = len * step; + shuffle_bwd_kernel<<>>( + step, src_size, sptr, dptr, iptr); + after_kernel_launch(); +} + uint32_t get_permutation_bits(size_t N) { double uniq_rand_num_prob = 0.9; double thresh = std::log(uniq_rand_num_prob) * 12; @@ -156,6 +192,14 @@ INST_PERMUTATION(dt_int16) INST_PERMUTATION(dt_float32) #undef INST_PERMUTATION +#define INST_SHUFFLE(T) \ + template void shuffle_forward(T* sptr, T* dptr, dt_int32* iptr,\ + size_t len, size_t step, cudaStream_t stream);\ + template void shuffle_backward(T* dptr, dt_int32* iptr, T* sptr,\ + size_t len, size_t step, cudaStream_t stream); + + ARGSORT_FOREACH_CTYPE(INST_SHUFFLE) +#undef INST_SHUFFLE } // namespace random #define INST(_dtype) \ diff --git a/dnn/src/cuda/rng/kernel.cuh b/dnn/src/cuda/rng/kernel.cuh index 1ba6cbe8..46b3bb0e 100644 --- a/dnn/src/cuda/rng/kernel.cuh +++ b/dnn/src/cuda/rng/kernel.cuh @@ -253,6 +253,17 @@ void permutation_forward(ctype* dst, void* workspace, size_t size, uint64_t seed size_t get_permutation_workspace_in_bytes(size_t N); +template +void shuffle_forward(T* sptr, T* dptr, dt_int32* iptr, + size_t len, size_t step, cudaStream_t stream); + +template +void shuffle_backward(T* dptr, dt_int32* iptr, T* sptr, + size_t len, size_t step, cudaStream_t stream); + +#define ARGSORT_FOREACH_CTYPE(cb) \ + cb(float) cb(int32_t) DNN_INC_FLOAT16(cb(dt_float16)) + } // namespace random } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/rng/opr_impl.cpp b/dnn/src/cuda/rng/opr_impl.cpp index d31b8d82..5ea71aad 100644 --- a/dnn/src/cuda/rng/opr_impl.cpp +++ b/dnn/src/cuda/rng/opr_impl.cpp @@ -9,11 +9,11 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include "./opr_impl.h" +#include "./kernel.cuh" #include "src/common/utils.h" #include "src/cuda/handle.h" #include "src/cuda/utils.h" -#include "./opr_impl.h" -#include "./kernel.cuh" using namespace megdnn; using namespace cuda; @@ -261,5 +261,76 @@ size_t PermutationRNGImpl::get_workspace_in_bytes(const TensorLayout &layout){ return random::get_permutation_workspace_in_bytes(size); } +ShuffleRNGForwardImpl::ShuffleRNGForwardImpl(Handle* handle) + : ShuffleRNGForward(handle), + m_seed(0), + m_offset(0), + m_stream(cuda_stream(handle)) {} + +void ShuffleRNGForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_tensor_out indices, + _megdnn_workspace workspace) { + check_exec(src.layout, dst.layout, indices.layout, workspace.size); + ensure_seed(m_param.seed); + auto wk = workspace.ptr(); + const auto len = indices.layout[0]; + random::permutation_forward(indices.ptr(), wk, len, + m_seed, m_offset, m_stream); + size_t step = 0; + for (size_t i = 1; i < src.layout.ndim; ++i) { + step += src.layout[i]; + } + if (step <= 0) + step = 1; + switch (src.layout.dtype.enumv()) { +#define cb(DType) \ + case DTypeTrait::enumv: \ + random::shuffle_forward::ctype>( \ + src.ptr::ctype>(), \ + dst.ptr::ctype>(), indices.ptr(), \ + len, step, m_stream); \ + break; + ARGSORT_FOREACH_CTYPE(cb) +#undef cb + default : megdnn_throw("bad dtype"); + } + m_offset += 8; +} + +size_t ShuffleRNGForwardImpl::get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, + const TensorLayout& indices) { + size_t size = indices.total_nr_elems(); + return random::get_permutation_workspace_in_bytes(size); +} + +ShuffleRNGBackwardImpl::ShuffleRNGBackwardImpl(Handle* handle) + : ShuffleRNGBackward(handle), m_stream(cuda_stream(handle)) {} + +void ShuffleRNGBackwardImpl::exec(_megdnn_tensor_in diff, + _megdnn_tensor_in indices, + _megdnn_tensor_out grad, + _megdnn_workspace workspace) { + const auto len = indices.layout[0]; + auto step = 0; + for (size_t i = 1; i < diff.layout.ndim; ++i) { + step += diff.layout[i]; + } + if (step <= 0) + step = 1; + switch (diff.layout.dtype.enumv()) { +#define cb(DType) \ + case DTypeTrait::enumv: \ + random::shuffle_backward::ctype>( \ + diff.ptr::ctype>(), indices.ptr(), \ + grad.ptr::ctype>(), len, step, m_stream); \ + break; + ARGSORT_FOREACH_CTYPE(cb) +#undef cb + default: + megdnn_throw("bad dtype"); + } +} + // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/rng/opr_impl.h b/dnn/src/cuda/rng/opr_impl.h index 451c224f..83b4453b 100644 --- a/dnn/src/cuda/rng/opr_impl.h +++ b/dnn/src/cuda/rng/opr_impl.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -152,6 +153,45 @@ public: } }; +class ShuffleRNGForwardImpl : public ShuffleRNGForward { + uint64_t m_seed, m_offset; + cudaStream_t m_stream; + +public: + using ShuffleRNGForward::ShuffleRNGForward; + ShuffleRNGForwardImpl(Handle* handle); + + void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_tensor_out indices, _megdnn_workspace workspace) override; + + size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& dst, + const TensorLayout& indices) override; + + void seed(uint64_t seed) { m_seed = seed; } + + void ensure_seed(uint64_t seed) { + if (m_seed != seed) { + this->seed(seed); + } + } +}; + +class ShuffleRNGBackwardImpl : public ShuffleRNGBackward { + cudaStream_t m_stream; + +public: + using ShuffleRNGBackward::ShuffleRNGBackward; + ShuffleRNGBackwardImpl(Handle* handle); + void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices, + _megdnn_tensor_out grad, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, + const TensorLayout&, + const TensorLayout&) override { + return 0; + } +}; + } // namespace cuda } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/rng/opr_impl.cpp b/dnn/src/naive/rng/opr_impl.cpp index 3d24adb4..f8ec1c2a 100644 --- a/dnn/src/naive/rng/opr_impl.cpp +++ b/dnn/src/naive/rng/opr_impl.cpp @@ -6,12 +6,13 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ -#include "src/naive/handle.h" -#include "src/common/utils.h" #include "./opr_impl.h" +#include "src/common/utils.h" +#include "src/naive/handle.h" #include @@ -229,7 +230,29 @@ namespace { } } -} // anonymous namespace + template + void shuffle_fwd(const T* __restrict sptr, T* __restrict dptr, + const dt_int32* iptr, const size_t len, + const size_t step) MEGDNN_NOEXCEPT { + for (size_t i = 0; i < len; ++i) { + for (size_t j = 0; j < step; ++j) { + dptr[i * step + j] = sptr[iptr[i] * step + j]; + } + } + } + + template + void shuffle_bwd(T* __restrict sptr, const T* __restrict dptr, + const dt_int32* iptr, const size_t len, + const size_t step) MEGDNN_NOEXCEPT { + for (size_t i = 0; i < len; ++i) { + for (size_t j = 0; j < step; ++j) { + sptr[iptr[i] * step + j] = dptr[i * step + j]; + } + } + } + +} // anonymous namespace uint64_t Splitmix64::operator() () { uint64_t z = (m_s += UINT64_C(0x9E3779B97F4A7C15)); @@ -394,5 +417,54 @@ void PermutationRNGImpl::exec( } } -// vim: syntax=cpp.doxygen +void ShuffleRNGForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_tensor_out indices, + _megdnn_workspace workspace) { + check_exec(src.layout, dst.layout, indices.layout, workspace.size); + const auto len = indices.layout[0]; + auto iptr = indices.ptr(); + auto prng = &m_rng.ensure_seed(m_param.seed); + fill_permutation(prng, iptr, len); + auto step = 0; + for (size_t i = 1; i < src.layout.ndim; ++i) { + step += src.layout[i]; + } + if (step <= 0) + step = 1; + +#define cb(DType) \ + if (src.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + shuffle_fwd(src.ptr(), dst.ptr(), iptr, len, step)); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb +} +void ShuffleRNGBackwardImpl::exec(_megdnn_tensor_in diff, + _megdnn_tensor_in indices, + _megdnn_tensor_out grad, + _megdnn_workspace workspace) { + check_exec(diff.layout, indices.layout, grad.layout, workspace.size); + const auto len = indices.layout[0]; + auto iptr = indices.ptr(); + auto step = 0; + for (size_t i = 1; i < diff.layout.ndim; ++i) { + step += diff.layout[i]; + } + if (step <= 0) + step = 1; +#define cb(DType) \ + if (diff.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(shuffle_bwd( \ + grad.ptr(), diff.ptr(), iptr, len, step)); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/rng/opr_impl.h b/dnn/src/naive/rng/opr_impl.h index 370c55d2..739af6d6 100644 --- a/dnn/src/naive/rng/opr_impl.h +++ b/dnn/src/naive/rng/opr_impl.h @@ -128,6 +128,35 @@ public: size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } }; + +class ShuffleRNGForwardImpl : public ShuffleRNGForward { + Xoroshiro128plus m_rng; + +public: + using ShuffleRNGForward::ShuffleRNGForward; + void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_tensor_out indices, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, + const TensorLayout&, + const TensorLayout&) override { + return 0; + } +}; + +class ShuffleRNGBackwardImpl : public ShuffleRNGBackward { + Xoroshiro128plus m_rng; + +public: + using ShuffleRNGBackward::ShuffleRNGBackward; + void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices, + _megdnn_tensor_out grad, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, + const TensorLayout&, + const TensorLayout&) override { + return 0; + } +}; + } // namespace naive } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/test/cuda/rng.cpp b/dnn/test/cuda/rng.cpp index 77ba9602..eba1c022 100644 --- a/dnn/test/cuda/rng.cpp +++ b/dnn/test/cuda/rng.cpp @@ -143,6 +143,60 @@ void run_permutation(Handle* handle) { } } +template +void run_shuffle(Handle* handle, bool bwd_flag) { + using ctype = typename DTypeTrait::ctype; + auto run = [&](TensorShape shape) { + auto opr = handle->create_operator(); + TensorLayout srclay{shape, T()}; + TensorLayout dstlay{shape, T()}; + TensorLayout indexlay{TensorShape{shape[0]}, dtype::Int32()}; + Tensor workspace( + handle, {TensorShape{opr->get_workspace_in_bytes(srclay, dstlay, + indexlay)}, + dtype::Byte()}); + SyncedTensor src(handle, srclay); + SyncedTensor dst(handle, dstlay); + SyncedTensor::ctype> index(handle, indexlay); + auto sptr = src.ptr_mutable_host(); + size_t size = src.layout().total_nr_elems(); + for (size_t j = 0; j < size; ++j) { + sptr[j] = j; + } + opr->exec(src.tensornd_dev(), dst.tensornd_dev(), index.tensornd_dev(), + {workspace.ptr(), workspace.layout().total_nr_elems()}); + + auto dptr = dst.ptr_mutable_host(); + auto iptr = index.ptr_mutable_host(); + size_t len = index.layout().total_nr_elems(); + size_t step = size / len; + for (size_t i = 0; i < len; ++i) { + for (size_t j = 0; j < step; ++j) { + ASSERT_EQ(dptr[i * step + j], sptr[iptr[i] * step + j]); + } + } + if (bwd_flag) { + for (size_t j = 0; j < size; ++j) { + sptr[j] = 0; + } + auto oprbwd = handle->create_operator(); + oprbwd->exec( + dst.tensornd_dev(), index.tensornd_dev(), + src.tensornd_dev(), + {workspace.ptr(), workspace.layout().total_nr_elems()}); + auto sptr_bwd = src.ptr_mutable_host(); + for (size_t i = 0; i < len; ++i) { + for (size_t j = 0; j < step; ++j) { + ASSERT_EQ(dptr[i * step + j], sptr_bwd[iptr[i] * step + j]); + } + } + } + }; + + run({10}); + run({6, 3}); +} + } // anonymous namespace TEST_F(CUDA, UNIFORM_RNG_F32) { @@ -215,6 +269,30 @@ TEST_F(CUDA, PERMUTATION_RNG_INT16) { run_permutation(handle_cuda()); } +TEST_F(CUDA, SHUFFLE_RNG_F32) { + run_shuffle(handle_cuda(), false); +} + +TEST_F(CUDA, SHUFFLE_RNG_INT32) { + run_shuffle(handle_cuda(), false); +} + +TEST_F(CUDA, SHUFFLE_RNG_F16) { + run_shuffle(handle_cuda(), false); +} + +TEST_F(CUDA, SHUFFLE_RNG_BWD_F32) { + run_shuffle(handle_cuda(), true); +} + +TEST_F(CUDA, SHUFFLE_RNG_BWD_INT32) { + run_shuffle(handle_cuda(), true); +} + +TEST_F(CUDA, SHUFFLE_RNG_BWD_F16) { + run_shuffle(handle_cuda(), true); +} + } // namespace test } // namespace megdnn diff --git a/dnn/test/naive/rng.cpp b/dnn/test/naive/rng.cpp index 75a82223..32b65b7a 100644 --- a/dnn/test/naive/rng.cpp +++ b/dnn/test/naive/rng.cpp @@ -6,12 +6,13 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ -#include "megdnn.h" -#include "test/naive/fixture.h" #include "test/naive/rng.h" +#include "megdnn.h" #include "test/common/tensor.h" +#include "test/naive/fixture.h" namespace megdnn { @@ -181,7 +182,59 @@ namespace { ASSERT_LE(std::abs(res[i] - ctype(i)), 1e-8); } } -} + + template + void run_shuffle(Handle* handle, bool bwd_flag) { + using ctype = typename DTypeTrait::ctype; + auto run = [&](TensorShape shape) { + auto opr = handle->create_operator(); + TensorLayout srclay{shape, T()}; + TensorLayout dstlay{shape, T()}; + TensorLayout indexlay{TensorShape{shape[0]}, dtype::Int32()}; + Tensor workspace( + handle, {TensorShape{opr->get_workspace_in_bytes(srclay, dstlay, + indexlay)}, + dtype::Byte()}); + Tensor src(handle, srclay); + Tensor dst(handle, dstlay); + Tensor::ctype> index(handle, indexlay); + auto sptr = src.ptr(); + size_t size = src.layout().total_nr_elems(); + for (size_t j = 0; j < size; ++j) { + sptr[j] = j; + } + opr->exec(src.tensornd(), dst.tensornd(), index.tensornd(), + {workspace.ptr(), workspace.layout().total_nr_elems()}); + + auto dptr = dst.ptr(); + auto iptr = index.ptr(); + size_t len = index.layout().total_nr_elems(); + size_t step = size / len; + for (size_t i = 0; i < len; ++i) { + for (size_t j = 0; j < step; ++j) { + ASSERT_EQ(dptr[i * step + j], sptr[iptr[i] * step + j]); + } + } + if (bwd_flag) { + for (size_t j = 0; j < size; ++j) { + sptr[j] = 0; + } + auto oprbwd = handle->create_operator(); + oprbwd->exec( + dst.tensornd(), index.tensornd(), src.tensornd(), + {workspace.ptr(), workspace.layout().total_nr_elems()}); + for (size_t i = 0; i < len; ++i) { + for (size_t j = 0; j < step; ++j) { + ASSERT_EQ(dptr[i * step + j], sptr[iptr[i] * step + j]); + } + } + } + }; + + run({10}); + run({6, 3}); + } +} // namespace TEST_F(NAIVE, UNIFORM_RNG_F32) { run_uniform(handle()); @@ -235,10 +288,31 @@ TEST_F(NAIVE, PERMUTATION_RNG_INT16) { run_permutation(handle()); } -} // namespace test -} // namespace megdnn +TEST_F(NAIVE, SHUFFLE_RNG_FWD_F32) { + run_shuffle(handle(), false); +} -// vim: syntax=cpp.doxygen +TEST_F(NAIVE, SHUFFLE_RNG_FWD_INT32) { + run_shuffle(handle(), false); +} +TEST_F(NAIVE, SHUFFLE_RNG_FWD_F16) { + run_shuffle(handle(), false); +} +TEST_F(NAIVE, SHUFFLE_RNG_BWD_F32) { + run_shuffle(handle(), true); +} +TEST_F(NAIVE, SHUFFLE_RNG_BWD_INT32) { + run_shuffle(handle(), true); +} + +TEST_F(NAIVE, SHUFFLE_RNG_BWD_F16) { + run_shuffle(handle(), true); +} + +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/imperative/python/megengine/random/__init__.py b/imperative/python/megengine/random/__init__.py index e59a2a56..8a642240 100644 --- a/imperative/python/megengine/random/__init__.py +++ b/imperative/python/megengine/random/__init__.py @@ -6,7 +6,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from .rng import RNG, beta, gamma, normal, permutation, poisson, seed, uniform +from .rng import RNG, beta, gamma, normal, permutation, poisson, seed, shuffle, uniform __all__ = [ "RNG", @@ -17,6 +17,7 @@ __all__ = [ "poisson", "seed", "uniform", + "shuffle", ] # pylint: disable=undefined-variable del rng # type: ignore[name-defined] diff --git a/imperative/python/megengine/random/rng.py b/imperative/python/megengine/random/rng.py index d61928f4..45c19090 100644 --- a/imperative/python/megengine/random/rng.py +++ b/imperative/python/megengine/random/rng.py @@ -27,6 +27,7 @@ from ..core.ops.builtin import ( GaussianRNG, PermutationRNG, PoissonRNG, + ShuffleRNG, UniformRNG, ) from ..core.tensor import utils @@ -41,6 +42,7 @@ __all__ = [ "beta", "poisson", "permutation", + "shuffle", ] _rng = None @@ -219,6 +221,13 @@ def _permutation(n: int, seed: int, device: str, handle: int, dtype: str) -> Ten return output +def _shuffle(inp: Tensor, seed: int, handle: int) -> Tensor: + assert inp.size > 0, "size needs to be greater than 0" + op = ShuffleRNG(seed=seed, handle=handle) + output, _ = apply(op, inp) + inp._reset(output) + + class RNG: r""":class:`RNG` exposes a number of methods for generating random numbers. @@ -581,6 +590,45 @@ class RNG: n=n, seed=_seed, device=self._device, handle=self._handle, dtype=dtype ) + def shuffle(self, inp: Tensor): + r"""Modify a sequence in-place by shuffling its contents. + This function only shuffles the Tensor along the first axis of a multi-dimensional Tensor. + The order of sub-Tensors is changed but their contents remains the same. + + Args: + inp: input tensor. + + Examples: + + .. testcode:: + + import numpy as np + import megengine as mge + import megengine.random as rand + + x = mge.tensor(np.arange(10)) + rand.shuffle(x) + print(x.numpy()) + y = mge.tensor(np.arange(18)).reshape(6,3) + rand.shuffle(y) + print(y.numpy()) + + Outputs: + + .. testoutput:: + :options: +SKIP + + [7 9 3 0 8 2 4 5 6 1] + [[12. 13. 14.] + [ 3. 4. 5.] + [15. 16. 17.] + [ 0. 1. 2.] + [ 9. 10. 11.] + [ 6. 7. 8.]] + """ + _seed = self._seed() if callable(self._seed) else self._seed + _shuffle(inp=inp, seed=_seed, handle=self._handle) + def __del__(self): if self._handle != 0: _delete_rng_handle(self._handle) @@ -599,6 +647,7 @@ gamma = _default_handle.gamma beta = _default_handle.beta poisson = _default_handle.poisson permutation = _default_handle.permutation +shuffle = _default_handle.shuffle def _random_seed_generator(): diff --git a/imperative/python/test/unit/random/test_rng.py b/imperative/python/test/unit/random/test_rng.py index ab575fe3..6df000bb 100644 --- a/imperative/python/test/unit/random/test_rng.py +++ b/imperative/python/test/unit/random/test_rng.py @@ -18,6 +18,7 @@ from megengine.core._imperative_rt.ops import ( get_global_rng_seed, new_rng_handle, ) +from megengine.core.autodiff.grad import Grad from megengine.core.ops.builtin import ( BetaRNG, GammaRNG, @@ -397,6 +398,45 @@ def test_PermutationRNG(): assert sum_result(out, np.sort) == 1000 +@pytest.mark.skipif( + get_device_count("xpu") <= 1, reason="xpu counts need > 1", +) +def test_ShuffleRNG(): + g = [] + + def cb(grad): + g.append(grad) + + n, m = 6, 3 + arr = np.arange(n * m) + out0 = Tensor(arr, dtype="float32") + grad = Grad().wrt(out0, callback=cb) + random.shuffle(out0) + grad(out0, F.ones_like(out0)) + m1 = RNG(seed=111, device="xpu0") + m2 = RNG(seed=111, device="xpu1") + m3 = RNG(seed=222, device="xpu0") + out1 = Tensor(arr, dtype="float32", device="xpu0") + out2 = Tensor(arr, dtype="float32", device="xpu1") + out3 = Tensor(arr, dtype="float32", device="xpu0") + m1.shuffle(out1) + m2.shuffle(out2) + m3.shuffle(out3) + + np.testing.assert_equal(out1.numpy(), out2.numpy()) + assert out1.device == "xpu0" and out2.device == "xpu1" + assert not (out1.numpy() == out3.numpy()).all() + + out = Tensor(arr, dtype="float32").reshape(n, m) + m1.shuffle(out) + + out_shp = out.shape + if isinstance(out_shp, tuple): + assert out_shp == (n, m) + else: + assert all(out.shape.numpy() == np.array([n, m])) + + def test_seed(): set_global_seed(10) out1 = uniform(size=[10, 10]) diff --git a/imperative/src/impl/ops/rng.cpp b/imperative/src/impl/ops/rng.cpp index 441bbcb5..9ebc09ac 100644 --- a/imperative/src/impl/ops/rng.cpp +++ b/imperative/src/impl/ops/rng.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "megbrain/imperative/ops/rng.h" @@ -14,8 +15,8 @@ #include "megbrain/graph/helper.h" #include "megbrain/opr/rand.h" -#include "../op_trait.h" #include "../dnn_op_helper.h" +#include "../op_trait.h" namespace mgb::imperative::rng { @@ -259,13 +260,27 @@ struct OpMeth { } }; +template <> +struct OpMeth { + using DnnOp = megdnn::ShuffleRNG; + using Param = DnnOp::Param; + using OpNode = mgb::opr::ShuffleRNG; + static Param make_param(const ShuffleRNG& rng) { + auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); + mgb_assert(handle_seed == rng.seed, + "inconsistent rng seed: rng op: %lu handle: %lu", + handle_seed, rng.seed); + return {handle_seed}; + } +}; + template struct _InferLayout; template struct _RNGOprMaker; -template +template struct _RNGOprInvoker; template<> @@ -316,50 +331,63 @@ struct _InferLayout return inp.layout; } }; - -#define _INST_RNG_INVOLKER(DNN_NR_INPUTS) \ -template<> \ -struct _RNGOprInvoker { \ - template \ - static void exec(Opr *dnn_op, const SmallVector& inputs,const TensorPtr& dest){ \ - size_t wk_size = 0; \ - wk_size = dnn_op->get_workspace_in_bytes(_FOR_EACH_IN(->layout())dest->layout()); \ - auto workspace = Blob::make(dest->comp_node(), wk_size); \ - megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); \ - dnn_op->exec(_FOR_EACH_IN(->dev_tensor().as_megdnn()) \ - dest->dev_tensor().as_megdnn(), dnn_wk); \ - } \ -}; +#define _INST_RNG_INVOLKER(DNN_NR_INPUTS, DNN_NR_OUTPUTS) \ + template <> \ + struct _RNGOprInvoker { \ + template \ + static void exec(Opr* dnn_op, const SmallVector& inputs, \ + const SmallVector& outputs) { \ + size_t wk_size = 0; \ + wk_size = dnn_op->get_workspace_in_bytes( \ + _FOR_EACH_IN(->layout()) _FOR_EACH_OUT(->layout())); \ + auto workspace = Blob::make(outputs[0]->comp_node(), wk_size); \ + megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); \ + dnn_op->exec(_FOR_EACH_IN(->dev_tensor().as_megdnn()) \ + _FOR_EACH_OUT(->dev_tensor().as_megdnn()), \ + dnn_wk); \ + } \ + }; -#define _INST_RNG_MAKER(MGB_NR_INPUTS) \ -template<> \ -struct _RNGOprMaker { \ - template \ - static SymbolVar make(const VarNodeArray& inputs, const Op& rng){ \ - auto param = OpMeth::make_param(rng); \ - OperatorNodeConfig config; \ - if (rng.handle) { \ - config = {rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)}; \ - } else { \ - config = {rng.make_name()}; \ - } \ - return OpMeth::OpNode::make(_FOR_EACH_IN() param, config); \ - } \ -}; +#define _INST_RNG_MAKER(MGB_NR_INPUTS) \ + template <> \ + struct _RNGOprMaker { \ + template \ + static auto make(const VarNodeArray& inputs, const Op& rng) { \ + auto param = OpMeth::make_param(rng); \ + OperatorNodeConfig config; \ + if (rng.handle) { \ + config = {rng.make_name(), \ + RNGDnnOpManager::get_comp_node(rng.handle)}; \ + } else { \ + config = {rng.make_name()}; \ + } \ + return OpMeth::OpNode::make(_FOR_EACH_IN() param, config); \ + } \ + }; -#define _FOR_EACH_IN(subfix) -_INST_RNG_INVOLKER(0) +#define _FOR_EACH_IN(subfix) +#define _FOR_EACH_OUT(subfix) outputs[0] subfix +_INST_RNG_INVOLKER(0, 1) +#undef _FOR_EACH_OUT #undef _FOR_EACH_IN #define _FOR_EACH_IN(subfix) inputs[0] subfix, -_INST_RNG_INVOLKER(1) +#define _FOR_EACH_OUT(subfix) outputs[0] subfix +_INST_RNG_INVOLKER(1, 1) +#undef _FOR_EACH_OUT + +#define _FOR_EACH_OUT(subfix) outputs[0] subfix, outputs[1] subfix +_INST_RNG_INVOLKER(1, 2) _INST_RNG_MAKER(1) +#undef _FOR_EACH_OUT #undef _FOR_EACH_IN #define _FOR_EACH_IN(subfix) inputs[0] subfix, inputs[1] subfix, -_INST_RNG_INVOLKER(2) +#define _FOR_EACH_OUT(subfix) outputs[0] subfix +_INST_RNG_INVOLKER(2, 1) _INST_RNG_MAKER(2) +#undef _FOR_EACH_OUT #undef _FOR_EACH_IN #undef _INST_RNG_INVOLKER @@ -392,7 +420,9 @@ void exec(const OpDef& op, const SmallVector& inputs, handle_seed, dnn_op->param().seed); } dnn_op->param() = OpMeth::make_param(rng); - _RNGOprInvoker::DnnOp::NR_INPUTS>::exec(dnn_op,inputs,dest); + _RNGOprInvoker::DnnOp::NR_INPUTS, + OpMeth::DnnOp::NR_OUTPUTS>::exec(dnn_op, inputs, + outputs); } template @@ -420,24 +450,45 @@ SmallVector infer_output_attrs( return {dest}; } -template -std::tuple, SmallVector> infer_output_mem_desc( - const OpDef& def, - const SmallVector& inputs_tensors, - const SmallVector& inputs_mems) { - auto &&dest = infer_output_attrs(def, inputs_tensors); - SmallVector outputs = {{dest[0].layout, 0, dest[0].comp_node, StorageIdentifier::make(1)}}; - - return {outputs, {}}; +template <> +SmallVector infer_output_attrs( + const OpDef& op, const SmallVector& inputs) { + SmallVector dests(2); + auto&& rng = op.cast_final_safe(); + auto handle = rng.handle; + if (handle) { + dests[0].comp_node = RNGDnnOpManager::get_comp_node(handle); + dests[1].comp_node = RNGDnnOpManager::get_comp_node(handle); + } else { + dests[0].comp_node = inputs[0]->comp_node(); + dests[1].comp_node = inputs[0]->comp_node(); + } + dests[0].layout = TensorLayout(inputs[0]->layout()); + dests[0].layout.dtype = inputs[0]->layout().dtype; + dests[1].layout = + TensorLayout(TensorShape({inputs[0]->layout()[0]}), dtype::Int32()); + return dests; } +template +std::tuple, SmallVector> +infer_output_mem_desc(const OpDef& def, + const SmallVector& inputs_tensors, + const SmallVector& inputs_mems) { + auto&& dests = infer_output_attrs(def, inputs_tensors); + SmallVector outputs; + for (size_t i = 0; i < dests.size(); ++i) { + outputs.push_back({dests[i].layout, 0, dests[i].comp_node, + StorageIdentifier::make(i + 1)}); + } + return {outputs, {}}; +} template SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs) { SmallVector outputs; - SmallVector desc; - desc = infer_output_attrs(def, inputs); + SmallVector desc = infer_output_attrs(def, inputs); for (auto&& i : desc) { outputs.push_back(Tensor::make(i.layout, i.comp_node)); } @@ -454,10 +505,8 @@ void execute( exec(def, inputs, outputs, {}); } -template -SymbolVar apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +template +Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { size_t nr_inp = inputs.size(); constexpr size_t dnn_nr_inp = OpMeth::DnnOp::NR_INPUTS; auto&& rng = def.cast_final_safe(); @@ -487,7 +536,21 @@ std::tuple, bool> infer_output_attrs_fallible( return {{dest}, true}; } -} // anonymous namespace +template <> +std::tuple, bool> +infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + SmallVector dests(2); + dests[0].comp_node = inputs[0].comp_node; + dests[0].layout = TensorLayout(inputs[0].layout); + dests[0].layout.dtype = inputs[0].layout.dtype; + dests[1].comp_node = inputs[0].comp_node; + dests[1].layout = TensorLayout(TensorShape({inputs[0].layout.shape[0]}), + dtype::Int32()); + return {dests, true}; +} + +} // anonymous namespace Handle new_handle(CompNode comp_node, uint64_t seed) { return RNGDnnOpManager::inst().new_handle(comp_node, seed); @@ -509,23 +572,24 @@ CompNode get_rng_handle_compnode(Handle handle){ return RNGDnnOpManager::get_comp_node(handle); } -#define REG_RNG_OP(NAME)\ -namespace { \ -OP_TRAIT_REG(NAME, NAME, OpMeth::OpNode) \ - .apply_on_var_node(apply_on_var_node) \ - .apply_on_physical_tensor(apply_on_physical_tensor) \ - .infer_output_attrs_fallible(infer_output_attrs_fallible) \ - .infer_output_mem_desc(infer_output_mem_desc) \ - .execute(execute) \ - .fallback(); \ -} \ - -REG_RNG_OP(UniformRNG) -REG_RNG_OP(GaussianRNG) -REG_RNG_OP(GammaRNG) -REG_RNG_OP(PermutationRNG) -REG_RNG_OP(PoissonRNG) -REG_RNG_OP(BetaRNG) +#define REG_RNG_OP(NAME, Output) \ + namespace { \ + OP_TRAIT_REG(NAME, NAME, OpMeth::OpNode) \ + .apply_on_var_node(apply_on_var_node) \ + .apply_on_physical_tensor(apply_on_physical_tensor) \ + .infer_output_attrs_fallible(infer_output_attrs_fallible) \ + .infer_output_mem_desc(infer_output_mem_desc) \ + .execute(execute) \ + .fallback(); \ + } + +REG_RNG_OP(UniformRNG, SymbolVar) +REG_RNG_OP(GaussianRNG, SymbolVar) +REG_RNG_OP(GammaRNG, SymbolVar) +REG_RNG_OP(PermutationRNG, SymbolVar) +REG_RNG_OP(PoissonRNG, SymbolVar) +REG_RNG_OP(BetaRNG, SymbolVar) +REG_RNG_OP(ShuffleRNG, SymbolVarArray) #undef REG_RNG_OP } // namespace mgb::imperative::rng diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 3d29ce37..0fcc13ee 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -215,6 +215,19 @@ def PermutationRNG: MgbHashableOp<"PermutationRNG", [PermutationRNGParam]> { let cmpFunction = [{return $0.handle == $1.handle && $0.dtype == $1.dtype;}]; } +def ShuffleRNG: MgbHashableOp<"ShuffleRNG", [ShuffleRNGParam]> { + let extraArguments = (ins + MgbSizeTAddr:$handle + ); + let hashFunction = [{ + return mgb::hash_pair_combine( + mgb::hash($_self.dyn_typeinfo()), + mgb::hash($_self.handle) + ); + }]; + let cmpFunction = [{return $0.handle == $1.handle;}]; +} + def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> { let extraArguments = (ins MgbCompNodeAttr:$comp_node diff --git a/src/opr/impl/rand.cpp b/src/opr/impl/rand.cpp index 7cd3f076..02e91c69 100644 --- a/src/opr/impl/rand.cpp +++ b/src/opr/impl/rand.cpp @@ -192,6 +192,8 @@ template class RNGOprBase<::megdnn::GammaRNG>; template class RNGOprBase<::megdnn::PermutationRNG>; template class RNGOprBase<::megdnn::BetaRNG>; template class RNGOprBase<::megdnn::PoissonRNG>; +template class RNGOprBase<::megdnn::ShuffleRNGForward>; +template class RNGOprBase<::megdnn::ShuffleRNGBackward>; #if MGB_ENABLE_GRAD IMPL(GaussianRNG); IMPL(UniformRNG); @@ -200,9 +202,87 @@ IMPL(PoissonRNG); IMPL(PermutationRNG); IMPL(BetaRNG); #endif -} +} // namespace intl +} // namespace opr +} // namespace mgb + +/* ================= ShuffleRNGForward ================= */ + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleRNGForward); + +ShuffleRNGForward::ShuffleRNGForward(VarNode* data, const Param& param, + const OperatorNodeConfig& config) + : Super({data->owner_graph(), config, "shuffle_rng", {data}}, param) { + add_input({data}); + add_output(None)->dtype(data->dtype()); + add_output(None)->dtype(dtype::Int32{}); + cg::add_workspace_output(this); + add_equivalence_component>(this); +} + +SymbolVarArray ShuffleRNGForward::make(SymbolVar in_tensor, const Param& param, + const OperatorNodeConfig& config) { + auto node = in_tensor.node()->owner_graph()->insert_opr( + std::make_unique(in_tensor.node(), param, + config)); + mgb_assert(node->output().size() == 3); + return {node->output(0), node->output(1)}; } + +void ShuffleRNGForward::init_output_static_infer_desc() { + using namespace cg::static_infer; + auto&& mgr = owner_graph()->static_infer_manager(); + + mgr.register_shape_infer(output(0), + ShapeInferDesc::make_identity(input(0))); + + auto infer_oshp1 = [this](TensorShape& dest, const InpVal& iv) { + TensorLayout o0, o1; + m_dnn_opr->deduce_layout({iv.val[0].shape(), input(0)->dtype()}, o0, + o1); + dest = o1; + return true; + }; + mgr.register_shape_infer( + output(1), + {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_oshp1}); + + auto infer_wk = [this](TensorShape& dest, const InpVal& inp) { + ensure_megdnn_opr(); + dest.ndim = 1; + dest.shape[0] = m_dnn_opr->get_workspace_in_bytes( + {inp.val[0].shape(), input(0)->dtype()}, + {output(0)->shape(), output(0)->dtype()}, + {output(1)->shape(), output(1)->dtype()}); + return true; + }; + mgr.register_shape_infer( + output(2), + {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_wk}); } +void ShuffleRNGForward::add_input_layout_constraint() { + input(0)->add_layout_constraint_contiguous(); +}; + +void ShuffleRNGForward::scn_do_execute() { + m_dnn_opr->exec(input(0)->dev_tensor().as_megdnn(), + output(0)->dev_tensor().as_megdnn(), + output(1)->dev_tensor().as_megdnn(), + get_megdnn_workspace_from_var(output(2))); +} + +#if MGB_ENABLE_GRAD +MGB_IMPL_OPR_GRAD(ShuffleRNGForward) { + mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]); + if (!out_grad[0]) + return nullptr; + return ShuffleRNGBackward::make(out_grad[0], opr.output(1), opr.input(0)).node(); +} +#endif + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleRNGBackward); +MEGDNN_OPR_INIT3(ShuffleRNGBackward, "shuffle_rng_bwd", 2, true) + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/rand.sereg.h b/src/opr/impl/rand.sereg.h index 68b16ea1..c9ae7ea0 100644 --- a/src/opr/impl/rand.sereg.h +++ b/src/opr/impl/rand.sereg.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "megbrain/opr/rand.h" @@ -14,6 +15,23 @@ namespace mgb { +namespace serialization { + +template <> +struct OprMaker { + using Opr = opr::ShuffleRNG; + using Param = Opr::Param; + static cg::OperatorNodeBase* make(const Param& param, + const cg::VarNodeArray& inputs, + ComputingGraph& graph, + const OperatorNodeConfig& config) { + MGB_MARK_USED_VAR(graph); + auto out = Opr::make(inputs[0], param, config); + return out[0].node()->owner_opr(); + } +}; +} // namespace serialization + namespace opr { using UniformRNGV1 = opr::UniformRNG; @@ -24,9 +42,10 @@ MGB_SEREG_OPR(GammaRNG, 2); MGB_SEREG_OPR(PoissonRNG, 1); MGB_SEREG_OPR(PermutationRNG, 1); MGB_SEREG_OPR(BetaRNG, 2); +MGB_SEREG_OPR(ShuffleRNG, 1); +MGB_SEREG_OPR(ShuffleRNGBackward, 3); -} // namespace opr -} // namespace mgb +} // namespace opr +} // namespace mgb // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} - diff --git a/src/opr/include/megbrain/opr/rand.h b/src/opr/include/megbrain/opr/rand.h index 7bea8bfc..fc4fe5ad 100644 --- a/src/opr/include/megbrain/opr/rand.h +++ b/src/opr/include/megbrain/opr/rand.h @@ -6,14 +6,15 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "megbrain/graph.h" -#include "megbrain/opr/internal/out_shape_by_sym_var.h" #include "megbrain/opr/internal/megdnn_opr_wrapper.h" +#include "megbrain/opr/internal/out_shape_by_sym_var.h" #include "megdnn/oprs.h" namespace mgb { @@ -41,22 +42,24 @@ MGB_DEFINE_CLS_WITH_SUPER(RNGOprBase, cg::SingleCNOperatorNodeBase) // { }; /* ================= RNG with shape ================= */ -#define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \ -MGB_DEFINE_OPR_CLASS(RNG,RNGOprBase) \ - cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \ - public: \ - RNG(VarNode *shape, const Param ¶m, const OperatorNodeConfig &config); \ - static SymbolVar make(SymbolVar shape, const Param ¶m = {}, \ - const OperatorNodeConfig &config = {}); \ - static SymbolVar make(ComputingGraph &graph, const TensorShape &shape, \ - const OperatorNodeConfig &config, \ - const Param ¶m = {}) { \ - return make(var_from_tensor_shape(graph, config, "rng", shape), \ - param, config); \ - } \ - void init_output_static_infer_desc() override; \ - void scn_do_execute() override; \ -}; +#define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \ + MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase) \ + cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \ + \ +public: \ + RNG(VarNode* shape, const Param& param, const OperatorNodeConfig& config); \ + static SymbolVar make(SymbolVar shape, const Param& param = {}, \ + const OperatorNodeConfig& config = {}); \ + static SymbolVar make(ComputingGraph& graph, const TensorShape& shape, \ + const OperatorNodeConfig& config, \ + const Param& param = {}) { \ + return make(var_from_tensor_shape(graph, config, "rng", shape), param, \ + config); \ + } \ + void init_output_static_infer_desc() override; \ + void scn_do_execute() override; \ + } \ + ; _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(UniformRNG) _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(GaussianRNG) @@ -71,7 +74,7 @@ MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase) public: \ RNG(_INPUTS(VarNode*), const Param ¶m, \ const OperatorNodeConfig &config); \ - static SymbolVar make(_INPUTS(SymbolVar),const Param ¶m = {}, \ + static _OUTPUTS make(_INPUTS(SymbolVar),const Param ¶m = {}, \ const OperatorNodeConfig &config = {}); \ void init_output_static_infer_desc() override; \ void scn_do_execute() override; \ @@ -79,17 +82,24 @@ MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase) /* ================= 1 input ================= */ #define _INPUTS(preifx) preifx i0 +#define _OUTPUTS SymbolVar _DEFINE_RNG_OPR_WITH_INPUT_CLASS(PoissonRNG) +#undef _OUTPUTS +#define _OUTPUTS SymbolVarArray +_DEFINE_RNG_OPR_WITH_INPUT_CLASS(ShuffleRNGForward) +#undef _OUTPUTS #undef _INPUTS /* ================= 2 input ================= */ #define _INPUTS(preifx) preifx i0, preifx i1 +#define _OUTPUTS SymbolVar _DEFINE_RNG_OPR_WITH_INPUT_CLASS(BetaRNG) _DEFINE_RNG_OPR_WITH_INPUT_CLASS(GammaRNG) +#undef _OUTPUTS #undef _INPUTS #undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS -} // intl +} // intl using UniformRNG = intl::UniformRNG; using GaussianRNG = intl::GaussianRNG; @@ -97,9 +107,20 @@ using GammaRNG = intl::GammaRNG; using PermutationRNG = intl::PermutationRNG; using PoissonRNG = intl::PoissonRNG; using BetaRNG = intl::BetaRNG; -} // namespace opr -} // namespace mgb +using ShuffleRNG = intl::ShuffleRNGForward; +MGB_DEFINE_OPR_CLASS(ShuffleRNGBackward, + intl::MegDNNOprWrapperBwd) //{ +public: +ShuffleRNGBackward(VarNode* out_diff, VarNode* indices, VarNode* result_shape, + const Param& param, const OperatorNodeConfig& config); -// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} +static SymbolVar make(SymbolVar out_diff, SymbolVar indices, + SymbolVar result_shape, const Param& param = {}, + const OperatorNodeConfig& config = {}); +}; + +} // namespace opr +} // namespace mgb +// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/test/rand.cpp b/src/opr/test/rand.cpp index d4e28491..dc719d41 100644 --- a/src/opr/test/rand.cpp +++ b/src/opr/test/rand.cpp @@ -333,6 +333,38 @@ TEST(TestOprRand, EmptyShape) { } +TEST(TestOprRand, ShuffleForward) { + auto run = [&](TensorShape shape) { + std::shared_ptr src_host(new HostTensorND{ + CompNode::load("xpux"), shape, dtype::Float32()}); + auto sptr = src_host->ptr(); + auto size = shape.total_nr_elems(); + for (size_t i = 0; i < size; ++i) { + sptr[i] = i; + } + auto graph = ComputingGraph::make(); + auto src_sym = opr::Host2DeviceCopy::make(*graph, src_host); + auto rec = opr::ShuffleRNG::make(src_sym, {10}); + HostTensorND host_y, host_index; + auto func = graph->compile({make_callback_copy(rec[0], host_y), + make_callback_copy(rec[1], host_index)}); + func->execute(); + auto dptr = host_y.ptr(); + auto iptr = host_index.ptr(); + + size_t len = shape[0]; + size_t step = size / len; + for (size_t i = 0; i < len; ++i) { + for (size_t j = 0; j < step; ++j) { + assert(dptr[i * step + j] == sptr[iptr[i] * step + j]); + } + } + }; + run({10}); + run({6, 3}); + run({1, 1}); +} + TEST(TestOprRand, UniformReprod) { static constexpr size_t SIZE = 123; auto graph = ComputingGraph::make(); diff --git a/src/serialization/impl/schema.fbs b/src/serialization/impl/schema.fbs index 1923337c..9eff6038 100644 --- a/src/serialization/impl/schema.fbs +++ b/src/serialization/impl/schema.fbs @@ -114,6 +114,7 @@ union OperatorParam { param.BetaRNG = 80, param.SlidingWindowTranspose = 81, param.Padding = 82, + param.ShuffleRNG = 83, } table Operator {