@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "megdnn/internal/opr_header_prologue.h" | #include "megdnn/internal/opr_header_prologue.h" | ||||
@@ -94,6 +95,42 @@ class PermutationRNG: public RNGBase { | |||||
void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | ||||
}; | }; | ||||
class ShuffleRNGForward : public OperatorBase { | |||||
DEF_OPR_IMPL(ShuffleRNGForward, OperatorBase, 1, 2); | |||||
DEF_OPR_PARAM(ShuffleRNG); | |||||
public: | |||||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_tensor_out indices, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout(const TensorLayout& src, TensorLayout& dst, | |||||
TensorLayout& indices); | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
const TensorLayout& dst, | |||||
const TensorLayout& indices) = 0; | |||||
protected: | |||||
void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||||
const TensorLayout& indices, size_t workspace_in_bytes); | |||||
}; | |||||
using ShuffleRNG = ShuffleRNGForward; | |||||
class ShuffleRNGBackward : public OperatorBase { | |||||
DEF_OPR_IMPL(ShuffleRNGBackward, OperatorBase, 2, 1); | |||||
DEF_OPR_PARAM(ShuffleRNG); | |||||
public: | |||||
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices, | |||||
_megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout& diff, | |||||
const TensorLayout& indices, | |||||
const TensorLayout& grad) = 0; | |||||
protected: | |||||
void check_exec(const TensorLayout& diff, const TensorLayout& indices, | |||||
const TensorLayout& grad, size_t workspace_in_bytes); | |||||
}; | |||||
/*! | /*! | ||||
* \brief sleep for specific time on the computing device; useful for testing | * \brief sleep for specific time on the computing device; useful for testing | ||||
* async problems | * async problems | ||||
@@ -781,6 +781,9 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) | |||||
'Float32 are supported.'), | 'Float32 are supported.'), | ||||
'DTypeEnum::Int32')) | 'DTypeEnum::Int32')) | ||||
(pdef('ShuffleRNG'). | |||||
add_fields('uint64', 'seed', 0)) | |||||
(pdef('Flip'). | (pdef('Flip'). | ||||
add_fields('bool', 'vertical', 'false', 'horizontal', 'false')) | add_fields('bool', 'vertical', 'false', 'horizontal', 'false')) | ||||
@@ -165,6 +165,8 @@ private: | |||||
cb(BetaRNG) \ | cb(BetaRNG) \ | ||||
cb(PoissonRNG) \ | cb(PoissonRNG) \ | ||||
cb(PermutationRNG) \ | cb(PermutationRNG) \ | ||||
cb(ShuffleRNGForward) \ | |||||
cb(ShuffleRNGBackward) \ | |||||
cb(SeparableConvForward) \ | cb(SeparableConvForward) \ | ||||
cb(SeparableFilterForward) \ | cb(SeparableFilterForward) \ | ||||
cb(BNForward) \ | cb(BNForward) \ | ||||
@@ -128,6 +128,8 @@ DEF(GammaRNG, 3, true, true); | |||||
DEF(BetaRNG, 3, true, true); | DEF(BetaRNG, 3, true, true); | ||||
DEF(PoissonRNG, 2, true, true); | DEF(PoissonRNG, 2, true, true); | ||||
DEF(PermutationRNG, 1, true, true); | DEF(PermutationRNG, 1, true, true); | ||||
DEF(ShuffleRNGForward, 3, true, true); | |||||
DEF(ShuffleRNGBackward, 3, true, false); | |||||
DEF(ChecksumForward, 1, true, false); | DEF(ChecksumForward, 1, true, false); | ||||
DEF(CheckHasInf, 2, true, true); | DEF(CheckHasInf, 2, true, true); | ||||
DEF(LSQForward, 5, true, true); | DEF(LSQForward, 5, true, true); | ||||
@@ -15,6 +15,47 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
void ShuffleRNGForward::deduce_layout(const TensorLayout& src, | |||||
TensorLayout& dst, | |||||
TensorLayout& indices) { | |||||
dst = src; | |||||
indices = TensorLayout(TensorShape({src.shape[0]}), dtype::Int32()); | |||||
} | |||||
void ShuffleRNGForward::check_exec(const TensorLayout& src, | |||||
const TensorLayout& dst, | |||||
const TensorLayout& indices, | |||||
size_t workspace_in_bytes) { | |||||
TensorLayout dst_expected, indices_expected; | |||||
megdnn_assert_contiguous(src); | |||||
deduce_layout(src, dst_expected, indices_expected); | |||||
megdnn_assert_eq_layout(dst_expected, dst); | |||||
megdnn_assert_eq_layout(indices_expected, indices); | |||||
megdnn_assert_contiguous(indices); | |||||
megdnn_assert(src.dtype == dst.dtype); | |||||
megdnn_assert(indices.dtype == dtype::Int32()); | |||||
auto required_workspace_in_bytes = | |||||
get_workspace_in_bytes(src, dst, indices); | |||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
} | |||||
void ShuffleRNGBackward::check_exec(const TensorLayout& diff, | |||||
const TensorLayout& indices, | |||||
const TensorLayout& grad, | |||||
size_t workspace_in_bytes) { | |||||
megdnn_assert( | |||||
diff.shape[0] == indices.shape[0] && diff.dtype == grad.dtype && | |||||
indices.dtype == dtype::Int32{} && diff.is_contiguous() && | |||||
indices.is_contiguous() && grad.is_contiguous(), | |||||
"invalid layouts: diff=%s indices=%s grad=%s", | |||||
diff.to_string().c_str(), indices.to_string().c_str(), | |||||
grad.to_string().c_str()); | |||||
auto required_workspace_in_bytes = | |||||
get_workspace_in_bytes(diff, indices, grad); | |||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
} | |||||
void PermutationRNG::check_exec( | void PermutationRNG::check_exec( | ||||
const TensorLayout &dst, size_t workspace_in_bytes) { | const TensorLayout &dst, size_t workspace_in_bytes) { | ||||
megdnn_assert((dst.dtype == dtype::Float32() || | megdnn_assert((dst.dtype == dtype::Float32() || | ||||
@@ -55,6 +55,42 @@ __global__ void permute_duplicate_keys_kernel(KeyType* keys, ValueType* indexs, | |||||
} | } | ||||
} | } | ||||
template <typename T> | |||||
__global__ void shuffle_fwd_kernel(uint32_t step, uint32_t src_size, const T* sptr, | |||||
T* dptr, const int* iptr) { | |||||
uint32_t idx = threadIdx.x + blockIdx.x * blockDim.x; | |||||
if (idx < src_size) { | |||||
uint32_t r = idx / step; | |||||
dptr[idx]=sptr[iptr[r] * step + idx % step]; | |||||
} | |||||
} | |||||
template <typename T> | |||||
void shuffle_forward(T* sptr, T* dptr, dt_int32* iptr, | |||||
size_t len, size_t step, cudaStream_t stream) { | |||||
uint32_t src_size = len * step; | |||||
shuffle_fwd_kernel<<<DIVUP(src_size, 512), 512, 0, stream>>>( | |||||
step, src_size, sptr, dptr, iptr); | |||||
after_kernel_launch(); | |||||
} | |||||
template <typename T> | |||||
__global__ void shuffle_bwd_kernel(uint32_t step, uint32_t src_size, T* sptr, | |||||
T* dptr, const int* iptr) { | |||||
uint32_t idx = threadIdx.x + blockIdx.x * blockDim.x; | |||||
if (idx < src_size) { | |||||
uint32_t r = idx / step; | |||||
sptr[iptr[r] * step + idx % step]=dptr[idx]; | |||||
} | |||||
} | |||||
template <typename T> | |||||
void shuffle_backward(T* dptr, dt_int32* iptr, T* sptr, | |||||
size_t len, size_t step, cudaStream_t stream) { | |||||
uint32_t src_size = len * step; | |||||
shuffle_bwd_kernel<<<DIVUP(src_size, 512), 512, 0, stream>>>( | |||||
step, src_size, sptr, dptr, iptr); | |||||
after_kernel_launch(); | |||||
} | |||||
uint32_t get_permutation_bits(size_t N) { | uint32_t get_permutation_bits(size_t N) { | ||||
double uniq_rand_num_prob = 0.9; | double uniq_rand_num_prob = 0.9; | ||||
double thresh = std::log(uniq_rand_num_prob) * 12; | double thresh = std::log(uniq_rand_num_prob) * 12; | ||||
@@ -156,6 +192,14 @@ INST_PERMUTATION(dt_int16) | |||||
INST_PERMUTATION(dt_float32) | INST_PERMUTATION(dt_float32) | ||||
#undef INST_PERMUTATION | #undef INST_PERMUTATION | ||||
#define INST_SHUFFLE(T) \ | |||||
template void shuffle_forward<T>(T* sptr, T* dptr, dt_int32* iptr,\ | |||||
size_t len, size_t step, cudaStream_t stream);\ | |||||
template void shuffle_backward(T* dptr, dt_int32* iptr, T* sptr,\ | |||||
size_t len, size_t step, cudaStream_t stream); | |||||
ARGSORT_FOREACH_CTYPE(INST_SHUFFLE) | |||||
#undef INST_SHUFFLE | |||||
} // namespace random | } // namespace random | ||||
#define INST(_dtype) \ | #define INST(_dtype) \ | ||||
@@ -253,6 +253,17 @@ void permutation_forward(ctype* dst, void* workspace, size_t size, uint64_t seed | |||||
size_t get_permutation_workspace_in_bytes(size_t N); | size_t get_permutation_workspace_in_bytes(size_t N); | ||||
template<typename T> | |||||
void shuffle_forward(T* sptr, T* dptr, dt_int32* iptr, | |||||
size_t len, size_t step, cudaStream_t stream); | |||||
template<typename T> | |||||
void shuffle_backward(T* dptr, dt_int32* iptr, T* sptr, | |||||
size_t len, size_t step, cudaStream_t stream); | |||||
#define ARGSORT_FOREACH_CTYPE(cb) \ | |||||
cb(float) cb(int32_t) DNN_INC_FLOAT16(cb(dt_float16)) | |||||
} // namespace random | } // namespace random | ||||
} // namespace cuda | } // namespace cuda | ||||
} // namespace megdnn | } // namespace megdnn |
@@ -9,11 +9,11 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "./opr_impl.h" | |||||
#include "./kernel.cuh" | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#include "./opr_impl.h" | |||||
#include "./kernel.cuh" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -261,5 +261,76 @@ size_t PermutationRNGImpl::get_workspace_in_bytes(const TensorLayout &layout){ | |||||
return random::get_permutation_workspace_in_bytes(size); | return random::get_permutation_workspace_in_bytes(size); | ||||
} | } | ||||
ShuffleRNGForwardImpl::ShuffleRNGForwardImpl(Handle* handle) | |||||
: ShuffleRNGForward(handle), | |||||
m_seed(0), | |||||
m_offset(0), | |||||
m_stream(cuda_stream(handle)) {} | |||||
void ShuffleRNGForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_tensor_out indices, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(src.layout, dst.layout, indices.layout, workspace.size); | |||||
ensure_seed(m_param.seed); | |||||
auto wk = workspace.ptr<void>(); | |||||
const auto len = indices.layout[0]; | |||||
random::permutation_forward<dt_int32>(indices.ptr<dt_int32>(), wk, len, | |||||
m_seed, m_offset, m_stream); | |||||
size_t step = 0; | |||||
for (size_t i = 1; i < src.layout.ndim; ++i) { | |||||
step += src.layout[i]; | |||||
} | |||||
if (step <= 0) | |||||
step = 1; | |||||
switch (src.layout.dtype.enumv()) { | |||||
#define cb(DType) \ | |||||
case DTypeTrait<DType>::enumv: \ | |||||
random::shuffle_forward<DTypeTrait<DType>::ctype>( \ | |||||
src.ptr<DTypeTrait<DType>::ctype>(), \ | |||||
dst.ptr<DTypeTrait<DType>::ctype>(), indices.ptr<dt_int32>(), \ | |||||
len, step, m_stream); \ | |||||
break; | |||||
ARGSORT_FOREACH_CTYPE(cb) | |||||
#undef cb | |||||
default : megdnn_throw("bad dtype"); | |||||
} | |||||
m_offset += 8; | |||||
} | |||||
size_t ShuffleRNGForwardImpl::get_workspace_in_bytes( | |||||
const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout& indices) { | |||||
size_t size = indices.total_nr_elems(); | |||||
return random::get_permutation_workspace_in_bytes(size); | |||||
} | |||||
ShuffleRNGBackwardImpl::ShuffleRNGBackwardImpl(Handle* handle) | |||||
: ShuffleRNGBackward(handle), m_stream(cuda_stream(handle)) {} | |||||
void ShuffleRNGBackwardImpl::exec(_megdnn_tensor_in diff, | |||||
_megdnn_tensor_in indices, | |||||
_megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) { | |||||
const auto len = indices.layout[0]; | |||||
auto step = 0; | |||||
for (size_t i = 1; i < diff.layout.ndim; ++i) { | |||||
step += diff.layout[i]; | |||||
} | |||||
if (step <= 0) | |||||
step = 1; | |||||
switch (diff.layout.dtype.enumv()) { | |||||
#define cb(DType) \ | |||||
case DTypeTrait<DType>::enumv: \ | |||||
random::shuffle_backward<DTypeTrait<DType>::ctype>( \ | |||||
diff.ptr<DTypeTrait<DType>::ctype>(), indices.ptr<dt_int32>(), \ | |||||
grad.ptr<DTypeTrait<DType>::ctype>(), len, step, m_stream); \ | |||||
break; | |||||
ARGSORT_FOREACH_CTYPE(cb) | |||||
#undef cb | |||||
default: | |||||
megdnn_throw("bad dtype"); | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
@@ -152,6 +153,45 @@ public: | |||||
} | } | ||||
}; | }; | ||||
class ShuffleRNGForwardImpl : public ShuffleRNGForward { | |||||
uint64_t m_seed, m_offset; | |||||
cudaStream_t m_stream; | |||||
public: | |||||
using ShuffleRNGForward::ShuffleRNGForward; | |||||
ShuffleRNGForwardImpl(Handle* handle); | |||||
void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_tensor_out indices, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
const TensorLayout& dst, | |||||
const TensorLayout& indices) override; | |||||
void seed(uint64_t seed) { m_seed = seed; } | |||||
void ensure_seed(uint64_t seed) { | |||||
if (m_seed != seed) { | |||||
this->seed(seed); | |||||
} | |||||
} | |||||
}; | |||||
class ShuffleRNGBackwardImpl : public ShuffleRNGBackward { | |||||
cudaStream_t m_stream; | |||||
public: | |||||
using ShuffleRNGBackward::ShuffleRNGBackward; | |||||
ShuffleRNGBackwardImpl(Handle* handle); | |||||
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices, | |||||
_megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
}; | |||||
} // namespace cuda | } // namespace cuda | ||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -6,12 +6,13 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "src/naive/handle.h" | |||||
#include "src/common/utils.h" | |||||
#include "./opr_impl.h" | #include "./opr_impl.h" | ||||
#include "src/common/utils.h" | |||||
#include "src/naive/handle.h" | |||||
#include <cmath> | #include <cmath> | ||||
@@ -229,7 +230,29 @@ namespace { | |||||
} | } | ||||
} | } | ||||
} // anonymous namespace | |||||
template <typename T> | |||||
void shuffle_fwd(const T* __restrict sptr, T* __restrict dptr, | |||||
const dt_int32* iptr, const size_t len, | |||||
const size_t step) MEGDNN_NOEXCEPT { | |||||
for (size_t i = 0; i < len; ++i) { | |||||
for (size_t j = 0; j < step; ++j) { | |||||
dptr[i * step + j] = sptr[iptr[i] * step + j]; | |||||
} | |||||
} | |||||
} | |||||
template <typename T> | |||||
void shuffle_bwd(T* __restrict sptr, const T* __restrict dptr, | |||||
const dt_int32* iptr, const size_t len, | |||||
const size_t step) MEGDNN_NOEXCEPT { | |||||
for (size_t i = 0; i < len; ++i) { | |||||
for (size_t j = 0; j < step; ++j) { | |||||
sptr[iptr[i] * step + j] = dptr[i * step + j]; | |||||
} | |||||
} | |||||
} | |||||
} // anonymous namespace | |||||
uint64_t Splitmix64::operator() () { | uint64_t Splitmix64::operator() () { | ||||
uint64_t z = (m_s += UINT64_C(0x9E3779B97F4A7C15)); | uint64_t z = (m_s += UINT64_C(0x9E3779B97F4A7C15)); | ||||
@@ -394,5 +417,54 @@ void PermutationRNGImpl::exec( | |||||
} | } | ||||
} | } | ||||
// vim: syntax=cpp.doxygen | |||||
void ShuffleRNGForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_tensor_out indices, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(src.layout, dst.layout, indices.layout, workspace.size); | |||||
const auto len = indices.layout[0]; | |||||
auto iptr = indices.ptr<dt_int32>(); | |||||
auto prng = &m_rng.ensure_seed(m_param.seed); | |||||
fill_permutation<dt_int32>(prng, iptr, len); | |||||
auto step = 0; | |||||
for (size_t i = 1; i < src.layout.ndim; ++i) { | |||||
step += src.layout[i]; | |||||
} | |||||
if (step <= 0) | |||||
step = 1; | |||||
#define cb(DType) \ | |||||
if (src.layout.dtype == DType()) { \ | |||||
using T = typename DTypeTrait<DType>::ctype; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
shuffle_fwd<T>(src.ptr<T>(), dst.ptr<T>(), iptr, len, step)); \ | |||||
return; \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||||
#undef cb | |||||
} | |||||
void ShuffleRNGBackwardImpl::exec(_megdnn_tensor_in diff, | |||||
_megdnn_tensor_in indices, | |||||
_megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(diff.layout, indices.layout, grad.layout, workspace.size); | |||||
const auto len = indices.layout[0]; | |||||
auto iptr = indices.ptr<dt_int32>(); | |||||
auto step = 0; | |||||
for (size_t i = 1; i < diff.layout.ndim; ++i) { | |||||
step += diff.layout[i]; | |||||
} | |||||
if (step <= 0) | |||||
step = 1; | |||||
#define cb(DType) \ | |||||
if (diff.layout.dtype == DType()) { \ | |||||
using T = typename DTypeTrait<DType>::ctype; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(shuffle_bwd<T>( \ | |||||
grad.ptr<T>(), diff.ptr<T>(), iptr, len, step)); \ | |||||
return; \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||||
#undef cb | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -128,6 +128,35 @@ public: | |||||
size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } | size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } | ||||
}; | }; | ||||
class ShuffleRNGForwardImpl : public ShuffleRNGForward { | |||||
Xoroshiro128plus m_rng; | |||||
public: | |||||
using ShuffleRNGForward::ShuffleRNGForward; | |||||
void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_tensor_out indices, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
}; | |||||
class ShuffleRNGBackwardImpl : public ShuffleRNGBackward { | |||||
Xoroshiro128plus m_rng; | |||||
public: | |||||
using ShuffleRNGBackward::ShuffleRNGBackward; | |||||
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices, | |||||
_megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
}; | |||||
} // namespace naive | } // namespace naive | ||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -143,6 +143,60 @@ void run_permutation(Handle* handle) { | |||||
} | } | ||||
} | } | ||||
template <typename T> | |||||
void run_shuffle(Handle* handle, bool bwd_flag) { | |||||
using ctype = typename DTypeTrait<T>::ctype; | |||||
auto run = [&](TensorShape shape) { | |||||
auto opr = handle->create_operator<ShuffleRNGForward>(); | |||||
TensorLayout srclay{shape, T()}; | |||||
TensorLayout dstlay{shape, T()}; | |||||
TensorLayout indexlay{TensorShape{shape[0]}, dtype::Int32()}; | |||||
Tensor<dt_byte> workspace( | |||||
handle, {TensorShape{opr->get_workspace_in_bytes(srclay, dstlay, | |||||
indexlay)}, | |||||
dtype::Byte()}); | |||||
SyncedTensor<ctype> src(handle, srclay); | |||||
SyncedTensor<ctype> dst(handle, dstlay); | |||||
SyncedTensor<DTypeTrait<dt_int32>::ctype> index(handle, indexlay); | |||||
auto sptr = src.ptr_mutable_host(); | |||||
size_t size = src.layout().total_nr_elems(); | |||||
for (size_t j = 0; j < size; ++j) { | |||||
sptr[j] = j; | |||||
} | |||||
opr->exec(src.tensornd_dev(), dst.tensornd_dev(), index.tensornd_dev(), | |||||
{workspace.ptr(), workspace.layout().total_nr_elems()}); | |||||
auto dptr = dst.ptr_mutable_host(); | |||||
auto iptr = index.ptr_mutable_host(); | |||||
size_t len = index.layout().total_nr_elems(); | |||||
size_t step = size / len; | |||||
for (size_t i = 0; i < len; ++i) { | |||||
for (size_t j = 0; j < step; ++j) { | |||||
ASSERT_EQ(dptr[i * step + j], sptr[iptr[i] * step + j]); | |||||
} | |||||
} | |||||
if (bwd_flag) { | |||||
for (size_t j = 0; j < size; ++j) { | |||||
sptr[j] = 0; | |||||
} | |||||
auto oprbwd = handle->create_operator<ShuffleRNGBackward>(); | |||||
oprbwd->exec( | |||||
dst.tensornd_dev(), index.tensornd_dev(), | |||||
src.tensornd_dev(), | |||||
{workspace.ptr(), workspace.layout().total_nr_elems()}); | |||||
auto sptr_bwd = src.ptr_mutable_host(); | |||||
for (size_t i = 0; i < len; ++i) { | |||||
for (size_t j = 0; j < step; ++j) { | |||||
ASSERT_EQ(dptr[i * step + j], sptr_bwd[iptr[i] * step + j]); | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
run({10}); | |||||
run({6, 3}); | |||||
} | |||||
} // anonymous namespace | } // anonymous namespace | ||||
TEST_F(CUDA, UNIFORM_RNG_F32) { | TEST_F(CUDA, UNIFORM_RNG_F32) { | ||||
@@ -215,6 +269,30 @@ TEST_F(CUDA, PERMUTATION_RNG_INT16) { | |||||
run_permutation<dtype::Int16>(handle_cuda()); | run_permutation<dtype::Int16>(handle_cuda()); | ||||
} | } | ||||
TEST_F(CUDA, SHUFFLE_RNG_F32) { | |||||
run_shuffle<dtype::Float32>(handle_cuda(), false); | |||||
} | |||||
TEST_F(CUDA, SHUFFLE_RNG_INT32) { | |||||
run_shuffle<dtype::Int32>(handle_cuda(), false); | |||||
} | |||||
TEST_F(CUDA, SHUFFLE_RNG_F16) { | |||||
run_shuffle<dtype::Float16>(handle_cuda(), false); | |||||
} | |||||
TEST_F(CUDA, SHUFFLE_RNG_BWD_F32) { | |||||
run_shuffle<dtype::Float32>(handle_cuda(), true); | |||||
} | |||||
TEST_F(CUDA, SHUFFLE_RNG_BWD_INT32) { | |||||
run_shuffle<dtype::Int32>(handle_cuda(), true); | |||||
} | |||||
TEST_F(CUDA, SHUFFLE_RNG_BWD_F16) { | |||||
run_shuffle<dtype::Float16>(handle_cuda(), true); | |||||
} | |||||
} // namespace test | } // namespace test | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -6,12 +6,13 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "megdnn.h" | |||||
#include "test/naive/fixture.h" | |||||
#include "test/naive/rng.h" | #include "test/naive/rng.h" | ||||
#include "megdnn.h" | |||||
#include "test/common/tensor.h" | #include "test/common/tensor.h" | ||||
#include "test/naive/fixture.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
@@ -181,7 +182,59 @@ namespace { | |||||
ASSERT_LE(std::abs(res[i] - ctype(i)), 1e-8); | ASSERT_LE(std::abs(res[i] - ctype(i)), 1e-8); | ||||
} | } | ||||
} | } | ||||
} | |||||
template <typename T> | |||||
void run_shuffle(Handle* handle, bool bwd_flag) { | |||||
using ctype = typename DTypeTrait<T>::ctype; | |||||
auto run = [&](TensorShape shape) { | |||||
auto opr = handle->create_operator<ShuffleRNGForward>(); | |||||
TensorLayout srclay{shape, T()}; | |||||
TensorLayout dstlay{shape, T()}; | |||||
TensorLayout indexlay{TensorShape{shape[0]}, dtype::Int32()}; | |||||
Tensor<dt_byte> workspace( | |||||
handle, {TensorShape{opr->get_workspace_in_bytes(srclay, dstlay, | |||||
indexlay)}, | |||||
dtype::Byte()}); | |||||
Tensor<ctype> src(handle, srclay); | |||||
Tensor<ctype> dst(handle, dstlay); | |||||
Tensor<DTypeTrait<dt_int32>::ctype> index(handle, indexlay); | |||||
auto sptr = src.ptr(); | |||||
size_t size = src.layout().total_nr_elems(); | |||||
for (size_t j = 0; j < size; ++j) { | |||||
sptr[j] = j; | |||||
} | |||||
opr->exec(src.tensornd(), dst.tensornd(), index.tensornd(), | |||||
{workspace.ptr(), workspace.layout().total_nr_elems()}); | |||||
auto dptr = dst.ptr(); | |||||
auto iptr = index.ptr(); | |||||
size_t len = index.layout().total_nr_elems(); | |||||
size_t step = size / len; | |||||
for (size_t i = 0; i < len; ++i) { | |||||
for (size_t j = 0; j < step; ++j) { | |||||
ASSERT_EQ(dptr[i * step + j], sptr[iptr[i] * step + j]); | |||||
} | |||||
} | |||||
if (bwd_flag) { | |||||
for (size_t j = 0; j < size; ++j) { | |||||
sptr[j] = 0; | |||||
} | |||||
auto oprbwd = handle->create_operator<ShuffleRNGBackward>(); | |||||
oprbwd->exec( | |||||
dst.tensornd(), index.tensornd(), src.tensornd(), | |||||
{workspace.ptr(), workspace.layout().total_nr_elems()}); | |||||
for (size_t i = 0; i < len; ++i) { | |||||
for (size_t j = 0; j < step; ++j) { | |||||
ASSERT_EQ(dptr[i * step + j], sptr[iptr[i] * step + j]); | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
run({10}); | |||||
run({6, 3}); | |||||
} | |||||
} // namespace | |||||
TEST_F(NAIVE, UNIFORM_RNG_F32) { | TEST_F(NAIVE, UNIFORM_RNG_F32) { | ||||
run_uniform<dtype::Float32>(handle()); | run_uniform<dtype::Float32>(handle()); | ||||
@@ -235,10 +288,31 @@ TEST_F(NAIVE, PERMUTATION_RNG_INT16) { | |||||
run_permutation<dtype::Int16>(handle()); | run_permutation<dtype::Int16>(handle()); | ||||
} | } | ||||
} // namespace test | |||||
} // namespace megdnn | |||||
TEST_F(NAIVE, SHUFFLE_RNG_FWD_F32) { | |||||
run_shuffle<dtype::Float32>(handle(), false); | |||||
} | |||||
// vim: syntax=cpp.doxygen | |||||
TEST_F(NAIVE, SHUFFLE_RNG_FWD_INT32) { | |||||
run_shuffle<dtype::Int32>(handle(), false); | |||||
} | |||||
TEST_F(NAIVE, SHUFFLE_RNG_FWD_F16) { | |||||
run_shuffle<dtype::Float16>(handle(), false); | |||||
} | |||||
TEST_F(NAIVE, SHUFFLE_RNG_BWD_F32) { | |||||
run_shuffle<dtype::Float32>(handle(), true); | |||||
} | |||||
TEST_F(NAIVE, SHUFFLE_RNG_BWD_INT32) { | |||||
run_shuffle<dtype::Int32>(handle(), true); | |||||
} | |||||
TEST_F(NAIVE, SHUFFLE_RNG_BWD_F16) { | |||||
run_shuffle<dtype::Float16>(handle(), true); | |||||
} | |||||
} // namespace test | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -6,7 +6,7 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from .rng import RNG, beta, gamma, normal, permutation, poisson, seed, uniform | |||||
from .rng import RNG, beta, gamma, normal, permutation, poisson, seed, shuffle, uniform | |||||
__all__ = [ | __all__ = [ | ||||
"RNG", | "RNG", | ||||
@@ -17,6 +17,7 @@ __all__ = [ | |||||
"poisson", | "poisson", | ||||
"seed", | "seed", | ||||
"uniform", | "uniform", | ||||
"shuffle", | |||||
] | ] | ||||
# pylint: disable=undefined-variable | # pylint: disable=undefined-variable | ||||
del rng # type: ignore[name-defined] | del rng # type: ignore[name-defined] |
@@ -27,6 +27,7 @@ from ..core.ops.builtin import ( | |||||
GaussianRNG, | GaussianRNG, | ||||
PermutationRNG, | PermutationRNG, | ||||
PoissonRNG, | PoissonRNG, | ||||
ShuffleRNG, | |||||
UniformRNG, | UniformRNG, | ||||
) | ) | ||||
from ..core.tensor import utils | from ..core.tensor import utils | ||||
@@ -41,6 +42,7 @@ __all__ = [ | |||||
"beta", | "beta", | ||||
"poisson", | "poisson", | ||||
"permutation", | "permutation", | ||||
"shuffle", | |||||
] | ] | ||||
_rng = None | _rng = None | ||||
@@ -219,6 +221,13 @@ def _permutation(n: int, seed: int, device: str, handle: int, dtype: str) -> Ten | |||||
return output | return output | ||||
def _shuffle(inp: Tensor, seed: int, handle: int) -> Tensor: | |||||
assert inp.size > 0, "size needs to be greater than 0" | |||||
op = ShuffleRNG(seed=seed, handle=handle) | |||||
output, _ = apply(op, inp) | |||||
inp._reset(output) | |||||
class RNG: | class RNG: | ||||
r""":class:`RNG` exposes a number of methods for generating random numbers. | r""":class:`RNG` exposes a number of methods for generating random numbers. | ||||
@@ -581,6 +590,45 @@ class RNG: | |||||
n=n, seed=_seed, device=self._device, handle=self._handle, dtype=dtype | n=n, seed=_seed, device=self._device, handle=self._handle, dtype=dtype | ||||
) | ) | ||||
def shuffle(self, inp: Tensor): | |||||
r"""Modify a sequence in-place by shuffling its contents. | |||||
This function only shuffles the Tensor along the first axis of a multi-dimensional Tensor. | |||||
The order of sub-Tensors is changed but their contents remains the same. | |||||
Args: | |||||
inp: input tensor. | |||||
Examples: | |||||
.. testcode:: | |||||
import numpy as np | |||||
import megengine as mge | |||||
import megengine.random as rand | |||||
x = mge.tensor(np.arange(10)) | |||||
rand.shuffle(x) | |||||
print(x.numpy()) | |||||
y = mge.tensor(np.arange(18)).reshape(6,3) | |||||
rand.shuffle(y) | |||||
print(y.numpy()) | |||||
Outputs: | |||||
.. testoutput:: | |||||
:options: +SKIP | |||||
[7 9 3 0 8 2 4 5 6 1] | |||||
[[12. 13. 14.] | |||||
[ 3. 4. 5.] | |||||
[15. 16. 17.] | |||||
[ 0. 1. 2.] | |||||
[ 9. 10. 11.] | |||||
[ 6. 7. 8.]] | |||||
""" | |||||
_seed = self._seed() if callable(self._seed) else self._seed | |||||
_shuffle(inp=inp, seed=_seed, handle=self._handle) | |||||
def __del__(self): | def __del__(self): | ||||
if self._handle != 0: | if self._handle != 0: | ||||
_delete_rng_handle(self._handle) | _delete_rng_handle(self._handle) | ||||
@@ -599,6 +647,7 @@ gamma = _default_handle.gamma | |||||
beta = _default_handle.beta | beta = _default_handle.beta | ||||
poisson = _default_handle.poisson | poisson = _default_handle.poisson | ||||
permutation = _default_handle.permutation | permutation = _default_handle.permutation | ||||
shuffle = _default_handle.shuffle | |||||
def _random_seed_generator(): | def _random_seed_generator(): | ||||
@@ -18,6 +18,7 @@ from megengine.core._imperative_rt.ops import ( | |||||
get_global_rng_seed, | get_global_rng_seed, | ||||
new_rng_handle, | new_rng_handle, | ||||
) | ) | ||||
from megengine.core.autodiff.grad import Grad | |||||
from megengine.core.ops.builtin import ( | from megengine.core.ops.builtin import ( | ||||
BetaRNG, | BetaRNG, | ||||
GammaRNG, | GammaRNG, | ||||
@@ -397,6 +398,45 @@ def test_PermutationRNG(): | |||||
assert sum_result(out, np.sort) == 1000 | assert sum_result(out, np.sort) == 1000 | ||||
@pytest.mark.skipif( | |||||
get_device_count("xpu") <= 1, reason="xpu counts need > 1", | |||||
) | |||||
def test_ShuffleRNG(): | |||||
g = [] | |||||
def cb(grad): | |||||
g.append(grad) | |||||
n, m = 6, 3 | |||||
arr = np.arange(n * m) | |||||
out0 = Tensor(arr, dtype="float32") | |||||
grad = Grad().wrt(out0, callback=cb) | |||||
random.shuffle(out0) | |||||
grad(out0, F.ones_like(out0)) | |||||
m1 = RNG(seed=111, device="xpu0") | |||||
m2 = RNG(seed=111, device="xpu1") | |||||
m3 = RNG(seed=222, device="xpu0") | |||||
out1 = Tensor(arr, dtype="float32", device="xpu0") | |||||
out2 = Tensor(arr, dtype="float32", device="xpu1") | |||||
out3 = Tensor(arr, dtype="float32", device="xpu0") | |||||
m1.shuffle(out1) | |||||
m2.shuffle(out2) | |||||
m3.shuffle(out3) | |||||
np.testing.assert_equal(out1.numpy(), out2.numpy()) | |||||
assert out1.device == "xpu0" and out2.device == "xpu1" | |||||
assert not (out1.numpy() == out3.numpy()).all() | |||||
out = Tensor(arr, dtype="float32").reshape(n, m) | |||||
m1.shuffle(out) | |||||
out_shp = out.shape | |||||
if isinstance(out_shp, tuple): | |||||
assert out_shp == (n, m) | |||||
else: | |||||
assert all(out.shape.numpy() == np.array([n, m])) | |||||
def test_seed(): | def test_seed(): | ||||
set_global_seed(10) | set_global_seed(10) | ||||
out1 = uniform(size=[10, 10]) | out1 = uniform(size=[10, 10]) | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "megbrain/imperative/ops/rng.h" | #include "megbrain/imperative/ops/rng.h" | ||||
@@ -14,8 +15,8 @@ | |||||
#include "megbrain/graph/helper.h" | #include "megbrain/graph/helper.h" | ||||
#include "megbrain/opr/rand.h" | #include "megbrain/opr/rand.h" | ||||
#include "../op_trait.h" | |||||
#include "../dnn_op_helper.h" | #include "../dnn_op_helper.h" | ||||
#include "../op_trait.h" | |||||
namespace mgb::imperative::rng { | namespace mgb::imperative::rng { | ||||
@@ -259,13 +260,27 @@ struct OpMeth<BetaRNG> { | |||||
} | } | ||||
}; | }; | ||||
template <> | |||||
struct OpMeth<ShuffleRNG> { | |||||
using DnnOp = megdnn::ShuffleRNG; | |||||
using Param = DnnOp::Param; | |||||
using OpNode = mgb::opr::ShuffleRNG; | |||||
static Param make_param(const ShuffleRNG& rng) { | |||||
auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); | |||||
mgb_assert(handle_seed == rng.seed, | |||||
"inconsistent rng seed: rng op: %lu handle: %lu", | |||||
handle_seed, rng.seed); | |||||
return {handle_seed}; | |||||
} | |||||
}; | |||||
template <bool> | template <bool> | ||||
struct _InferLayout; | struct _InferLayout; | ||||
template <int nr_in> | template <int nr_in> | ||||
struct _RNGOprMaker; | struct _RNGOprMaker; | ||||
template <int nr_in> | |||||
template <int nr_in, int nr_out> | |||||
struct _RNGOprInvoker; | struct _RNGOprInvoker; | ||||
template<> | template<> | ||||
@@ -316,50 +331,63 @@ struct _InferLayout<false> | |||||
return inp.layout; | return inp.layout; | ||||
} | } | ||||
}; | }; | ||||
#define _INST_RNG_INVOLKER(DNN_NR_INPUTS) \ | |||||
template<> \ | |||||
struct _RNGOprInvoker<DNN_NR_INPUTS> { \ | |||||
template<typename Opr> \ | |||||
static void exec(Opr *dnn_op, const SmallVector<TensorPtr>& inputs,const TensorPtr& dest){ \ | |||||
size_t wk_size = 0; \ | |||||
wk_size = dnn_op->get_workspace_in_bytes(_FOR_EACH_IN(->layout())dest->layout()); \ | |||||
auto workspace = Blob::make(dest->comp_node(), wk_size); \ | |||||
megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); \ | |||||
dnn_op->exec(_FOR_EACH_IN(->dev_tensor().as_megdnn()) \ | |||||
dest->dev_tensor().as_megdnn(), dnn_wk); \ | |||||
} \ | |||||
}; | |||||
#define _INST_RNG_INVOLKER(DNN_NR_INPUTS, DNN_NR_OUTPUTS) \ | |||||
template <> \ | |||||
struct _RNGOprInvoker<DNN_NR_INPUTS, DNN_NR_OUTPUTS> { \ | |||||
template <typename Opr> \ | |||||
static void exec(Opr* dnn_op, const SmallVector<TensorPtr>& inputs, \ | |||||
const SmallVector<TensorPtr>& outputs) { \ | |||||
size_t wk_size = 0; \ | |||||
wk_size = dnn_op->get_workspace_in_bytes( \ | |||||
_FOR_EACH_IN(->layout()) _FOR_EACH_OUT(->layout())); \ | |||||
auto workspace = Blob::make(outputs[0]->comp_node(), wk_size); \ | |||||
megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); \ | |||||
dnn_op->exec(_FOR_EACH_IN(->dev_tensor().as_megdnn()) \ | |||||
_FOR_EACH_OUT(->dev_tensor().as_megdnn()), \ | |||||
dnn_wk); \ | |||||
} \ | |||||
}; | |||||
#define _INST_RNG_MAKER(MGB_NR_INPUTS) \ | |||||
template<> \ | |||||
struct _RNGOprMaker<MGB_NR_INPUTS> { \ | |||||
template<typename Op> \ | |||||
static SymbolVar make(const VarNodeArray& inputs, const Op& rng){ \ | |||||
auto param = OpMeth<Op>::make_param(rng); \ | |||||
OperatorNodeConfig config; \ | |||||
if (rng.handle) { \ | |||||
config = {rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)}; \ | |||||
} else { \ | |||||
config = {rng.make_name()}; \ | |||||
} \ | |||||
return OpMeth<Op>::OpNode::make(_FOR_EACH_IN() param, config); \ | |||||
} \ | |||||
}; | |||||
#define _INST_RNG_MAKER(MGB_NR_INPUTS) \ | |||||
template <> \ | |||||
struct _RNGOprMaker<MGB_NR_INPUTS> { \ | |||||
template <typename Op> \ | |||||
static auto make(const VarNodeArray& inputs, const Op& rng) { \ | |||||
auto param = OpMeth<Op>::make_param(rng); \ | |||||
OperatorNodeConfig config; \ | |||||
if (rng.handle) { \ | |||||
config = {rng.make_name(), \ | |||||
RNGDnnOpManager::get_comp_node(rng.handle)}; \ | |||||
} else { \ | |||||
config = {rng.make_name()}; \ | |||||
} \ | |||||
return OpMeth<Op>::OpNode::make(_FOR_EACH_IN() param, config); \ | |||||
} \ | |||||
}; | |||||
#define _FOR_EACH_IN(subfix) | |||||
_INST_RNG_INVOLKER(0) | |||||
#define _FOR_EACH_IN(subfix) | |||||
#define _FOR_EACH_OUT(subfix) outputs[0] subfix | |||||
_INST_RNG_INVOLKER(0, 1) | |||||
#undef _FOR_EACH_OUT | |||||
#undef _FOR_EACH_IN | #undef _FOR_EACH_IN | ||||
#define _FOR_EACH_IN(subfix) inputs[0] subfix, | #define _FOR_EACH_IN(subfix) inputs[0] subfix, | ||||
_INST_RNG_INVOLKER(1) | |||||
#define _FOR_EACH_OUT(subfix) outputs[0] subfix | |||||
_INST_RNG_INVOLKER(1, 1) | |||||
#undef _FOR_EACH_OUT | |||||
#define _FOR_EACH_OUT(subfix) outputs[0] subfix, outputs[1] subfix | |||||
_INST_RNG_INVOLKER(1, 2) | |||||
_INST_RNG_MAKER(1) | _INST_RNG_MAKER(1) | ||||
#undef _FOR_EACH_OUT | |||||
#undef _FOR_EACH_IN | #undef _FOR_EACH_IN | ||||
#define _FOR_EACH_IN(subfix) inputs[0] subfix, inputs[1] subfix, | #define _FOR_EACH_IN(subfix) inputs[0] subfix, inputs[1] subfix, | ||||
_INST_RNG_INVOLKER(2) | |||||
#define _FOR_EACH_OUT(subfix) outputs[0] subfix | |||||
_INST_RNG_INVOLKER(2, 1) | |||||
_INST_RNG_MAKER(2) | _INST_RNG_MAKER(2) | ||||
#undef _FOR_EACH_OUT | |||||
#undef _FOR_EACH_IN | #undef _FOR_EACH_IN | ||||
#undef _INST_RNG_INVOLKER | #undef _INST_RNG_INVOLKER | ||||
@@ -392,7 +420,9 @@ void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs, | |||||
handle_seed, dnn_op->param().seed); | handle_seed, dnn_op->param().seed); | ||||
} | } | ||||
dnn_op->param() = OpMeth<Op>::make_param(rng); | dnn_op->param() = OpMeth<Op>::make_param(rng); | ||||
_RNGOprInvoker<OpMeth<Op>::DnnOp::NR_INPUTS>::exec(dnn_op,inputs,dest); | |||||
_RNGOprInvoker<OpMeth<Op>::DnnOp::NR_INPUTS, | |||||
OpMeth<Op>::DnnOp::NR_OUTPUTS>::exec(dnn_op, inputs, | |||||
outputs); | |||||
} | } | ||||
template <typename Op> | template <typename Op> | ||||
@@ -420,24 +450,45 @@ SmallVector<LogicalTensorDesc> infer_output_attrs( | |||||
return {dest}; | return {dest}; | ||||
} | } | ||||
template <typename Op> | |||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||||
const OpDef& def, | |||||
const SmallVector<TensorPtr>& inputs_tensors, | |||||
const SmallVector<MemoryDesc>& inputs_mems) { | |||||
auto &&dest = infer_output_attrs<Op>(def, inputs_tensors); | |||||
SmallVector<MemoryDesc> outputs = {{dest[0].layout, 0, dest[0].comp_node, StorageIdentifier::make(1)}}; | |||||
return {outputs, {}}; | |||||
template <> | |||||
SmallVector<LogicalTensorDesc> infer_output_attrs<ShuffleRNG>( | |||||
const OpDef& op, const SmallVector<TensorPtr>& inputs) { | |||||
SmallVector<LogicalTensorDesc> dests(2); | |||||
auto&& rng = op.cast_final_safe<ShuffleRNG>(); | |||||
auto handle = rng.handle; | |||||
if (handle) { | |||||
dests[0].comp_node = RNGDnnOpManager::get_comp_node(handle); | |||||
dests[1].comp_node = RNGDnnOpManager::get_comp_node(handle); | |||||
} else { | |||||
dests[0].comp_node = inputs[0]->comp_node(); | |||||
dests[1].comp_node = inputs[0]->comp_node(); | |||||
} | |||||
dests[0].layout = TensorLayout(inputs[0]->layout()); | |||||
dests[0].layout.dtype = inputs[0]->layout().dtype; | |||||
dests[1].layout = | |||||
TensorLayout(TensorShape({inputs[0]->layout()[0]}), dtype::Int32()); | |||||
return dests; | |||||
} | } | ||||
template <typename Op> | |||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> | |||||
infer_output_mem_desc(const OpDef& def, | |||||
const SmallVector<TensorPtr>& inputs_tensors, | |||||
const SmallVector<MemoryDesc>& inputs_mems) { | |||||
auto&& dests = infer_output_attrs<Op>(def, inputs_tensors); | |||||
SmallVector<MemoryDesc> outputs; | |||||
for (size_t i = 0; i < dests.size(); ++i) { | |||||
outputs.push_back({dests[i].layout, 0, dests[i].comp_node, | |||||
StorageIdentifier::make(i + 1)}); | |||||
} | |||||
return {outputs, {}}; | |||||
} | |||||
template <typename Op> | template <typename Op> | ||||
SmallVector<TensorPtr> apply_on_physical_tensor( | SmallVector<TensorPtr> apply_on_physical_tensor( | ||||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | const OpDef& def, const SmallVector<TensorPtr>& inputs) { | ||||
SmallVector<TensorPtr> outputs; | SmallVector<TensorPtr> outputs; | ||||
SmallVector<LogicalTensorDesc> desc; | |||||
desc = infer_output_attrs<Op>(def, inputs); | |||||
SmallVector<LogicalTensorDesc> desc = infer_output_attrs<Op>(def, inputs); | |||||
for (auto&& i : desc) { | for (auto&& i : desc) { | ||||
outputs.push_back(Tensor::make(i.layout, i.comp_node)); | outputs.push_back(Tensor::make(i.layout, i.comp_node)); | ||||
} | } | ||||
@@ -454,10 +505,8 @@ void execute( | |||||
exec<Op>(def, inputs, outputs, {}); | exec<Op>(def, inputs, outputs, {}); | ||||
} | } | ||||
template<typename Op> | |||||
SymbolVar apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
template <typename Op, typename Output> | |||||
Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
size_t nr_inp = inputs.size(); | size_t nr_inp = inputs.size(); | ||||
constexpr size_t dnn_nr_inp = OpMeth<Op>::DnnOp::NR_INPUTS; | constexpr size_t dnn_nr_inp = OpMeth<Op>::DnnOp::NR_INPUTS; | ||||
auto&& rng = def.cast_final_safe<Op>(); | auto&& rng = def.cast_final_safe<Op>(); | ||||
@@ -487,7 +536,21 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
return {{dest}, true}; | return {{dest}, true}; | ||||
} | } | ||||
} // anonymous namespace | |||||
template <> | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> | |||||
infer_output_attrs_fallible<ShuffleRNG>( | |||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||||
SmallVector<LogicalTensorDesc> dests(2); | |||||
dests[0].comp_node = inputs[0].comp_node; | |||||
dests[0].layout = TensorLayout(inputs[0].layout); | |||||
dests[0].layout.dtype = inputs[0].layout.dtype; | |||||
dests[1].comp_node = inputs[0].comp_node; | |||||
dests[1].layout = TensorLayout(TensorShape({inputs[0].layout.shape[0]}), | |||||
dtype::Int32()); | |||||
return {dests, true}; | |||||
} | |||||
} // anonymous namespace | |||||
Handle new_handle(CompNode comp_node, uint64_t seed) { | Handle new_handle(CompNode comp_node, uint64_t seed) { | ||||
return RNGDnnOpManager::inst().new_handle(comp_node, seed); | return RNGDnnOpManager::inst().new_handle(comp_node, seed); | ||||
@@ -509,23 +572,24 @@ CompNode get_rng_handle_compnode(Handle handle){ | |||||
return RNGDnnOpManager::get_comp_node(handle); | return RNGDnnOpManager::get_comp_node(handle); | ||||
} | } | ||||
#define REG_RNG_OP(NAME)\ | |||||
namespace { \ | |||||
OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \ | |||||
.apply_on_var_node(apply_on_var_node<NAME>) \ | |||||
.apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \ | |||||
.infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \ | |||||
.infer_output_mem_desc(infer_output_mem_desc<NAME>) \ | |||||
.execute(execute<NAME>) \ | |||||
.fallback(); \ | |||||
} \ | |||||
REG_RNG_OP(UniformRNG) | |||||
REG_RNG_OP(GaussianRNG) | |||||
REG_RNG_OP(GammaRNG) | |||||
REG_RNG_OP(PermutationRNG) | |||||
REG_RNG_OP(PoissonRNG) | |||||
REG_RNG_OP(BetaRNG) | |||||
#define REG_RNG_OP(NAME, Output) \ | |||||
namespace { \ | |||||
OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \ | |||||
.apply_on_var_node(apply_on_var_node<NAME, Output>) \ | |||||
.apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \ | |||||
.infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \ | |||||
.infer_output_mem_desc(infer_output_mem_desc<NAME>) \ | |||||
.execute(execute<NAME>) \ | |||||
.fallback(); \ | |||||
} | |||||
REG_RNG_OP(UniformRNG, SymbolVar) | |||||
REG_RNG_OP(GaussianRNG, SymbolVar) | |||||
REG_RNG_OP(GammaRNG, SymbolVar) | |||||
REG_RNG_OP(PermutationRNG, SymbolVar) | |||||
REG_RNG_OP(PoissonRNG, SymbolVar) | |||||
REG_RNG_OP(BetaRNG, SymbolVar) | |||||
REG_RNG_OP(ShuffleRNG, SymbolVarArray) | |||||
#undef REG_RNG_OP | #undef REG_RNG_OP | ||||
} // namespace mgb::imperative::rng | } // namespace mgb::imperative::rng | ||||
@@ -215,6 +215,19 @@ def PermutationRNG: MgbHashableOp<"PermutationRNG", [PermutationRNGParam]> { | |||||
let cmpFunction = [{return $0.handle == $1.handle && $0.dtype == $1.dtype;}]; | let cmpFunction = [{return $0.handle == $1.handle && $0.dtype == $1.dtype;}]; | ||||
} | } | ||||
def ShuffleRNG: MgbHashableOp<"ShuffleRNG", [ShuffleRNGParam]> { | |||||
let extraArguments = (ins | |||||
MgbSizeTAddr:$handle | |||||
); | |||||
let hashFunction = [{ | |||||
return mgb::hash_pair_combine( | |||||
mgb::hash($_self.dyn_typeinfo()), | |||||
mgb::hash($_self.handle) | |||||
); | |||||
}]; | |||||
let cmpFunction = [{return $0.handle == $1.handle;}]; | |||||
} | |||||
def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> { | def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> { | ||||
let extraArguments = (ins | let extraArguments = (ins | ||||
MgbCompNodeAttr:$comp_node | MgbCompNodeAttr:$comp_node | ||||
@@ -192,6 +192,8 @@ template class RNGOprBase<::megdnn::GammaRNG>; | |||||
template class RNGOprBase<::megdnn::PermutationRNG>; | template class RNGOprBase<::megdnn::PermutationRNG>; | ||||
template class RNGOprBase<::megdnn::BetaRNG>; | template class RNGOprBase<::megdnn::BetaRNG>; | ||||
template class RNGOprBase<::megdnn::PoissonRNG>; | template class RNGOprBase<::megdnn::PoissonRNG>; | ||||
template class RNGOprBase<::megdnn::ShuffleRNGForward>; | |||||
template class RNGOprBase<::megdnn::ShuffleRNGBackward>; | |||||
#if MGB_ENABLE_GRAD | #if MGB_ENABLE_GRAD | ||||
IMPL(GaussianRNG); | IMPL(GaussianRNG); | ||||
IMPL(UniformRNG); | IMPL(UniformRNG); | ||||
@@ -200,9 +202,87 @@ IMPL(PoissonRNG); | |||||
IMPL(PermutationRNG); | IMPL(PermutationRNG); | ||||
IMPL(BetaRNG); | IMPL(BetaRNG); | ||||
#endif | #endif | ||||
} | |||||
} // namespace intl | |||||
} // namespace opr | |||||
} // namespace mgb | |||||
/* ================= ShuffleRNGForward ================= */ | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleRNGForward); | |||||
ShuffleRNGForward::ShuffleRNGForward(VarNode* data, const Param& param, | |||||
const OperatorNodeConfig& config) | |||||
: Super({data->owner_graph(), config, "shuffle_rng", {data}}, param) { | |||||
add_input({data}); | |||||
add_output(None)->dtype(data->dtype()); | |||||
add_output(None)->dtype(dtype::Int32{}); | |||||
cg::add_workspace_output(this); | |||||
add_equivalence_component<ScalarHash<void*>>(this); | |||||
} | |||||
SymbolVarArray ShuffleRNGForward::make(SymbolVar in_tensor, const Param& param, | |||||
const OperatorNodeConfig& config) { | |||||
auto node = in_tensor.node()->owner_graph()->insert_opr( | |||||
std::make_unique<ShuffleRNGForward>(in_tensor.node(), param, | |||||
config)); | |||||
mgb_assert(node->output().size() == 3); | |||||
return {node->output(0), node->output(1)}; | |||||
} | } | ||||
void ShuffleRNGForward::init_output_static_infer_desc() { | |||||
using namespace cg::static_infer; | |||||
auto&& mgr = owner_graph()->static_infer_manager(); | |||||
mgr.register_shape_infer(output(0), | |||||
ShapeInferDesc::make_identity(input(0))); | |||||
auto infer_oshp1 = [this](TensorShape& dest, const InpVal& iv) { | |||||
TensorLayout o0, o1; | |||||
m_dnn_opr->deduce_layout({iv.val[0].shape(), input(0)->dtype()}, o0, | |||||
o1); | |||||
dest = o1; | |||||
return true; | |||||
}; | |||||
mgr.register_shape_infer( | |||||
output(1), | |||||
{SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_oshp1}); | |||||
auto infer_wk = [this](TensorShape& dest, const InpVal& inp) { | |||||
ensure_megdnn_opr(); | |||||
dest.ndim = 1; | |||||
dest.shape[0] = m_dnn_opr->get_workspace_in_bytes( | |||||
{inp.val[0].shape(), input(0)->dtype()}, | |||||
{output(0)->shape(), output(0)->dtype()}, | |||||
{output(1)->shape(), output(1)->dtype()}); | |||||
return true; | |||||
}; | |||||
mgr.register_shape_infer( | |||||
output(2), | |||||
{SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_wk}); | |||||
} | } | ||||
void ShuffleRNGForward::add_input_layout_constraint() { | |||||
input(0)->add_layout_constraint_contiguous(); | |||||
}; | |||||
void ShuffleRNGForward::scn_do_execute() { | |||||
m_dnn_opr->exec(input(0)->dev_tensor().as_megdnn(), | |||||
output(0)->dev_tensor().as_megdnn(), | |||||
output(1)->dev_tensor().as_megdnn(), | |||||
get_megdnn_workspace_from_var(output(2))); | |||||
} | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(ShuffleRNGForward) { | |||||
mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]); | |||||
if (!out_grad[0]) | |||||
return nullptr; | |||||
return ShuffleRNGBackward::make(out_grad[0], opr.output(1), opr.input(0)).node(); | |||||
} | |||||
#endif | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleRNGBackward); | |||||
MEGDNN_OPR_INIT3(ShuffleRNGBackward, "shuffle_rng_bwd", 2, true) | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "megbrain/opr/rand.h" | #include "megbrain/opr/rand.h" | ||||
@@ -14,6 +15,23 @@ | |||||
namespace mgb { | namespace mgb { | ||||
namespace serialization { | |||||
template <> | |||||
struct OprMaker<opr::ShuffleRNG, 1> { | |||||
using Opr = opr::ShuffleRNG; | |||||
using Param = Opr::Param; | |||||
static cg::OperatorNodeBase* make(const Param& param, | |||||
const cg::VarNodeArray& inputs, | |||||
ComputingGraph& graph, | |||||
const OperatorNodeConfig& config) { | |||||
MGB_MARK_USED_VAR(graph); | |||||
auto out = Opr::make(inputs[0], param, config); | |||||
return out[0].node()->owner_opr(); | |||||
} | |||||
}; | |||||
} // namespace serialization | |||||
namespace opr { | namespace opr { | ||||
using UniformRNGV1 = opr::UniformRNG; | using UniformRNGV1 = opr::UniformRNG; | ||||
@@ -24,9 +42,10 @@ MGB_SEREG_OPR(GammaRNG, 2); | |||||
MGB_SEREG_OPR(PoissonRNG, 1); | MGB_SEREG_OPR(PoissonRNG, 1); | ||||
MGB_SEREG_OPR(PermutationRNG, 1); | MGB_SEREG_OPR(PermutationRNG, 1); | ||||
MGB_SEREG_OPR(BetaRNG, 2); | MGB_SEREG_OPR(BetaRNG, 2); | ||||
MGB_SEREG_OPR(ShuffleRNG, 1); | |||||
MGB_SEREG_OPR(ShuffleRNGBackward, 3); | |||||
} // namespace opr | |||||
} // namespace mgb | |||||
} // namespace opr | |||||
} // namespace mgb | |||||
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
@@ -6,14 +6,15 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "megbrain/graph.h" | #include "megbrain/graph.h" | ||||
#include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | ||||
#include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
namespace mgb { | namespace mgb { | ||||
@@ -41,22 +42,24 @@ MGB_DEFINE_CLS_WITH_SUPER(RNGOprBase, cg::SingleCNOperatorNodeBase) // { | |||||
}; | }; | ||||
/* ================= RNG with shape ================= */ | /* ================= RNG with shape ================= */ | ||||
#define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \ | |||||
MGB_DEFINE_OPR_CLASS(RNG,RNGOprBase<megdnn::RNG>) \ | |||||
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \ | |||||
public: \ | |||||
RNG(VarNode *shape, const Param ¶m, const OperatorNodeConfig &config); \ | |||||
static SymbolVar make(SymbolVar shape, const Param ¶m = {}, \ | |||||
const OperatorNodeConfig &config = {}); \ | |||||
static SymbolVar make(ComputingGraph &graph, const TensorShape &shape, \ | |||||
const OperatorNodeConfig &config, \ | |||||
const Param ¶m = {}) { \ | |||||
return make(var_from_tensor_shape(graph, config, "rng", shape), \ | |||||
param, config); \ | |||||
} \ | |||||
void init_output_static_infer_desc() override; \ | |||||
void scn_do_execute() override; \ | |||||
}; | |||||
#define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \ | |||||
MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>) \ | |||||
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \ | |||||
\ | |||||
public: \ | |||||
RNG(VarNode* shape, const Param& param, const OperatorNodeConfig& config); \ | |||||
static SymbolVar make(SymbolVar shape, const Param& param = {}, \ | |||||
const OperatorNodeConfig& config = {}); \ | |||||
static SymbolVar make(ComputingGraph& graph, const TensorShape& shape, \ | |||||
const OperatorNodeConfig& config, \ | |||||
const Param& param = {}) { \ | |||||
return make(var_from_tensor_shape(graph, config, "rng", shape), param, \ | |||||
config); \ | |||||
} \ | |||||
void init_output_static_infer_desc() override; \ | |||||
void scn_do_execute() override; \ | |||||
} \ | |||||
; | |||||
_DEFINE_RNG_OPR_WITH_SHAPE_CLASS(UniformRNG) | _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(UniformRNG) | ||||
_DEFINE_RNG_OPR_WITH_SHAPE_CLASS(GaussianRNG) | _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(GaussianRNG) | ||||
@@ -71,7 +74,7 @@ MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>) | |||||
public: \ | public: \ | ||||
RNG(_INPUTS(VarNode*), const Param ¶m, \ | RNG(_INPUTS(VarNode*), const Param ¶m, \ | ||||
const OperatorNodeConfig &config); \ | const OperatorNodeConfig &config); \ | ||||
static SymbolVar make(_INPUTS(SymbolVar),const Param ¶m = {}, \ | |||||
static _OUTPUTS make(_INPUTS(SymbolVar),const Param ¶m = {}, \ | |||||
const OperatorNodeConfig &config = {}); \ | const OperatorNodeConfig &config = {}); \ | ||||
void init_output_static_infer_desc() override; \ | void init_output_static_infer_desc() override; \ | ||||
void scn_do_execute() override; \ | void scn_do_execute() override; \ | ||||
@@ -79,17 +82,24 @@ MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>) | |||||
/* ================= 1 input ================= */ | /* ================= 1 input ================= */ | ||||
#define _INPUTS(preifx) preifx i0 | #define _INPUTS(preifx) preifx i0 | ||||
#define _OUTPUTS SymbolVar | |||||
_DEFINE_RNG_OPR_WITH_INPUT_CLASS(PoissonRNG) | _DEFINE_RNG_OPR_WITH_INPUT_CLASS(PoissonRNG) | ||||
#undef _OUTPUTS | |||||
#define _OUTPUTS SymbolVarArray | |||||
_DEFINE_RNG_OPR_WITH_INPUT_CLASS(ShuffleRNGForward) | |||||
#undef _OUTPUTS | |||||
#undef _INPUTS | #undef _INPUTS | ||||
/* ================= 2 input ================= */ | /* ================= 2 input ================= */ | ||||
#define _INPUTS(preifx) preifx i0, preifx i1 | #define _INPUTS(preifx) preifx i0, preifx i1 | ||||
#define _OUTPUTS SymbolVar | |||||
_DEFINE_RNG_OPR_WITH_INPUT_CLASS(BetaRNG) | _DEFINE_RNG_OPR_WITH_INPUT_CLASS(BetaRNG) | ||||
_DEFINE_RNG_OPR_WITH_INPUT_CLASS(GammaRNG) | _DEFINE_RNG_OPR_WITH_INPUT_CLASS(GammaRNG) | ||||
#undef _OUTPUTS | |||||
#undef _INPUTS | #undef _INPUTS | ||||
#undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS | #undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS | ||||
} // intl | |||||
} // intl | |||||
using UniformRNG = intl::UniformRNG; | using UniformRNG = intl::UniformRNG; | ||||
using GaussianRNG = intl::GaussianRNG; | using GaussianRNG = intl::GaussianRNG; | ||||
@@ -97,9 +107,20 @@ using GammaRNG = intl::GammaRNG; | |||||
using PermutationRNG = intl::PermutationRNG; | using PermutationRNG = intl::PermutationRNG; | ||||
using PoissonRNG = intl::PoissonRNG; | using PoissonRNG = intl::PoissonRNG; | ||||
using BetaRNG = intl::BetaRNG; | using BetaRNG = intl::BetaRNG; | ||||
} // namespace opr | |||||
} // namespace mgb | |||||
using ShuffleRNG = intl::ShuffleRNGForward; | |||||
MGB_DEFINE_OPR_CLASS(ShuffleRNGBackward, | |||||
intl::MegDNNOprWrapperBwd<megdnn::ShuffleRNGBackward>) //{ | |||||
public: | |||||
ShuffleRNGBackward(VarNode* out_diff, VarNode* indices, VarNode* result_shape, | |||||
const Param& param, const OperatorNodeConfig& config); | |||||
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
static SymbolVar make(SymbolVar out_diff, SymbolVar indices, | |||||
SymbolVar result_shape, const Param& param = {}, | |||||
const OperatorNodeConfig& config = {}); | |||||
}; | |||||
} // namespace opr | |||||
} // namespace mgb | |||||
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -333,6 +333,38 @@ TEST(TestOprRand, EmptyShape) { | |||||
} | } | ||||
TEST(TestOprRand, ShuffleForward) { | |||||
auto run = [&](TensorShape shape) { | |||||
std::shared_ptr<HostTensorND> src_host(new HostTensorND{ | |||||
CompNode::load("xpux"), shape, dtype::Float32()}); | |||||
auto sptr = src_host->ptr<dt_float32>(); | |||||
auto size = shape.total_nr_elems(); | |||||
for (size_t i = 0; i < size; ++i) { | |||||
sptr[i] = i; | |||||
} | |||||
auto graph = ComputingGraph::make(); | |||||
auto src_sym = opr::Host2DeviceCopy::make(*graph, src_host); | |||||
auto rec = opr::ShuffleRNG::make(src_sym, {10}); | |||||
HostTensorND host_y, host_index; | |||||
auto func = graph->compile({make_callback_copy(rec[0], host_y), | |||||
make_callback_copy(rec[1], host_index)}); | |||||
func->execute(); | |||||
auto dptr = host_y.ptr<dt_float32>(); | |||||
auto iptr = host_index.ptr<dt_int32>(); | |||||
size_t len = shape[0]; | |||||
size_t step = size / len; | |||||
for (size_t i = 0; i < len; ++i) { | |||||
for (size_t j = 0; j < step; ++j) { | |||||
assert(dptr[i * step + j] == sptr[iptr[i] * step + j]); | |||||
} | |||||
} | |||||
}; | |||||
run({10}); | |||||
run({6, 3}); | |||||
run({1, 1}); | |||||
} | |||||
TEST(TestOprRand, UniformReprod) { | TEST(TestOprRand, UniformReprod) { | ||||
static constexpr size_t SIZE = 123; | static constexpr size_t SIZE = 123; | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
@@ -114,6 +114,7 @@ union OperatorParam { | |||||
param.BetaRNG = 80, | param.BetaRNG = 80, | ||||
param.SlidingWindowTranspose = 81, | param.SlidingWindowTranspose = 81, | ||||
param.Padding = 82, | param.Padding = 82, | ||||
param.ShuffleRNG = 83, | |||||
} | } | ||||
table Operator { | table Operator { | ||||