GitOrigin-RevId: 4c05ebc266
release-1.5
@@ -21,19 +21,77 @@ class RNGBase: public OperatorBase { | |||||
_megdnn_workspace workspace) = 0; | _megdnn_workspace workspace) = 0; | ||||
virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0; | virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0; | ||||
protected: | protected: | ||||
void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | |||||
virtual void check_exec(const TensorLayout &dst, size_t workspace_in_bytes) = 0; | |||||
}; | |||||
//! sample from poisson distribution | |||||
class PoissonRNG: public OperatorBase { | |||||
DEF_OPR_IMPL(PoissonRNG, OperatorBase, 1, 1); | |||||
DEF_OPR_PARAM(PoissonRNG); | |||||
public: | |||||
virtual void exec(_megdnn_tensor_in lam, | |||||
_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout &lam, | |||||
const TensorLayout &dst) = 0; | |||||
protected: | |||||
void check_exec(const TensorLayout &lam, const TensorLayout &dst, | |||||
size_t workspace_in_bytes); | |||||
}; | |||||
//! sample from beta distribution | |||||
class BetaRNG: public OperatorBase { | |||||
DEF_OPR_IMPL(BetaRNG, OperatorBase, 2, 1); | |||||
DEF_OPR_PARAM(BetaRNG); | |||||
public: | |||||
virtual void exec(_megdnn_tensor_in alpha, | |||||
_megdnn_tensor_in beta, | |||||
_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout &alpha, | |||||
const TensorLayout &beta, const TensorLayout &dst) = 0; | |||||
protected: | |||||
void check_exec(const TensorLayout &alpha, const TensorLayout &beta, | |||||
const TensorLayout &dst, size_t workspace_in_bytes); | |||||
}; | |||||
//! sample from gamma distribution | |||||
class GammaRNG: public OperatorBase { | |||||
DEF_OPR_IMPL(GammaRNG, OperatorBase, 2, 1); | |||||
DEF_OPR_PARAM(GammaRNG); | |||||
public: | |||||
virtual void exec(_megdnn_tensor_in shape, | |||||
_megdnn_tensor_in scale, | |||||
_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout &shape, | |||||
const TensorLayout &scale, const TensorLayout &dst) = 0; | |||||
protected: | |||||
void check_exec(const TensorLayout &shape, const TensorLayout &scale, | |||||
const TensorLayout &dst, size_t workspace_in_bytes); | |||||
}; | }; | ||||
//! sample from uniform distribution on the interval (0, 1] | //! sample from uniform distribution on the interval (0, 1] | ||||
class UniformRNG: public RNGBase { | class UniformRNG: public RNGBase { | ||||
DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1); | DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1); | ||||
DEF_OPR_PARAM(UniformRNG); | DEF_OPR_PARAM(UniformRNG); | ||||
protected: | |||||
void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | |||||
}; | }; | ||||
//! sample from gaussian distribution | //! sample from gaussian distribution | ||||
class GaussianRNG: public RNGBase { | class GaussianRNG: public RNGBase { | ||||
DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1); | DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1); | ||||
DEF_OPR_PARAM(GaussianRNG); | DEF_OPR_PARAM(GaussianRNG); | ||||
protected: | |||||
void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | |||||
}; | |||||
class PermutationRNG: public RNGBase { | |||||
DEF_OPR_IMPL(PermutationRNG, RNGBase, 0, 1); | |||||
DEF_OPR_PARAM(PermutationRNG); | |||||
protected: | |||||
void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | |||||
}; | }; | ||||
/*! | /*! | ||||
@@ -735,11 +735,34 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) | |||||
'dtype', Doc('dtype', 'data type of output value'), | 'dtype', Doc('dtype', 'data type of output value'), | ||||
'DTypeEnum::Float32')) | 'DTypeEnum::Float32')) | ||||
pdef('UniformRNG').add_fields('uint64', 'seed', 0) | |||||
(pdef('UniformRNG'). | |||||
add_fields('uint64', 'seed', 0). | |||||
add_fields( | |||||
'dtype', Doc('dtype', 'The dtype of output Tensor. Only support Float32.'), | |||||
'DTypeEnum::Float32')) | |||||
(pdef('GaussianRNG'). | (pdef('GaussianRNG'). | ||||
add_fields('uint64', 'seed', 0). | add_fields('uint64', 'seed', 0). | ||||
add_fields('float32', 'mean', 0, 'std', 1)) | |||||
add_fields('float32', 'mean', 0, 'std', 1). | |||||
add_fields( | |||||
'dtype', Doc('dtype', 'The dtype of output Tensor. Only support Float32.'), | |||||
'DTypeEnum::Float32')) | |||||
(pdef('GammaRNG'). | |||||
add_fields('uint64', 'seed', 0)) | |||||
(pdef('BetaRNG'). | |||||
add_fields('uint64', 'seed', 0)) | |||||
(pdef('PoissonRNG'). | |||||
add_fields('uint64', 'seed', 0)) | |||||
(pdef('PermutationRNG'). | |||||
add_fields('uint64', 'seed', 0). | |||||
add_fields( | |||||
'dtype', Doc('dtype', 'The dtype of output Tensor. Int32, Int16 and ' | |||||
'Float32 are supported.'), | |||||
'DTypeEnum::Int32')) | |||||
(pdef('Flip'). | (pdef('Flip'). | ||||
add_fields('bool', 'vertical', 'false', 'horizontal', 'false')) | add_fields('bool', 'vertical', 'false', 'horizontal', 'false')) | ||||
@@ -159,6 +159,10 @@ private: | |||||
cb(SleepForward) \ | cb(SleepForward) \ | ||||
cb(UniformRNG) \ | cb(UniformRNG) \ | ||||
cb(GaussianRNG) \ | cb(GaussianRNG) \ | ||||
cb(GammaRNG) \ | |||||
cb(BetaRNG) \ | |||||
cb(PoissonRNG) \ | |||||
cb(PermutationRNG) \ | |||||
cb(SeparableConvForward) \ | cb(SeparableConvForward) \ | ||||
cb(SeparableFilterForward) \ | cb(SeparableFilterForward) \ | ||||
cb(BNForward) \ | cb(BNForward) \ | ||||
@@ -120,6 +120,10 @@ DEF(TQTBackward, 5, true, false); | |||||
DEF(PowC, 2, false, true); | DEF(PowC, 2, false, true); | ||||
DEF(UniformRNG, 1, true, true); | DEF(UniformRNG, 1, true, true); | ||||
DEF(GaussianRNG, 1, true, true); | DEF(GaussianRNG, 1, true, true); | ||||
DEF(GammaRNG, 3, true, true); | |||||
DEF(BetaRNG, 3, true, true); | |||||
DEF(PoissonRNG, 2, true, true); | |||||
DEF(PermutationRNG, 1, true, true); | |||||
DEF(ChecksumForward, 1, true, false); | DEF(ChecksumForward, 1, true, false); | ||||
DEF(CheckHasInf, 2, true, true); | DEF(CheckHasInf, 2, true, true); | ||||
DEF(LSQForward, 5, true, true); | DEF(LSQForward, 5, true, true); | ||||
@@ -15,13 +15,62 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
void RNGBase::check_exec( | |||||
void PermutationRNG::check_exec( | |||||
const TensorLayout &dst, size_t workspace_in_bytes) { | const TensorLayout &dst, size_t workspace_in_bytes) { | ||||
megdnn_assert(dst.dtype.category() == DTypeCategory::FLOAT && | |||||
dst.is_contiguous()); | |||||
megdnn_assert((dst.dtype == dtype::Float32() || | |||||
dst.dtype == dtype::Int32() || | |||||
dst.dtype == dtype::Int16() ) && | |||||
dst.dtype.enumv() == param().dtype && | |||||
dst.is_contiguous()); | |||||
megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(dst)); | megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(dst)); | ||||
} | } | ||||
void PoissonRNG::check_exec(const TensorLayout &lam, const TensorLayout &dst, | |||||
size_t workspace_in_bytes){ | |||||
megdnn_assert( dst.dtype.category() == DTypeCategory::FLOAT && | |||||
lam.dtype == dst.dtype); | |||||
megdnn_assert(dst.is_contiguous() && lam.is_contiguous()); | |||||
megdnn_assert(lam.total_nr_elems() == dst.total_nr_elems()); | |||||
megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(lam, dst)); | |||||
} | |||||
void GammaRNG::check_exec(const TensorLayout &shape,const TensorLayout &scale, | |||||
const TensorLayout &dst, size_t workspace_in_bytes){ | |||||
megdnn_assert(dst.dtype.category() == DTypeCategory::FLOAT && | |||||
shape.dtype == dst.dtype && | |||||
scale.dtype == dst.dtype); | |||||
megdnn_assert(shape.is_contiguous() && scale.is_contiguous() | |||||
&& dst.is_contiguous()); | |||||
megdnn_assert(shape.total_nr_elems() == dst.total_nr_elems() && | |||||
scale.total_nr_elems() == dst.total_nr_elems()); | |||||
megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(shape,scale,dst)); | |||||
} | |||||
void BetaRNG::check_exec(const TensorLayout &alpha,const TensorLayout &beta, | |||||
const TensorLayout &dst, size_t workspace_in_bytes){ | |||||
megdnn_assert(dst.dtype.category() == DTypeCategory::FLOAT && | |||||
alpha.dtype == dst.dtype && | |||||
beta.dtype == dst.dtype); | |||||
megdnn_assert(alpha.is_contiguous() && beta.is_contiguous() | |||||
&& dst.is_contiguous()); | |||||
megdnn_assert(alpha.total_nr_elems() == dst.total_nr_elems() && | |||||
beta.total_nr_elems() == dst.total_nr_elems()); | |||||
megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(alpha,beta, dst)); | |||||
} | |||||
#define INST_CHECK_EXEC(RNG_NAME) \ | |||||
void RNG_NAME::check_exec( \ | |||||
const TensorLayout &dst, size_t workspace_in_bytes) { \ | |||||
megdnn_assert(dst.dtype.category() == DTypeCategory::FLOAT && \ | |||||
dst.dtype.enumv() == param().dtype && \ | |||||
dst.is_contiguous()); \ | |||||
megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(dst)); \ | |||||
} | |||||
INST_CHECK_EXEC(UniformRNG) | |||||
INST_CHECK_EXEC(GaussianRNG) | |||||
#undef INST_CHECK_EXEC | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
@@ -49,23 +49,42 @@ bool use_segmented(uint32_t M, uint32_t /*N*/) { | |||||
return M >= 8; | return M >= 8; | ||||
} | } | ||||
template <typename KeyType> | |||||
MEGDNN_NOINLINE size_t cub_sort_pairs( | |||||
__global__ void kern_arange(int* dst, uint32_t n, uint32_t mod) { | |||||
uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; | |||||
if (i < n) { | |||||
dst[i] = i % mod; | |||||
} | |||||
} | |||||
template <typename ctype> | |||||
size_t get_sort_workspace(uint32_t M, uint32_t N, bool is_ascending) { | |||||
if (use_bitonic(M, N)) { | |||||
return 0; | |||||
} | |||||
return argsort::cub_sort_pairs<ctype, int>(is_ascending, NULL, 0, NULL, NULL, NULL, NULL, | |||||
M, N, 0, sizeof(float)*8, NULL); | |||||
} | |||||
} // anonymous namespace | |||||
template <typename KeyType, typename ValueType> | |||||
MEGDNN_NOINLINE size_t argsort::cub_sort_pairs( | |||||
bool is_ascending, void* workspace, size_t workspace_size, | bool is_ascending, void* workspace, size_t workspace_size, | ||||
const KeyType* keys_in, KeyType* keys_out, const int* values_in, | |||||
int* values_out, uint32_t M, uint32_t N, cudaStream_t stream) { | |||||
const KeyType* keys_in, KeyType* keys_out, const ValueType* values_in, | |||||
ValueType* values_out, uint32_t M, uint32_t N, int begin_bit, int end_bit,cudaStream_t stream){ | |||||
cudaError_t err; | cudaError_t err; | ||||
if (use_segmented(M, N)) { | if (use_segmented(M, N)) { | ||||
if (is_ascending) { | if (is_ascending) { | ||||
err = cub::DeviceSegmentedRadixSort::SortPairs( | err = cub::DeviceSegmentedRadixSort::SortPairs( | ||||
workspace, workspace_size, keys_in, keys_out, values_in, | workspace, workspace_size, keys_in, keys_out, values_in, | ||||
values_out, N * M, M, StridedOffsetIterator(0, N), | values_out, N * M, M, StridedOffsetIterator(0, N), | ||||
StridedOffsetIterator(N, N), 0, sizeof(float) * 8, stream); | |||||
StridedOffsetIterator(N, N), begin_bit, end_bit, stream); | |||||
cuda_check(err); | |||||
} else { | } else { | ||||
err = cub::DeviceSegmentedRadixSort::SortPairsDescending( | err = cub::DeviceSegmentedRadixSort::SortPairsDescending( | ||||
workspace, workspace_size, keys_in, keys_out, values_in, | workspace, workspace_size, keys_in, keys_out, values_in, | ||||
values_out, N * M, M, StridedOffsetIterator(0, N), | values_out, N * M, M, StridedOffsetIterator(0, N), | ||||
StridedOffsetIterator(N, N), 0, sizeof(float) * 8, stream); | |||||
StridedOffsetIterator(N, N), begin_bit, end_bit, stream); | |||||
cuda_check(err); | |||||
} | } | ||||
} else { | } else { | ||||
if (is_ascending) { | if (is_ascending) { | ||||
@@ -73,7 +92,7 @@ MEGDNN_NOINLINE size_t cub_sort_pairs( | |||||
err = cub::DeviceRadixSort::SortPairs( | err = cub::DeviceRadixSort::SortPairs( | ||||
workspace, workspace_size, keys_in + N * i, | workspace, workspace_size, keys_in + N * i, | ||||
keys_out + N * i, values_in + N * i, values_out + N * i, | keys_out + N * i, values_in + N * i, values_out + N * i, | ||||
N, 0, sizeof(float) * 8, stream); | |||||
N, begin_bit, end_bit, stream); | |||||
cuda_check(err); | cuda_check(err); | ||||
if (!keys_in) { | if (!keys_in) { | ||||
return workspace_size; | return workspace_size; | ||||
@@ -84,7 +103,7 @@ MEGDNN_NOINLINE size_t cub_sort_pairs( | |||||
err = cub::DeviceRadixSort::SortPairsDescending( | err = cub::DeviceRadixSort::SortPairsDescending( | ||||
workspace, workspace_size, keys_in + N * i, | workspace, workspace_size, keys_in + N * i, | ||||
keys_out + N * i, values_in + N * i, values_out + N * i, | keys_out + N * i, values_in + N * i, values_out + N * i, | ||||
N, 0, sizeof(float) * 8, stream); | |||||
N, begin_bit, end_bit, stream); | |||||
cuda_check(err); | cuda_check(err); | ||||
if (!keys_in) { | if (!keys_in) { | ||||
return workspace_size; | return workspace_size; | ||||
@@ -95,23 +114,6 @@ MEGDNN_NOINLINE size_t cub_sort_pairs( | |||||
return workspace_size; | return workspace_size; | ||||
} | } | ||||
__global__ void kern_arange(int* dst, uint32_t n, uint32_t mod) { | |||||
uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; | |||||
if (i < n) { | |||||
dst[i] = i % mod; | |||||
} | |||||
} | |||||
template <typename ctype> | |||||
size_t get_sort_workspace(uint32_t M, uint32_t N, bool is_ascending) { | |||||
if (use_bitonic(M, N)) { | |||||
return 0; | |||||
} | |||||
return cub_sort_pairs<ctype>(is_ascending, NULL, 0, NULL, NULL, NULL, NULL, | |||||
M, N, NULL); | |||||
} | |||||
} // anonymous namespace | |||||
size_t argsort::get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype, | size_t argsort::get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype, | ||||
bool is_ascending, | bool is_ascending, | ||||
bool iptr_src_given) { | bool iptr_src_given) { | ||||
@@ -151,17 +153,28 @@ void argsort::forward(const dtype* sptr, dtype* dptr, int* iptr, | |||||
stream)); | stream)); | ||||
} else { | } else { | ||||
cub_sort_pairs(is_ascending, workspace, wk_size, sptr, dptr, iptr_src, | cub_sort_pairs(is_ascending, workspace, wk_size, sptr, dptr, iptr_src, | ||||
iptr, M, N, stream); | |||||
iptr, M, N, 0, sizeof(float)*8, stream); | |||||
} | } | ||||
} | } | ||||
namespace megdnn { | namespace megdnn { | ||||
namespace cuda { | namespace cuda { | ||||
#define INST_CUB_SORT(dtype) \ | |||||
template MEGDNN_NOINLINE size_t argsort::cub_sort_pairs<dtype, dtype>(bool, \ | |||||
void*, size_t, const dtype*, dtype*, \ | |||||
const dtype*, dtype*, uint32_t, uint32_t,\ | |||||
int, int, cudaStream_t); | |||||
#define INST_FORWARD(dtype) \ | #define INST_FORWARD(dtype) \ | ||||
template void argsort::forward<dtype>(const dtype*, dtype*, int*, void*, \ | |||||
uint32_t, uint32_t, bool, \ | |||||
cudaStream_t, const int*); | |||||
template void argsort::forward<dtype>(const dtype*, dtype*, int*, void*, \ | |||||
uint32_t, uint32_t, bool, cudaStream_t, \ | |||||
const int*); | |||||
ARGSORT_FOREACH_CTYPE(INST_FORWARD) | ARGSORT_FOREACH_CTYPE(INST_FORWARD) | ||||
INST_CUB_SORT(uint32_t) | |||||
INST_CUB_SORT(uint64_t) | |||||
#undef INST_CUB_SORT | |||||
#undef INST_FORWARD | #undef INST_FORWARD | ||||
} | } | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -24,6 +24,12 @@ size_t get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype, | |||||
bool is_ascending, | bool is_ascending, | ||||
bool iptr_src_given = false); | bool iptr_src_given = false); | ||||
template <typename KeyType, typename ValueType> | |||||
size_t cub_sort_pairs( | |||||
bool is_ascending, void* workspace, size_t workspace_size, | |||||
const KeyType* keys_in, KeyType* keys_out, const ValueType* values_in, | |||||
ValueType* values_out, uint32_t M, uint32_t N, int begin_bit, int end_bit,cudaStream_t stream); | |||||
/*! | /*! | ||||
* \param iptr_src pointer to indices; a range would be generated if it is null | * \param iptr_src pointer to indices; a range would be generated if it is null | ||||
*/ | */ | ||||
@@ -0,0 +1,174 @@ | |||||
/** | |||||
* \file dnn/src/cuda/rnd/kernel.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include <curand_kernel.h> | |||||
#include <device_launch_parameters.h> | |||||
#include "../argsort/argsort.cuh" | |||||
#include "./kernel.cuh" | |||||
#include "src/cuda/cuda_shfl_compat.cuh" | |||||
#include "src/cuda/utils.cuh" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace random { | |||||
template <typename KeyType, typename ValueType> | |||||
__global__ void permute_duplicate_keys_kernel(KeyType* keys, ValueType* indexs, | |||||
KeyType mask, size_t size, | |||||
uint64_t seed, uint64_t offset) { | |||||
uint32_t idx = threadIdx.x + blockDim.x * blockIdx.x; | |||||
if (idx >= size - 1) return; | |||||
uint32_t lane_idx = threadIdx.x & 0x1F; | |||||
KeyType cur_key = keys[idx] & mask; | |||||
KeyType r_key = __shfl_down(cur_key, 1, 32); | |||||
if (lane_idx == 31) r_key = keys[idx + 1] & mask; | |||||
if (cur_key != r_key) return; | |||||
KeyType l_key = __shfl_up(cur_key, 1, 32); | |||||
if (idx != 0 && lane_idx == 0) l_key = keys[idx - 1] & mask; | |||||
if (cur_key == l_key) return; | |||||
indexs += idx; | |||||
int32_t duplicate_size = 1; | |||||
for (; idx + duplicate_size < size && cur_key == (keys[idx + duplicate_size] & mask); | |||||
++duplicate_size){}; | |||||
Philox state; | |||||
curand_init(seed, idx, offset, &state); | |||||
for (int32_t i = duplicate_size - 1; i > 0; --i) { | |||||
int32_t r = static_cast<int32_t>(curand(&state) & 0x7fffffff) % (i + 1); | |||||
if (i != r) { | |||||
ValueType tmp = indexs[i]; | |||||
indexs[i] = indexs[r]; | |||||
indexs[r] = tmp; | |||||
} | |||||
} | |||||
} | |||||
uint32_t get_permutation_bits(size_t N) { | |||||
double uniq_rand_num_prob = 0.9; | |||||
double thresh = std::log(uniq_rand_num_prob) * 12; | |||||
double dN = static_cast<double>(N); | |||||
uint32_t bits = std::min(64, static_cast<int>(std::ceil(std::log2( | |||||
dN - (6 * dN * dN + 1) / thresh)))); | |||||
return bits; | |||||
} | |||||
size_t get_permutation_workspace_in_bytes(size_t size) { | |||||
uint32_t bits = get_permutation_bits(size); | |||||
size_t work_size = 0; | |||||
#define cb(KeyType, ValueType) \ | |||||
size_t random_src_size = size * sizeof(KeyType); \ | |||||
size_t indexs_size = size * sizeof(ValueType); \ | |||||
size_t sort_worksize = argsort::cub_sort_pairs<KeyType, ValueType>( \ | |||||
false, NULL, 0, NULL, NULL, NULL, NULL, 1, size, 0, bits, NULL); \ | |||||
work_size = 2 * random_src_size + 2 * indexs_size + \ | |||||
DIVUP(sort_worksize, sizeof(KeyType)) * sizeof(KeyType); | |||||
if (bits > 32) { | |||||
cb(uint64_t, uint64_t) | |||||
} else { | |||||
cb(uint32_t, uint32_t) | |||||
} | |||||
#undef cb | |||||
return work_size; | |||||
} | |||||
template <bool is_32bit, typename ctype> | |||||
void permutation_cuda(ctype* dst, void* workspace, size_t size, uint64_t seed, | |||||
uint64_t offset, uint32_t bits, cudaStream_t stream) { | |||||
int threads = 512; | |||||
int blocks = DIVUP(size, threads); | |||||
using KeyType = typename std::conditional<is_32bit, uint32_t, uint64_t>::type; | |||||
using ValueType = KeyType; | |||||
// split workspace | |||||
KeyType* keys_in = static_cast<KeyType*>(workspace); | |||||
KeyType* keys_out = keys_in + size; | |||||
ValueType* values_in = static_cast<ValueType*>(keys_out + size); | |||||
ValueType* values_out = values_in + size; | |||||
void* extra_workspace = static_cast<void*>(values_out + size); | |||||
// init indexs | |||||
ElemwiseOpParamN<0> ele_param(size); | |||||
typedef RangeKernel<ValueType> rangeOp; | |||||
rangeOp range_op; | |||||
range_op.output = values_in; | |||||
run_elemwise<rangeOp, ValueType, 0>(ele_param, stream, range_op); | |||||
// generate random smaple | |||||
typedef RandomKernel<KeyType> randomOP; | |||||
randomOP random_op; | |||||
random_op.output = keys_in; | |||||
random_op.seed = seed; | |||||
random_op.offset = offset; | |||||
run_elemwise<randomOP, KeyType, 0>(ele_param, stream, random_op); | |||||
// argsort random sample | |||||
size_t wk_size = argsort::cub_sort_pairs<KeyType, ValueType>( | |||||
false, NULL, 0, NULL, NULL, NULL, NULL, 1, size, 0, bits, NULL); | |||||
argsort::cub_sort_pairs<KeyType, ValueType>( | |||||
false, extra_workspace, wk_size, keys_in, keys_out, values_in, | |||||
values_out, 1, size, 0, bits, stream); | |||||
// permute duplicate sample | |||||
KeyType mask = static_cast<KeyType>((1ULL << bits) - 1); | |||||
permute_duplicate_keys_kernel<KeyType, ValueType> | |||||
<<<blocks, threads, 0, stream>>>(keys_out, values_out, mask, size, | |||||
seed, offset); | |||||
after_kernel_launch(); | |||||
typedef AsTypeKernel<ValueType, ctype> asTypeOP; | |||||
asTypeOP as_type_op; | |||||
as_type_op.input = values_out; | |||||
as_type_op.output = dst; | |||||
run_elemwise<asTypeOP, ValueType, 0>(ele_param, stream, as_type_op); | |||||
} | |||||
template <typename ctype> | |||||
void permutation_forward(ctype* dst, void* workspace, size_t size, uint64_t seed, | |||||
uint64_t offset, cudaStream_t stream) { | |||||
uint32_t bits = get_permutation_bits(size); | |||||
if (bits <= 32) { | |||||
permutation_cuda<true, ctype>(dst, workspace, size, seed, offset, bits, | |||||
stream); | |||||
} else { | |||||
permutation_cuda<false, ctype>(dst, workspace, size, seed, offset, bits, | |||||
stream); | |||||
} | |||||
} | |||||
#define INST_PERMUTATION(T) \ | |||||
template void permutation_forward<T>(T*, void*, size_t, uint64_t, uint64_t, \ | |||||
cudaStream_t); \ | |||||
INST_PERMUTATION(dt_int32) | |||||
INST_PERMUTATION(dt_int16) | |||||
INST_PERMUTATION(dt_float32) | |||||
#undef INST_PERMUTATION | |||||
} // namespace random | |||||
#define INST(_dtype) \ | |||||
INST_RUN_ELEMWISE(random::GammaKernel<DTypeTrait<_dtype>::ctype>, \ | |||||
DTypeTrait<_dtype>::ctype, 0); \ | |||||
INST_RUN_ELEMWISE(random::PoissonKernel<DTypeTrait<_dtype>::ctype>, \ | |||||
DTypeTrait<_dtype>::ctype, 0); \ | |||||
INST_RUN_ELEMWISE(random::BetaKernel<DTypeTrait<_dtype>::ctype>, \ | |||||
DTypeTrait<_dtype>::ctype, 0); \ | |||||
INST(megdnn::dtype::Float32) | |||||
INST(megdnn::dtype::Float16) | |||||
INST(megdnn::dtype::BFloat16) | |||||
#undef INST | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -0,0 +1,258 @@ | |||||
/** | |||||
* \file dnn/src/cuda/rng/kernel.cuh | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <cuda_runtime_api.h> | |||||
#include <stdint.h> | |||||
#include <curand.h> | |||||
#include <curand_kernel.h> | |||||
#include "megdnn/dtype.h" | |||||
#include "src/cuda/elemwise_helper.cuh" | |||||
#include "src/cuda/utils.cuh" | |||||
#if MEGDNN_CC_HOST | |||||
#include "megdnn/oprs.h" | |||||
#endif | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace random { | |||||
using Philox = curandStatePhilox4_32_10_t; | |||||
QUALIFIERS float _curand_uniform(Philox *state){ | |||||
float r = curand_uniform(state); | |||||
if (r >= 1.0f) { | |||||
r = 0.0f; | |||||
} | |||||
return r; | |||||
} | |||||
template<typename ctype, typename = void> | |||||
struct RandomKernel; | |||||
template<typename ctype> | |||||
using enable_64bit = typename std::enable_if<std::is_integral<ctype>::value && ((sizeof(ctype)) == 8)>::type; | |||||
template<typename ctype> | |||||
using enable_32bit = typename std::enable_if<std::is_integral<ctype>::value && ((sizeof(ctype)) <= 4)>::type; | |||||
template<typename ctype> | |||||
struct RandomKernel<ctype, enable_64bit<ctype>>{ | |||||
ctype* output; | |||||
uint64_t seed, offset; | |||||
uint64_t mask = static_cast<uint64_t>(std::numeric_limits<ctype>::max()); | |||||
__device__ void operator()(uint32_t idx){ | |||||
Philox local_state; | |||||
curand_init(seed, idx, offset, &local_state); | |||||
uint4 rand = curand4(&local_state); | |||||
uint64_t val = (static_cast<uint64_t>(rand.x) << 32) | rand.y; | |||||
output[idx] = static_cast<ctype>(val & mask); | |||||
} | |||||
#if MEGDNN_CC_HOST | |||||
RandomKernel(const ctype* output, uint64_t seed, uint64_t offset) | |||||
: output{output}, | |||||
seed{seed}, | |||||
offset{offset}{} | |||||
#endif | |||||
}; | |||||
template<typename ctype> | |||||
struct RandomKernel<ctype, enable_32bit<ctype>>{ | |||||
ctype* output; | |||||
uint64_t seed, offset; | |||||
uint32_t mask = static_cast<uint32_t>(std::numeric_limits<ctype>::max()); | |||||
__device__ void operator()(uint32_t idx){ | |||||
Philox local_state; | |||||
curand_init(seed, idx, offset, &local_state); | |||||
uint32_t val = curand(&local_state); | |||||
output[idx] = static_cast<ctype>(val & mask); | |||||
} | |||||
#if MEGDNN_CC_HOST | |||||
RandomKernel(const ctype* output, uint64_t seed, uint64_t offset) | |||||
: output{output}, | |||||
seed{seed}, | |||||
offset{offset}{} | |||||
#endif | |||||
}; | |||||
template<typename ctype> | |||||
struct RangeKernel{ | |||||
ctype* output; | |||||
__device__ void operator()(uint32_t idx){ | |||||
output[idx] = static_cast<ctype>(idx); | |||||
} | |||||
#if MEGDNN_CC_HOST | |||||
RangeKernel(const ctype* output) | |||||
: output{output}{} | |||||
#endif | |||||
}; | |||||
template<typename ctype_src, typename ctype_dst> | |||||
struct AsTypeKernel{ | |||||
ctype_src* input; | |||||
ctype_dst* output; | |||||
using ctype_mask =typename std::conditional<std::is_integral<ctype_dst>::value, ctype_dst, ctype_src>::type; | |||||
ctype_src mask = static_cast<ctype_src>(std::numeric_limits<ctype_mask>::max()); | |||||
__device__ void operator()(uint32_t idx){ | |||||
output[idx] = static_cast<ctype_dst>(input[idx] & mask); | |||||
} | |||||
#if MEGDNN_CC_HOST | |||||
AsTypeKernel(const ctype_src* input, const ctype_dst* output) | |||||
: input{input}, output{output}{} | |||||
#endif | |||||
}; | |||||
template <typename ctype> | |||||
struct GammaKernel { | |||||
ctype* output; | |||||
ctype* shape; | |||||
ctype* scale; | |||||
uint64_t seed, offset; | |||||
static __device__ float sample_gamma(float a, float b, Philox* state){ | |||||
float scale = b; | |||||
if (a <= 0) | |||||
return 0.f; | |||||
if (a < 1.0f) { | |||||
scale *= powf(_curand_uniform(state), 1.0f / a); | |||||
a += 1.0f; | |||||
} | |||||
float d = a - 1.0f / 3.0f; | |||||
float c = 1.0f / sqrtf(9.0f * d); | |||||
while (1) { | |||||
float x, y; | |||||
x = curand_normal(state); | |||||
y = 1.0f + c * x; | |||||
if (y <= 0) | |||||
continue; | |||||
float v = y * y * y; | |||||
float u = _curand_uniform(state); | |||||
float xx = x * x; | |||||
if ((u < 1.0f - 0.0331f * xx * xx) || | |||||
logf(u) < 0.5f * xx + d * (1.0f - v + logf(v))) | |||||
return scale * d * v; | |||||
} | |||||
} | |||||
__device__ void operator()(uint32_t idx) { | |||||
Philox local_state; | |||||
curand_init(seed, idx, offset, &local_state); | |||||
float a = static_cast<float>(shape[idx]); | |||||
float b = static_cast<float>(scale[idx]); | |||||
output[idx] = static_cast<ctype>(sample_gamma(a, b, &local_state)); | |||||
} | |||||
#if MEGDNN_CC_HOST | |||||
GammaKernel(const TensorND& output, const TensorND& shape, | |||||
const TensorND& scale, uint64_t seed, uint64_t offset) | |||||
: output{output.ptr<ctype>()}, | |||||
shape{shape.ptr<ctype>()}, | |||||
scale{scale.ptr<ctype>()}, | |||||
seed{seed}, | |||||
offset{offset}{} | |||||
#endif | |||||
}; | |||||
template<typename ctype> | |||||
struct PoissonKernel{ | |||||
ctype* output; | |||||
ctype* lambda; | |||||
uint64_t seed, offset; | |||||
__device__ void operator()(uint32_t idx){ | |||||
Philox local_state; | |||||
curand_init(seed, idx, offset, &local_state); | |||||
float lam = static_cast<float>(lambda[idx]); | |||||
output[idx] = static_cast<ctype>(curand_poisson(&local_state, lam)); | |||||
} | |||||
#if MEGDNN_CC_HOST | |||||
PoissonKernel(const TensorND& output,const TensorND& lambda, | |||||
uint64_t seed, uint64_t offset) | |||||
: output{output.ptr<ctype>()}, | |||||
lambda{lambda.ptr<ctype>()}, | |||||
seed{seed}, | |||||
offset{offset}{} | |||||
#endif | |||||
}; | |||||
template<typename ctype> | |||||
struct BetaKernel{ | |||||
ctype* output; | |||||
ctype* alpha; | |||||
ctype* beta; | |||||
uint64_t seed, offset; | |||||
__device__ void operator()(uint32_t idx){ | |||||
Philox local_state; | |||||
curand_init(seed, idx, offset, &local_state); | |||||
float a = static_cast<float>(alpha[idx]); | |||||
float b = static_cast<float>(beta[idx]); | |||||
if(a <= 0 || b <= 0){ | |||||
output[idx] = 0; | |||||
return; | |||||
} | |||||
if( a < 1.0f && b < 1.0f){ | |||||
float u, v, x, y; | |||||
while (true) | |||||
{ | |||||
u = _curand_uniform(&local_state); | |||||
v = _curand_uniform(&local_state); | |||||
x = powf(u, 1.0f / a); | |||||
y = powf(v, 1.0f / b); | |||||
if (x + y < 1.0f) { | |||||
if (x + y > 0) { | |||||
output[idx] = static_cast<ctype>(x / (x + y)); | |||||
return ; | |||||
} else { | |||||
float logx = logf(u) / a; | |||||
float logy = logf(v) / b; | |||||
float log_max = logx > logy ? logx : logy; | |||||
logx -= log_max; | |||||
logy -= log_max; | |||||
output[idx] = static_cast<ctype>(exp(logx - | |||||
log(exp(logx) + exp(logy)))); | |||||
return ; | |||||
} | |||||
} | |||||
} | |||||
}else{ | |||||
float ga = GammaKernel<float>::sample_gamma(a, 1.0f, &local_state); | |||||
float gb = GammaKernel<float>::sample_gamma(b, 1.0f, &local_state); | |||||
output[idx] = static_cast<ctype>(ga / ( ga + gb)); | |||||
return ; | |||||
} | |||||
} | |||||
#if MEGDNN_CC_HOST | |||||
BetaKernel(const TensorND& output, const TensorND& alpha, | |||||
const TensorND& beta, uint64_t seed, uint64_t offset) | |||||
: output{output.ptr<ctype>()}, | |||||
alpha{alpha.ptr<ctype>()}, | |||||
beta{beta.ptr<ctype>()}, | |||||
seed{seed}, | |||||
offset{offset}{} | |||||
#endif | |||||
}; | |||||
template<typename ctype> | |||||
void permutation_forward(ctype* dst, void* workspace, size_t size, uint64_t seed, | |||||
uint64_t offset, cudaStream_t stream); | |||||
size_t get_permutation_workspace_in_bytes(size_t N); | |||||
} // namespace random | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -13,6 +13,7 @@ | |||||
#include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#include "./opr_impl.h" | #include "./opr_impl.h" | ||||
#include "./kernel.cuh" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -122,5 +123,143 @@ size_t GaussianRNGImpl::get_workspace_in_bytes(const TensorLayout &layout) { | |||||
return 0; | return 0; | ||||
} | } | ||||
GammaRNGImpl::GammaRNGImpl(Handle *handle): | |||||
GammaRNG(handle), | |||||
m_seed(0), | |||||
m_offset(0), | |||||
m_stream(cuda_stream(handle)) | |||||
{ | |||||
} | |||||
void GammaRNGImpl::exec(_megdnn_tensor_in shape, _megdnn_tensor_in scale, | |||||
_megdnn_tensor_inout dst, _megdnn_workspace workspace) { | |||||
check_exec(shape.layout, scale.layout ,dst.layout, workspace.size); | |||||
auto size = dst.layout.total_nr_elems(); | |||||
megdnn_assert(size); | |||||
ensure_seed(m_param.seed); | |||||
ElemwiseOpParamN<0> ele_param(size); | |||||
switch (dst.layout.dtype.enumv()){ | |||||
#define cb(_dt) \ | |||||
case DTypeTrait<_dt>::enumv: \ | |||||
{ \ | |||||
using ctype = DTypeTrait<_dt>::ctype; \ | |||||
run_elemwise<random::GammaKernel<ctype>, ctype, 0>(ele_param, m_stream, \ | |||||
{dst, shape, scale, m_seed, m_offset}); \ | |||||
break ; \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
#undef cb | |||||
default: | |||||
megdnn_throw("bad dtype"); | |||||
} | |||||
m_offset += 16; | |||||
} | |||||
PoissonRNGImpl::PoissonRNGImpl(Handle *handle): | |||||
PoissonRNG(handle), | |||||
m_seed(0), | |||||
m_offset(0), | |||||
m_stream(cuda_stream(handle)) | |||||
{ | |||||
} | |||||
void PoissonRNGImpl::exec(_megdnn_tensor_in lam, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(lam.layout, dst.layout, workspace.size); | |||||
auto size = dst.layout.total_nr_elems(); | |||||
megdnn_assert(size); | |||||
ensure_seed(m_param.seed); | |||||
ElemwiseOpParamN<0> ele_param(size); | |||||
switch (dst.layout.dtype.enumv()){ | |||||
#define cb(_dt) \ | |||||
case DTypeTrait<_dt>::enumv: \ | |||||
{ \ | |||||
using ctype = DTypeTrait<_dt>::ctype; \ | |||||
run_elemwise<random::PoissonKernel<ctype>, ctype, 0>(ele_param, m_stream, \ | |||||
{dst, lam, m_seed, m_offset}); \ | |||||
break; \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
#undef cb | |||||
default: | |||||
megdnn_throw("bad dtype"); | |||||
} | |||||
m_offset += 20; | |||||
} | |||||
BetaRNGImpl::BetaRNGImpl(Handle *handle): | |||||
BetaRNG(handle), | |||||
m_seed(0), | |||||
m_offset(0), | |||||
m_stream(cuda_stream(handle)) | |||||
{ | |||||
} | |||||
void BetaRNGImpl::exec(_megdnn_tensor_in alpha, _megdnn_tensor_in beta,_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(alpha.layout, beta.layout ,dst.layout, workspace.size); | |||||
auto size = dst.layout.total_nr_elems(); | |||||
megdnn_assert(size); | |||||
ensure_seed(m_param.seed); | |||||
ElemwiseOpParamN<0> ele_param(size); | |||||
switch (dst.layout.dtype.enumv()){ | |||||
#define cb(_dt) \ | |||||
case DTypeTrait<_dt>::enumv: \ | |||||
{ \ | |||||
using ctype = DTypeTrait<_dt>::ctype; \ | |||||
run_elemwise<random::BetaKernel<ctype>, ctype, 0>(ele_param, m_stream, \ | |||||
{dst, alpha, beta, m_seed, m_offset}); \ | |||||
break; \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
#undef cb | |||||
default: | |||||
megdnn_throw("bad dtype"); | |||||
} | |||||
m_offset += 32; | |||||
} | |||||
PermutationRNGImpl::PermutationRNGImpl(Handle *handle): | |||||
PermutationRNG(handle), | |||||
m_seed(0), | |||||
m_offset(0), | |||||
m_stream(cuda_stream(handle)) | |||||
{ | |||||
} | |||||
void PermutationRNGImpl::exec( | |||||
_megdnn_tensor_inout dst, _megdnn_workspace workspace) { | |||||
check_exec(dst.layout, workspace.size); | |||||
auto size = dst.layout.total_nr_elems(); | |||||
megdnn_assert(size); | |||||
ensure_seed(m_param.seed); | |||||
auto wk = workspace.ptr<void>(); | |||||
switch (dst.layout.dtype.enumv()){ | |||||
#define cb(_dt) \ | |||||
case DTypeTrait<_dt>::enumv: \ | |||||
{ \ | |||||
using ctype = DTypeTrait<_dt>::ctype; \ | |||||
ctype max_size = DTypeTrait<_dt>::max() - 1; \ | |||||
megdnn_assert(ctype(size) < max_size); \ | |||||
random::permutation_forward<ctype>(dst.ptr<ctype>(), wk, size, m_seed, \ | |||||
m_offset, m_stream); \ | |||||
break; \ | |||||
} | |||||
cb(::megdnn::dtype::Float32) | |||||
cb(::megdnn::dtype::Int32) | |||||
cb(::megdnn::dtype::Int16) | |||||
#undef cb | |||||
default: | |||||
megdnn_throw("bad dtype"); | |||||
} | |||||
m_offset += 8; | |||||
} | |||||
size_t PermutationRNGImpl::get_workspace_in_bytes(const TensorLayout &layout){ | |||||
size_t size = layout.total_nr_elems(); | |||||
return random::get_permutation_workspace_in_bytes(size); | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
@@ -10,9 +10,9 @@ | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include <curand.h> | |||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
#include <curand.h> | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace cuda { | namespace cuda { | ||||
@@ -22,51 +22,136 @@ class CuRandHandle { | |||||
uint64_t m_seed; | uint64_t m_seed; | ||||
CuRandHandle(const CuRandHandle&) = delete; | CuRandHandle(const CuRandHandle&) = delete; | ||||
CuRandHandle& operator = (const CuRandHandle&) = delete; | |||||
CuRandHandle& operator=(const CuRandHandle&) = delete; | |||||
public: | |||||
CuRandHandle(cudaStream_t stream, uint64_t seed = 0); | |||||
~CuRandHandle(); | |||||
public: | |||||
CuRandHandle(cudaStream_t stream, uint64_t seed = 0); | |||||
~CuRandHandle(); | |||||
void seed(uint64_t seed); | |||||
void seed(uint64_t seed); | |||||
curandGenerator_t gen() const { | |||||
return m_gen; | |||||
} | |||||
curandGenerator_t gen() const { return m_gen; } | |||||
void ensure_seed(uint64_t seed) { | |||||
if (m_seed != seed) { | |||||
this->seed(seed); | |||||
} | |||||
void ensure_seed(uint64_t seed) { | |||||
if (m_seed != seed) { | |||||
this->seed(seed); | |||||
} | } | ||||
} | |||||
}; | }; | ||||
class UniformRNGImpl: public UniformRNG { | |||||
class UniformRNGImpl : public UniformRNG { | |||||
CuRandHandle m_curand_handle; | CuRandHandle m_curand_handle; | ||||
public: | |||||
UniformRNGImpl(Handle *handle); | |||||
void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; | |||||
public: | |||||
UniformRNGImpl(Handle* handle); | |||||
void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } | |||||
}; | }; | ||||
class GaussianRNGImpl: public GaussianRNG { | |||||
class GaussianRNGImpl : public GaussianRNG { | |||||
CuRandHandle m_curand_handle; | CuRandHandle m_curand_handle; | ||||
public: | |||||
GaussianRNGImpl(Handle *handle); | |||||
public: | |||||
GaussianRNGImpl(Handle* handle); | |||||
void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout& layout) override; | |||||
}; | |||||
class GammaRNGImpl : public GammaRNG { | |||||
uint64_t m_seed, m_offset; | |||||
cudaStream_t m_stream; | |||||
public: | |||||
GammaRNGImpl(Handle* handle); | |||||
void exec(_megdnn_tensor_in shape,_megdnn_tensor_in scale, | |||||
_megdnn_tensor_out dst, _megdnn_workspace) override; | |||||
void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
size_t get_workspace_in_bytes(const TensorLayout &layout) override; | |||||
void seed(uint64_t seed) { m_seed = seed; } | |||||
void ensure_seed(uint64_t seed) { | |||||
if (m_seed != seed) { | |||||
this->seed(seed); | |||||
} | |||||
} | |||||
}; | }; | ||||
class BetaRNGImpl : public BetaRNG { | |||||
uint64_t m_seed, m_offset; | |||||
cudaStream_t m_stream; | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen | |||||
public: | |||||
BetaRNGImpl(Handle* handle); | |||||
void exec(_megdnn_tensor_in alpha,_megdnn_tensor_in beta, | |||||
_megdnn_tensor_out dst, _megdnn_workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
void seed(uint64_t seed) { m_seed = seed; } | |||||
void ensure_seed(uint64_t seed) { | |||||
if (m_seed != seed) { | |||||
this->seed(seed); | |||||
} | |||||
} | |||||
}; | |||||
class PoissonRNGImpl : public PoissonRNG { | |||||
uint64_t m_seed, m_offset; | |||||
cudaStream_t m_stream; | |||||
public: | |||||
PoissonRNGImpl(Handle* handle); | |||||
void exec(_megdnn_tensor_in lam, _megdnn_tensor_out dst, | |||||
_megdnn_workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
void seed(uint64_t seed) { m_seed = seed; } | |||||
void ensure_seed(uint64_t seed) { | |||||
if (m_seed != seed) { | |||||
this->seed(seed); | |||||
} | |||||
} | |||||
}; | |||||
class PermutationRNGImpl : public PermutationRNG { | |||||
uint64_t m_seed, m_offset; | |||||
cudaStream_t m_stream; | |||||
public: | |||||
PermutationRNGImpl(Handle* handle); | |||||
void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout& layout) override; | |||||
void seed(uint64_t seed) { m_seed = seed; } | |||||
void ensure_seed(uint64_t seed) { | |||||
if (m_seed != seed) { | |||||
this->seed(seed); | |||||
} | |||||
} | |||||
}; | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -78,6 +78,157 @@ namespace { | |||||
} | } | ||||
} | } | ||||
template<typename T> | |||||
T normal_sample(Xoroshiro128plus *rng){ | |||||
T v; | |||||
fill_gaussian<T>(rng, &v, 1, T(0.f), T(1.f)); | |||||
return v; | |||||
} | |||||
template<typename T> | |||||
T uniform_sample(Xoroshiro128plus *rng){ | |||||
return uniform_int2float<T>((*rng)()); | |||||
} | |||||
template<typename T, typename U> | |||||
void fill_gamma(Xoroshiro128plus *rng, U *dst, size_t size, | |||||
U* shape, U* scale){ | |||||
for(size_t i = 0; i < size; ++i){ | |||||
T a = static_cast<T>(shape[i]); | |||||
T b = static_cast<T>(scale[i]); | |||||
T scale = b; | |||||
bool a_less_one = a < 1.f ? true : false; | |||||
if (a <= 0) { | |||||
dst[i] = U(0.0f); | |||||
continue; | |||||
}; | |||||
T d = a + (a_less_one ? 2.0f / 3.0f : -1.0f / 3.0f); | |||||
T c = 1.0f / std::sqrt(9.0f * d); | |||||
while (true) | |||||
{ | |||||
T x, y; | |||||
x = normal_sample<T>(rng); | |||||
y = 1.0f + c * x; | |||||
if ( y <= 0) continue; | |||||
T v = y * y * y; | |||||
T u = uniform_sample<T>(rng); | |||||
T xx = x * x; | |||||
if ((u < 1.0f - 0.0331f * xx * xx) || | |||||
std::log(u) < 0.5f * xx + d * (1.0f - v + std::log(v))) | |||||
{ | |||||
dst[i] = U(scale * d * v); | |||||
if (a_less_one) dst[i] *= U(std::pow(uniform_sample<T>(rng), T(1.f / a))); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
template<typename T, typename U> | |||||
void fill_poisson(Xoroshiro128plus *rng, U *dst, U* lam, size_t size){ | |||||
for(size_t i = 0; i < size; ++i) { | |||||
T lambda = static_cast<T>(lam[i]); | |||||
T exp_neg_lambda = std::exp(-lambda); | |||||
T log_lambda = std::log(lambda), sqrt_lambda = std::sqrt(lambda); | |||||
T b = 0.931f + 2.53f * sqrt_lambda; | |||||
T a = -0.059f + 0.02483f * b; | |||||
T inv_alpha = 1.1239f + 1.1328f / ( b - 3.4f); | |||||
T vr = 0.9277f - 3.6224f / (b - 2.f); | |||||
T u , v, u_shifted, k; | |||||
if( lambda == 0) { | |||||
dst[i] = U(0); | |||||
continue; | |||||
} | |||||
if ( lambda < 10){ | |||||
T prod = 1, x = 0; | |||||
u = 0; | |||||
while (true) | |||||
{ | |||||
u = uniform_sample<T>(rng); | |||||
prod *= u; | |||||
if ( prod <= exp_neg_lambda ){ | |||||
dst[i] = U(x); | |||||
break; | |||||
} | |||||
x += 1; | |||||
} | |||||
continue; | |||||
} | |||||
while (true) | |||||
{ | |||||
u = uniform_sample<T>(rng) - T(0.5f); | |||||
v = uniform_sample<T>(rng); | |||||
u_shifted = T(0.5f) - std::abs(u); | |||||
k = std::floor((T(2.f) * a / u_shifted + b) * u + lambda + T(0.43f)); | |||||
if ( u_shifted >= 0.07 && v < vr ){ | |||||
dst[i] = U(k); | |||||
break; | |||||
} | |||||
if (k < 0 || (u_shifted < T(0.013f) && v > u_shifted)) { | |||||
continue; | |||||
} | |||||
if ((std::log(v) + std::log(inv_alpha) - std::log(a / (u_shifted * u_shifted) + b)) <= | |||||
(-lambda + k * log_lambda - std::lgamma(k + 1))) { | |||||
dst[i] = U(k); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
template<typename T, typename U> | |||||
void fill_beta(Xoroshiro128plus *rng, U *dst, U* alpha,U* beta, size_t size){ | |||||
for (size_t i = 0; i < size; ++i) { | |||||
T a = static_cast<T>(alpha[i]), b = static_cast<T>(beta[i]); | |||||
if( a < 1.0f && b < 1.0f){ | |||||
T u,v,x,y; | |||||
while (true) | |||||
{ | |||||
u = uniform_sample<T>(rng); | |||||
v = uniform_sample<T>(rng); | |||||
x = std::pow(u, 1.0f / a); | |||||
y = std::pow(v, 1.0f / b); | |||||
if (x + y < 1.0f) { | |||||
if (x + y > 0) { | |||||
dst[i] = static_cast<U>(x / (x + y)); | |||||
break; | |||||
}else { | |||||
T logx = std::log(u) / a; | |||||
T logy = std::log(v) / b; | |||||
T log_max = std::max(logx, logy); | |||||
logx -= log_max; | |||||
logy -= log_max; | |||||
dst[i] = static_cast<U> (std::exp(logx - | |||||
std::log(std::exp(logx) + std::exp(logy)))); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
}else{ | |||||
T ga, gb, one = 1; | |||||
fill_gamma<T,T>(rng, &ga, 1, &a, &one); | |||||
fill_gamma<T,T>(rng, &gb, 1, &b, &one); | |||||
dst[i] = static_cast<U>( ga / (ga + gb)); | |||||
} | |||||
} | |||||
} | |||||
template<typename T> | |||||
void fill_permutation(Xoroshiro128plus *rng, T *dst, size_t size){ | |||||
const int64_t mask = std::numeric_limits<int64_t>::max(); | |||||
for (size_t i = 0; i < size; ++i) { | |||||
dst[i] = static_cast<T>(i); | |||||
} | |||||
for (int64_t i = size - 1; i > 0; --i) { | |||||
int64_t r = static_cast<int64_t>((*rng)()&mask) % (i + 1); | |||||
if (i != r) { | |||||
T tmp = dst[i]; | |||||
dst[i] = dst[r]; | |||||
dst[r] = tmp; | |||||
} | |||||
} | |||||
} | |||||
} // anonymous namespace | } // anonymous namespace | ||||
uint64_t Splitmix64::operator() () { | uint64_t Splitmix64::operator() () { | ||||
@@ -150,5 +301,98 @@ void GaussianRNGImpl::exec( | |||||
} | } | ||||
} | } | ||||
void GammaRNGImpl::exec(_megdnn_tensor_in shape, _megdnn_tensor_in scale, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||||
check_exec(shape.layout, scale.layout, dst.layout, workspace.size); | |||||
auto size = dst.layout.total_nr_elems(); | |||||
auto prng = &m_rng.ensure_seed(m_param.seed); | |||||
switch (dst.layout.dtype.enumv()) { | |||||
#define cb(_dt) \ | |||||
case DTypeTrait<_dt>::enumv: \ | |||||
{ \ | |||||
using ctype = DTypeTrait<_dt>::ctype; \ | |||||
auto ptr = dst.ptr<ctype>(); \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR({fill_gamma<float>(prng, ptr, \ | |||||
size, shape.ptr<ctype>(), scale.ptr<ctype>());};); \ | |||||
return; \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
#undef cb | |||||
default: | |||||
megdnn_throw("bad dtype"); | |||||
} | |||||
} | |||||
void PoissonRNGImpl::exec(_megdnn_tensor_in lam, | |||||
_megdnn_tensor_inout dst, _megdnn_workspace workspace) { | |||||
check_exec(lam.layout, dst.layout, workspace.size); | |||||
auto size = dst.layout.total_nr_elems(); | |||||
auto prng = &m_rng.ensure_seed(m_param.seed); | |||||
switch (dst.layout.dtype.enumv()) { | |||||
#define cb(_dt) \ | |||||
case DTypeTrait<_dt>::enumv: \ | |||||
{ \ | |||||
using ctype = DTypeTrait<_dt>::ctype; \ | |||||
auto dst_ptr = dst.ptr<ctype>(); \ | |||||
auto lam_ptr = lam.ptr<ctype>(); \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR({fill_poisson<float>(prng, dst_ptr, \ | |||||
lam_ptr, size );};); \ | |||||
return; \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
#undef cb | |||||
default: | |||||
megdnn_throw("bad dtype"); | |||||
} | |||||
} | |||||
void BetaRNGImpl::exec(_megdnn_tensor_in alpha,_megdnn_tensor_in beta, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||||
check_exec(alpha.layout, beta.layout, dst.layout, workspace.size); | |||||
auto size = dst.layout.total_nr_elems(); | |||||
auto prng = &m_rng.ensure_seed(m_param.seed); | |||||
switch (dst.layout.dtype.enumv()) { | |||||
#define cb(_dt) \ | |||||
case DTypeTrait<_dt>::enumv: \ | |||||
{ \ | |||||
using ctype = DTypeTrait<_dt>::ctype; \ | |||||
auto dst_ptr = dst.ptr<ctype>(); \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR({fill_beta<float>(prng, dst_ptr, \ | |||||
alpha.ptr<ctype>(),beta.ptr<ctype>(), size );};); \ | |||||
return; \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
#undef cb | |||||
default: | |||||
megdnn_throw("bad dtype"); | |||||
} | |||||
} | |||||
void PermutationRNGImpl::exec( | |||||
_megdnn_tensor_inout dst, _megdnn_workspace workspace) { | |||||
check_exec(dst.layout, workspace.size); | |||||
auto size = dst.layout.total_nr_elems(); | |||||
auto prng = &m_rng.ensure_seed(m_param.seed); | |||||
switch (dst.layout.dtype.enumv()) { | |||||
#define cb(_dt) \ | |||||
case DTypeTrait<_dt>::enumv: \ | |||||
{ \ | |||||
using ctype = DTypeTrait<_dt>::ctype; \ | |||||
ctype max_size = DTypeTrait<_dt>::max() - 1; \ | |||||
megdnn_assert((ctype(size) < max_size)); \ | |||||
auto ptr = dst.ptr<ctype>(); \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR({fill_permutation<ctype>(prng, ptr, \ | |||||
size);};); \ | |||||
return; \ | |||||
} | |||||
cb(::megdnn::dtype::Float32) | |||||
cb(::megdnn::dtype::Int32) | |||||
cb(::megdnn::dtype::Int16) | |||||
#undef cb | |||||
default: | |||||
megdnn_throw("bad dtype"); | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
@@ -10,8 +10,8 @@ | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "megdnn/oprs.h" | |||||
#include <cstdint> | #include <cstdint> | ||||
#include "megdnn/oprs.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace naive { | namespace naive { | ||||
@@ -19,12 +19,11 @@ namespace naive { | |||||
//! see http://xoroshiro.di.unimi.it/splitmix64.c | //! see http://xoroshiro.di.unimi.it/splitmix64.c | ||||
class Splitmix64 { | class Splitmix64 { | ||||
uint64_t m_s; | uint64_t m_s; | ||||
public: | |||||
explicit Splitmix64(uint64_t seed = 0): | |||||
m_s{seed} | |||||
{} | |||||
uint64_t operator() (); | |||||
public: | |||||
explicit Splitmix64(uint64_t seed = 0) : m_s{seed} {} | |||||
uint64_t operator()(); | |||||
}; | }; | ||||
/*! | /*! | ||||
@@ -36,51 +35,99 @@ class Xoroshiro128plus { | |||||
return (x << k) | (x >> (64 - k)); | return (x << k) | (x >> (64 - k)); | ||||
} | } | ||||
public: | |||||
explicit Xoroshiro128plus(uint64_t seed = 0) { | |||||
public: | |||||
explicit Xoroshiro128plus(uint64_t seed = 0) { this->seed(seed); } | |||||
//! reset state if seed changed | |||||
Xoroshiro128plus& ensure_seed(uint64_t seed) { | |||||
if (seed != m_init_seed) { | |||||
this->seed(seed); | this->seed(seed); | ||||
} | } | ||||
return *this; | |||||
} | |||||
//! reset state if seed changed | |||||
Xoroshiro128plus& ensure_seed(uint64_t seed) { | |||||
if (seed != m_init_seed) { | |||||
this->seed(seed); | |||||
} | |||||
return *this; | |||||
} | |||||
//! set seed | |||||
void seed(uint64_t seed); | |||||
uint64_t operator()(); | |||||
}; | |||||
//! set seed | |||||
void seed(uint64_t seed); | |||||
class UniformRNGImpl : public UniformRNG { | |||||
Xoroshiro128plus m_rng; | |||||
uint64_t operator() (); | |||||
public: | |||||
using UniformRNG::UniformRNG; | |||||
void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } | |||||
}; | }; | ||||
class UniformRNGImpl: public UniformRNG { | |||||
class GaussianRNGImpl : public GaussianRNG { | |||||
Xoroshiro128plus m_rng; | Xoroshiro128plus m_rng; | ||||
public: | |||||
using UniformRNG::UniformRNG; | |||||
void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; | |||||
public: | |||||
using GaussianRNG::GaussianRNG; | |||||
void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } | |||||
}; | }; | ||||
class GaussianRNGImpl: public GaussianRNG { | |||||
class GammaRNGImpl : public GammaRNG { | |||||
Xoroshiro128plus m_rng; | Xoroshiro128plus m_rng; | ||||
public: | |||||
using GaussianRNG::GaussianRNG; | |||||
void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; | |||||
public: | |||||
using GammaRNG::GammaRNG; | |||||
size_t get_workspace_in_bytes(const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
void exec(_megdnn_tensor_in shape,_megdnn_tensor_in scale, _megdnn_tensor_out dst, | |||||
_megdnn_workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&,const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
}; | }; | ||||
class PoissonRNGImpl : public PoissonRNG { | |||||
Xoroshiro128plus m_rng; | |||||
public: | |||||
using PoissonRNG::PoissonRNG; | |||||
} // namespace naive | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen | |||||
void exec(_megdnn_tensor_in lam, _megdnn_tensor_inout dst, | |||||
_megdnn_workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
}; | |||||
class BetaRNGImpl : public BetaRNG { | |||||
Xoroshiro128plus m_rng; | |||||
public: | |||||
using BetaRNG::BetaRNG; | |||||
void exec(_megdnn_tensor_in alpha,_megdnn_tensor_in beta, _megdnn_tensor_out dst, | |||||
_megdnn_workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
}; | |||||
class PermutationRNGImpl : public PermutationRNG { | |||||
Xoroshiro128plus m_rng; | |||||
public: | |||||
using PermutationRNG::PermutationRNG; | |||||
void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } | |||||
}; | |||||
} // namespace naive | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -8,36 +8,165 @@ | |||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "megdnn/oprs.h" | |||||
#include "test/cuda/fixture.h" | |||||
#include "test/naive/rng.h" | #include "test/naive/rng.h" | ||||
#include "megdnn/oprs.h" | |||||
#include "test/common/tensor.h" | #include "test/common/tensor.h" | ||||
#include "test/cuda/fixture.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace test { | namespace test { | ||||
namespace { | |||||
template <typename T> | |||||
void run_gamma(Handle* handle) { | |||||
using ctype = typename DTypeTrait<T>::ctype; | |||||
auto opr = handle->create_operator<GammaRNG>(); | |||||
TensorLayout ly{TensorShape{2000000 * 5}, T()}; | |||||
SyncedTensor<ctype> out(handle, ly); | |||||
SyncedTensor<ctype> shape(handle, ly); | |||||
SyncedTensor<ctype> scale(handle, ly); | |||||
auto shape_ptr = shape.ptr_mutable_host(); | |||||
auto scale_ptr = scale.ptr_mutable_host(); | |||||
for (int i = 0; i < 5; ++i) { | |||||
for (int j = 0; j < 2000000; ++j) { | |||||
shape_ptr[i * 2000000 + j] =2 * 0.3 * i + 0.3; | |||||
scale_ptr[i * 2000000 + j] = i * 0.2 + 0.1; | |||||
} | |||||
} | |||||
opr->exec(shape.tensornd_dev(), scale.tensornd_dev(), out.tensornd_dev(), | |||||
{}); | |||||
auto ptr = out.ptr_mutable_host(); | |||||
for (int i = 0; i < 5; ++i) { | |||||
float a = 2 * 0.3 * i + 0.3, b = i * 0.2 + 0.1; | |||||
float mean = a *b; | |||||
float std = a * (b * b); | |||||
auto stat = get_mean_var(ptr + i * 2000000, 2000000, ctype(mean)); | |||||
ASSERT_LE(std::abs(stat.first - mean), 0.01); | |||||
ASSERT_LE(std::abs(stat.second - std), 0.01); | |||||
} | |||||
} | |||||
template <typename T> | |||||
void run_poisson(Handle* handle) { | |||||
using ctype = typename DTypeTrait<T>::ctype; | |||||
auto opr = handle->create_operator<PoissonRNG>(); | |||||
TensorLayout ly{TensorShape{200000 * 5}, T()}; | |||||
SyncedTensor<ctype> out(handle, ly); | |||||
SyncedTensor<ctype> lam(handle, ly); | |||||
auto lam_ptr = lam.ptr_mutable_host(); | |||||
for (int i = 0; i < 5; ++i) { | |||||
for (int j = 0; j < 200000; ++j) { | |||||
lam_ptr[i * 200000 + j] = ctype(i + 1); | |||||
} | |||||
} | |||||
opr->exec(lam.tensornd_dev(), out.tensornd_dev(), {}); | |||||
auto ptr = out.ptr_mutable_host(); | |||||
for (int i = 0; i < 5; ++i) { | |||||
auto stat = get_mean_var(ptr + i * 200000, 200000, ctype(i + 1)); | |||||
ASSERT_LE(std::abs(stat.first - ctype(i + 1)), 0.01); | |||||
ASSERT_LE(std::abs(stat.second - ctype(i + 1)), 0.01); | |||||
} | |||||
} | |||||
template <typename T> | |||||
void run_beta(Handle* handle) { | |||||
using ctype = typename DTypeTrait<T>::ctype; | |||||
auto opr = handle->create_operator<BetaRNG>(); | |||||
TensorLayout ly{TensorShape{200000 * 5}, T()}; | |||||
SyncedTensor<ctype> out(handle, ly); | |||||
SyncedTensor<ctype> alpha(handle, ly); | |||||
SyncedTensor<ctype> beta(handle, ly); | |||||
auto alpha_ptr = alpha.ptr_mutable_host(); | |||||
auto beta_ptr = beta.ptr_mutable_host(); | |||||
for (int i = 0; i < 5; ++i) { | |||||
for (int j = 0; j < 200000; ++j) { | |||||
alpha_ptr[i * 200000 + j] = 0.3 * i + 0.1; | |||||
beta_ptr[i * 200000 + j] = 2 * i * 0.3 + 0.1; | |||||
} | |||||
} | |||||
opr->exec(alpha.tensornd_dev(), beta.tensornd_dev(), out.tensornd_dev(), | |||||
{}); | |||||
auto ptr = out.ptr_mutable_host(); | |||||
for (int i = 0; i < 5; ++i) { | |||||
float a = 0.3 * i + 0.1, b = 2 * i * 0.3 + 0.1; | |||||
float mean = a / (a + b); | |||||
float std = a * b / ((a + b) * (a + b) * (a + b + 1)); | |||||
auto stat = get_mean_var(ptr + i * 200000, 200000, ctype(mean)); | |||||
ASSERT_LE(std::abs(stat.first - mean), 0.01); | |||||
ASSERT_LE(std::abs(stat.second - std), 0.01); | |||||
} | |||||
} | |||||
template <typename T> | |||||
void run_permutation(Handle* handle) { | |||||
using ctype = typename DTypeTrait<T>::ctype; | |||||
size_t sample_num = | |||||
std::min(200000, static_cast<int>(DTypeTrait<T>::max()) - 10); | |||||
auto opr = handle->create_operator<PermutationRNG>(); | |||||
opr->param().dtype = DTypeTrait<T>::enumv; | |||||
TensorLayout ly{TensorShape{sample_num}, T()}; | |||||
Tensor<dt_byte> workspace( | |||||
handle, | |||||
{TensorShape{opr->get_workspace_in_bytes(ly)}, dtype::Byte()}); | |||||
SyncedTensor<ctype> t(handle, ly); | |||||
opr->exec(t.tensornd_dev(), | |||||
{workspace.ptr(), workspace.layout().total_nr_elems()}); | |||||
auto ptr = t.ptr_mutable_host(); | |||||
auto size = t.layout().total_nr_elems(); | |||||
std::vector<ctype> res(size); | |||||
int not_same = 0; | |||||
for (size_t i = 0; i < size; ++i) { | |||||
if ((ptr[i] - ctype(i)) >= ctype(1)) not_same++; | |||||
res[i] = ptr[i]; | |||||
} | |||||
ASSERT_GT(not_same, 5000); | |||||
std::sort(res.begin(), res.end()); | |||||
for (size_t i = 0; i < size; ++i) { | |||||
ASSERT_LE(std::abs(res[i] - ctype(i)), 1e-8); | |||||
} | |||||
} | |||||
} // anonymous namespace | |||||
TEST_F(CUDA, UNIFORM_RNG_F32) { | TEST_F(CUDA, UNIFORM_RNG_F32) { | ||||
auto opr = handle_cuda()->create_operator<UniformRNG>(); | auto opr = handle_cuda()->create_operator<UniformRNG>(); | ||||
opr->param().dtype = DTypeTrait<dtype::Float32>::enumv; | |||||
SyncedTensor<> t(handle_cuda(), {TensorShape{200000}, dtype::Float32()}); | SyncedTensor<> t(handle_cuda(), {TensorShape{200000}, dtype::Float32()}); | ||||
opr->exec(t.tensornd_dev(), {}); | opr->exec(t.tensornd_dev(), {}); | ||||
assert_uniform_correct(t.ptr_mutable_host(), | |||||
t.layout().total_nr_elems()); | |||||
assert_uniform_correct(t.ptr_mutable_host(), t.layout().total_nr_elems()); | |||||
} | } | ||||
TEST_F(CUDA, GAUSSIAN_RNG_F32) { | TEST_F(CUDA, GAUSSIAN_RNG_F32) { | ||||
auto opr = handle_cuda()->create_operator<GaussianRNG>(); | auto opr = handle_cuda()->create_operator<GaussianRNG>(); | ||||
opr->param().mean = 0.8; | opr->param().mean = 0.8; | ||||
opr->param().std = 2.3; | opr->param().std = 2.3; | ||||
for (size_t size: {1, 200000, 200001}) { | |||||
opr->param().dtype = DTypeTrait<dtype::Float32>::enumv; | |||||
for (size_t size : {1, 200000, 200001}) { | |||||
TensorLayout ly{{size}, dtype::Float32()}; | TensorLayout ly{{size}, dtype::Float32()}; | ||||
Tensor<dt_byte> workspace(handle_cuda(), | |||||
{TensorShape{opr->get_workspace_in_bytes(ly)}, | |||||
dtype::Byte()}); | |||||
Tensor<dt_byte> workspace( | |||||
handle_cuda(), | |||||
{TensorShape{opr->get_workspace_in_bytes(ly)}, dtype::Byte()}); | |||||
SyncedTensor<> t(handle_cuda(), ly); | SyncedTensor<> t(handle_cuda(), ly); | ||||
opr->exec(t.tensornd_dev(), | opr->exec(t.tensornd_dev(), | ||||
{workspace.ptr(), workspace.layout().total_nr_elems()}); | |||||
{workspace.ptr(), workspace.layout().total_nr_elems()}); | |||||
auto ptr = t.ptr_mutable_host(); | auto ptr = t.ptr_mutable_host(); | ||||
ASSERT_LE(std::abs(ptr[0] - 0.8), 2.3); | ASSERT_LE(std::abs(ptr[0] - 0.8), 2.3); | ||||
@@ -50,10 +179,43 @@ TEST_F(CUDA, GAUSSIAN_RNG_F32) { | |||||
} | } | ||||
} | } | ||||
} // namespace test | |||||
} // namespace megdnn | |||||
TEST_F(CUDA, GAMMA_RNG_F32) { | |||||
run_gamma<dtype::Float32>(handle_cuda()); | |||||
} | |||||
// vim: syntax=cpp.doxygen | |||||
TEST_F(CUDA, GAMMA_RNG_F16) { | |||||
run_gamma<dtype::Float16>(handle_cuda()); | |||||
} | |||||
TEST_F(CUDA, POISSON_RNG_F32) { | |||||
run_poisson<dtype::Float32>(handle_cuda()); | |||||
} | |||||
TEST_F(CUDA, POISSON_RNG_F16) { | |||||
run_poisson<dtype::Float16>(handle_cuda()); | |||||
} | |||||
TEST_F(CUDA, BETA_RNG_F32) { | |||||
run_beta<dtype::Float32>(handle_cuda()); | |||||
} | |||||
TEST_F(CUDA, BETA_RNG_F16) { | |||||
run_beta<dtype::Float16>(handle_cuda()); | |||||
} | |||||
TEST_F(CUDA, PERMUTATION_RNG_F32) { | |||||
run_permutation<dtype::Float32>(handle_cuda()); | |||||
} | |||||
TEST_F(CUDA, PERMUTATION_RNG_INT32) { | |||||
run_permutation<dtype::Int32>(handle_cuda()); | |||||
} | |||||
TEST_F(CUDA, PERMUTATION_RNG_INT16) { | |||||
run_permutation<dtype::Int16>(handle_cuda()); | |||||
} | |||||
} // namespace test | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -32,6 +32,7 @@ namespace { | |||||
template<typename dtype> | template<typename dtype> | ||||
void run_uniform(Handle *handle) { | void run_uniform(Handle *handle) { | ||||
auto opr = handle->create_operator<UniformRNG>(); | auto opr = handle->create_operator<UniformRNG>(); | ||||
opr->param().dtype = DTypeTrait<dtype>::enumv; | |||||
Tensor<typename DTypeTrait<dtype>::ctype> t( | Tensor<typename DTypeTrait<dtype>::ctype> t( | ||||
handle, {TensorShape{200000}, dtype()}); | handle, {TensorShape{200000}, dtype()}); | ||||
opr->exec(t.tensornd(), {}); | opr->exec(t.tensornd(), {}); | ||||
@@ -44,6 +45,7 @@ namespace { | |||||
auto opr = handle->create_operator<GaussianRNG>(); | auto opr = handle->create_operator<GaussianRNG>(); | ||||
opr->param().mean = 0.8; | opr->param().mean = 0.8; | ||||
opr->param().std = 2.3; | opr->param().std = 2.3; | ||||
opr->param().dtype = DTypeTrait<dtype>::enumv; | |||||
Tensor<ctype> t(handle, {TensorShape{200001}, dtype()}); | Tensor<ctype> t(handle, {TensorShape{200001}, dtype()}); | ||||
opr->exec(t.tensornd(), {}); | opr->exec(t.tensornd(), {}); | ||||
@@ -53,8 +55,131 @@ namespace { | |||||
ASSERT_LE(std::abs(ptr[i] - 0.8), ctype(15)); | ASSERT_LE(std::abs(ptr[i] - 0.8), ctype(15)); | ||||
} | } | ||||
auto stat = get_mean_var(ptr, size, ctype(0.8)); | auto stat = get_mean_var(ptr, size, ctype(0.8)); | ||||
ASSERT_LE(std::abs(stat.first - 0.8), 5e-3); | ASSERT_LE(std::abs(stat.first - 0.8), 5e-3); | ||||
ASSERT_LE(std::abs(stat.second - 2.3 * 2.3), 5e-2); | |||||
ASSERT_LE(std::abs(stat.second - 2.3 * 2.3), 5e-2); | |||||
} | |||||
template<typename dtype> | |||||
void run_gamma(Handle* handle){ | |||||
using ctype = typename DTypeTrait<dtype>::ctype; | |||||
auto opr = handle->create_operator<GammaRNG>(); | |||||
TensorLayout ly{TensorShape{2000000*5}, dtype()}; | |||||
Tensor<ctype> out(handle, ly); | |||||
Tensor<ctype> shape(handle, ly); | |||||
Tensor<ctype> scale(handle, ly); | |||||
auto shape_ptr = shape.ptr(); | |||||
auto scale_ptr = scale.ptr(); | |||||
for (int i = 0; i < 5; ++i) { | |||||
for (int j = 0; j < 2000000; ++j) { | |||||
shape_ptr[i * 2000000 + j] = 2 * 0.3 * i + 0.5; | |||||
scale_ptr[i * 2000000 + j] = i * 0.2 + 0.1; | |||||
} | |||||
} | |||||
opr->exec(shape.tensornd(), scale.tensornd(), out.tensornd(), {}); | |||||
auto ptr = out.ptr(); | |||||
for(int i = 0; i < 5 ; ++i){ | |||||
float a = 2 * 0.3 * i + 0.5, b = i * 0.2 + 0.1; | |||||
float mean = a * b; | |||||
float std = a * (b * b) ; | |||||
auto stat = get_mean_var(ptr + i * 2000000, 2000000, ctype(mean)); | |||||
ASSERT_LE(std::abs(stat.first - mean), 0.01); | |||||
ASSERT_LE(std::abs(stat.second - std), 0.01); | |||||
} | |||||
} | |||||
template<typename dtype> | |||||
void run_poisson(Handle* handle){ | |||||
using ctype = typename DTypeTrait<dtype>::ctype; | |||||
auto opr = handle->create_operator<PoissonRNG>(); | |||||
TensorLayout ly{TensorShape{200000*5}, dtype()}; | |||||
Tensor<ctype> out(handle, ly); | |||||
Tensor<ctype> lam(handle, ly); | |||||
auto lam_ptr = lam.ptr(); | |||||
for(int i = 0; i < 5; ++i){ | |||||
for(int j = 0; j <200000; ++j){ | |||||
lam_ptr[i*200000 + j] = ctype(i + 1); | |||||
} | |||||
} | |||||
opr->exec(lam.tensornd(), out.tensornd(), {}); | |||||
auto ptr = out.ptr(); | |||||
for(int i = 0; i < 5 ; ++i){ | |||||
auto stat = get_mean_var(ptr + i*200000, 200000, ctype(i + 1)); | |||||
ASSERT_LE(std::abs(stat.first - ctype(i + 1)), 0.01); | |||||
ASSERT_LE(std::abs(stat.second - ctype(i + 1)), 0.01); | |||||
} | |||||
} | |||||
template<typename dtype> | |||||
void run_beta(Handle* handle){ | |||||
using ctype = typename DTypeTrait<dtype>::ctype; | |||||
auto opr = handle->create_operator<BetaRNG>(); | |||||
TensorLayout ly{TensorShape{200000*5}, dtype()}; | |||||
Tensor<ctype> out(handle, ly); | |||||
Tensor<ctype> alpha(handle, ly); | |||||
Tensor<ctype> beta(handle, ly); | |||||
auto alpha_ptr = alpha.ptr(); | |||||
auto beta_ptr = beta.ptr(); | |||||
for (int i = 0; i < 5; ++i) { | |||||
for (int j = 0; j < 200000; ++j) { | |||||
alpha_ptr[i * 200000 + j] = 0.3 * i + 0.1; | |||||
beta_ptr[i * 200000 + j] = 2 * i * 0.3 + 0.1; | |||||
} | |||||
} | |||||
opr->exec(alpha.tensornd(),beta.tensornd(), out.tensornd(), {}); | |||||
auto ptr = out.ptr(); | |||||
for(int i = 0; i < 5 ; ++i){ | |||||
float a = 0.3 * i + 0.1, b = 2 * i * 0.3 + 0.1; | |||||
float mean = a / (a + b); | |||||
float std = a * b / ((a + b) * (a + b) * (a + b + 1)); | |||||
auto stat = get_mean_var(ptr + i * 200000, 200000, ctype(mean)); | |||||
ASSERT_LE(std::abs(stat.first - mean), 0.01); | |||||
ASSERT_LE(std::abs(stat.second - std), 0.01); | |||||
} | |||||
} | |||||
template<typename dtype> | |||||
void run_permutation(Handle* handle){ | |||||
using ctype = typename DTypeTrait<dtype>::ctype; | |||||
size_t sample_num = std::min(200000, | |||||
static_cast<int>(DTypeTrait<dtype>::max()) - 10); | |||||
auto opr = handle->create_operator<PermutationRNG>(); | |||||
opr->param().dtype = DTypeTrait<dtype>::enumv; | |||||
TensorLayout ly{TensorShape{sample_num}, dtype()}; | |||||
Tensor<ctype> t(handle, ly); | |||||
opr->exec(t.tensornd(), {}); | |||||
auto ptr = t.ptr(); | |||||
auto size = t.layout().total_nr_elems(); | |||||
std::vector<ctype> res(size); | |||||
int not_same = 0; | |||||
for(size_t i = 0; i < size; ++i){ | |||||
if ((ptr[i] - ctype(i)) >= 1 ) not_same++; | |||||
res[i] = ptr[i]; | |||||
} | |||||
ASSERT_GT(not_same, 5000); | |||||
std::sort(res.begin(),res.end()); | |||||
for(size_t i = 0; i < size; ++i){ | |||||
ASSERT_LE(std::abs(res[i] - ctype(i)), 1e-8); | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -74,6 +199,42 @@ TEST_F(NAIVE, GAUSSIAN_RNG_F16) { | |||||
DNN_INC_FLOAT16(run_gaussian<dtype::Float16>(handle())); | DNN_INC_FLOAT16(run_gaussian<dtype::Float16>(handle())); | ||||
} | } | ||||
TEST_F(NAIVE, GAMMA_RNG_F32) { | |||||
run_gamma<dtype::Float32>(handle()); | |||||
} | |||||
TEST_F(NAIVE, GAMMA_RNG_F16) { | |||||
DNN_INC_FLOAT16(run_gamma<dtype::Float16>(handle())); | |||||
} | |||||
TEST_F(NAIVE, POISSON_RNG_F32) { | |||||
run_poisson<dtype::Float32>(handle()); | |||||
} | |||||
TEST_F(NAIVE, POISSON_RNG_F16) { | |||||
DNN_INC_FLOAT16(run_poisson<dtype::Float16>(handle())); | |||||
} | |||||
TEST_F(NAIVE, BETA_RNG_F32) { | |||||
run_beta<dtype::Float32>(handle()); | |||||
} | |||||
TEST_F(NAIVE, BETA_RNG_F16) { | |||||
DNN_INC_FLOAT16(run_beta<dtype::Float16>(handle())); | |||||
} | |||||
TEST_F(NAIVE, PERMUTATION_RNG_F32) { | |||||
run_permutation<dtype::Float32>(handle()); | |||||
} | |||||
TEST_F(NAIVE, PERMUTATION_RNG_INT32) { | |||||
run_permutation<dtype::Int32>(handle()); | |||||
} | |||||
TEST_F(NAIVE, PERMUTATION_RNG_INT16) { | |||||
run_permutation<dtype::Int16>(handle()); | |||||
} | |||||
} // namespace test | } // namespace test | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -6,8 +6,17 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from .distribution import normal, uniform | |||||
from .rng import RNG, seed | |||||
from .rng import RNG, beta, gamma, normal, permutation, poisson, seed, uniform | |||||
__all__ = [ | |||||
"RNG", | |||||
"beta", | |||||
"gamma", | |||||
"normal", | |||||
"permutation", | |||||
"poisson", | |||||
"seed", | |||||
"uniform", | |||||
] | |||||
# pylint: disable=undefined-variable | # pylint: disable=undefined-variable | ||||
del distribution, rng # type: ignore[name-defined] | |||||
del rng # type: ignore[name-defined] |
@@ -1,95 +0,0 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from typing import Iterable, Optional | |||||
from .. import Tensor | |||||
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | |||||
from .rng import _normal, _uniform | |||||
__all__ = ["normal", "uniform"] | |||||
def normal( | |||||
mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None | |||||
) -> Tensor: | |||||
r""" | |||||
Random variable with Gaussian distribution :math:`N(\mu, \sigma)`. | |||||
:param size: output tensor size. | |||||
:param mean: the mean or expectation of the distribution. | |||||
:param std: the standard deviation of the distribution (variance = :math:`\sigma ^ 2`). | |||||
:return: the output tensor. | |||||
Examples: | |||||
.. testcode:: | |||||
import megengine as mge | |||||
import megengine.random as rand | |||||
x = rand.normal(mean=0, std=1, size=(2, 2)) | |||||
print(x.numpy()) | |||||
Outputs: | |||||
.. testoutput:: | |||||
:options: +SKIP | |||||
[[-0.20235455 -0.6959438 ] | |||||
[-1.4939808 -1.5824696 ]] | |||||
""" | |||||
return _normal( | |||||
mean=mean, | |||||
std=std, | |||||
size=size, | |||||
seed=_get_global_rng_seed(), | |||||
device=None, | |||||
handle=0, | |||||
) | |||||
def uniform( | |||||
low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None | |||||
) -> Tensor: | |||||
r""" | |||||
Random variable with uniform distribution $U(0, 1)$. | |||||
:param size: output tensor size. | |||||
:param low: lower range. | |||||
:param high: upper range. | |||||
:return: the output tensor. | |||||
Examples: | |||||
.. testcode:: | |||||
import megengine as mge | |||||
import megengine.random as rand | |||||
x = rand.uniform(size=(2, 2)) | |||||
print(x.numpy()) | |||||
Outputs: | |||||
.. testoutput:: | |||||
:options: +SKIP | |||||
[[0.76901674 0.70496535] | |||||
[0.09365904 0.62957656]] | |||||
""" | |||||
return _uniform( | |||||
low=low, | |||||
high=high, | |||||
size=size, | |||||
seed=_get_global_rng_seed(), | |||||
device=None, | |||||
handle=0, | |||||
) |
@@ -6,8 +6,9 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import collections | |||||
import time | import time | ||||
from typing import Iterable, Optional | |||||
from typing import Iterable, Optional, Union | |||||
from numpy.random import MT19937 | from numpy.random import MT19937 | ||||
@@ -15,15 +16,97 @@ from .. import Tensor | |||||
from ..core._imperative_rt.core2 import apply | from ..core._imperative_rt.core2 import apply | ||||
from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle | from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle | ||||
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | ||||
from ..core._imperative_rt.ops import ( | |||||
get_rng_handle_compnode as _get_rng_handle_compnode, | |||||
) | |||||
from ..core._imperative_rt.ops import new_rng_handle as _new_rng_handle | from ..core._imperative_rt.ops import new_rng_handle as _new_rng_handle | ||||
from ..core._imperative_rt.ops import set_global_rng_seed as _set_global_rng_seed | from ..core._imperative_rt.ops import set_global_rng_seed as _set_global_rng_seed | ||||
from ..core.ops.builtin import GaussianRNG, UniformRNG | |||||
from ..core.ops.builtin import ( | |||||
BetaRNG, | |||||
GammaRNG, | |||||
GaussianRNG, | |||||
PermutationRNG, | |||||
PoissonRNG, | |||||
UniformRNG, | |||||
) | |||||
from ..core.tensor import utils | from ..core.tensor import utils | ||||
from ..device import get_default_device | from ..device import get_default_device | ||||
__all__ = [ | |||||
"seed", | |||||
"RNG", | |||||
"uniform", | |||||
"normal", | |||||
"gamma", | |||||
"beta", | |||||
"poisson", | |||||
"permutation", | |||||
] | |||||
_rng = None | _rng = None | ||||
def _infer_broadcasted_shape(inps: Iterable[Tensor]) -> tuple: | |||||
broadcasted_ndim = inps[0].ndim | |||||
broadcasted_shape = list(inps[0]._tuple_shape) | |||||
for i in range(1, len(inps)): | |||||
cur_ndim = inps[i].ndim | |||||
cur_shape = list(inps[i]._tuple_shape) | |||||
n_dim = max(cur_ndim, broadcasted_ndim) | |||||
for j in range(n_dim - 1, -1, -1): | |||||
cur_dim = cur_ndim + j - n_dim | |||||
broad_dim = broadcasted_ndim + j - n_dim | |||||
cur_size = cur_shape[cur_dim] if cur_dim >= 0 else 1 | |||||
broad_size = broadcasted_shape[broad_dim] if broad_dim >= 0 else 1 | |||||
assert cur_size == broad_size or cur_size == 1 or broad_size == 1, ( | |||||
"The size of inps[{}] ({}) must match the size ({}) at " | |||||
"dim {}".format(i, cur_size, broad_size, j) | |||||
) | |||||
broad_size = max(cur_size, broad_size) | |||||
if broad_dim < 0: | |||||
broadcasted_shape = [broad_size] + broadcasted_shape | |||||
broadcasted_ndim += 1 | |||||
else: | |||||
broadcasted_shape[broad_dim] = broad_size | |||||
return tuple(broadcasted_shape) | |||||
def _broadcast_tensors_with_size( | |||||
inps: Iterable[Tensor], size: Iterable[int] | |||||
) -> Iterable[Tensor]: | |||||
assert inps, "The inps cloud not be empty" | |||||
target_shape = _infer_broadcasted_shape(inps) | |||||
if isinstance(size, collections.abc.Iterable): | |||||
target_shape = tuple(size) + target_shape | |||||
target_ndim = len(target_shape) | |||||
for i in range(len(inps)): | |||||
if inps[i]._tuple_shape != target_shape: | |||||
inps[i] = ( | |||||
inps[i] | |||||
.reshape((1,) * (target_ndim - inps[i].ndim) + inps[i]._tuple_shape) | |||||
._broadcast(target_shape) | |||||
) | |||||
return inps | |||||
def _uniform( | |||||
low: float, | |||||
high: float, | |||||
size: Optional[Iterable[int]], | |||||
seed: int, | |||||
device: str, | |||||
handle: int, | |||||
) -> Tensor: | |||||
assert low < high, "Uniform is not defined when low >= high" | |||||
if size is None: | |||||
size = (1,) | |||||
op = UniformRNG(seed=seed, handle=handle, dtype="float32") | |||||
_ref = Tensor([], dtype="int32", device=device) | |||||
shape = utils.astensor1d(size, _ref, dtype="int32", device=device) | |||||
(output,) = apply(op, shape) | |||||
return low + (high - low) * output | |||||
def _normal( | def _normal( | ||||
mean: float, | mean: float, | ||||
std: float, | std: float, | ||||
@@ -34,63 +117,477 @@ def _normal( | |||||
) -> Tensor: | ) -> Tensor: | ||||
if size is None: | if size is None: | ||||
size = (1,) | size = (1,) | ||||
op = GaussianRNG(seed=seed, mean=mean, std=std, handle=handle) | |||||
op = GaussianRNG(seed=seed, mean=mean, std=std, handle=handle, dtype="float32") | |||||
_ref = Tensor([], dtype="int32", device=device) | _ref = Tensor([], dtype="int32", device=device) | ||||
shape = utils.astensor1d(size, _ref, dtype="int32", device=device) | shape = utils.astensor1d(size, _ref, dtype="int32", device=device) | ||||
(output,) = apply(op, shape) | (output,) = apply(op, shape) | ||||
return output | return output | ||||
def _uniform( | |||||
low: float, | |||||
high: float, | |||||
def _gamma( | |||||
shape: Union[Tensor, float], | |||||
scale: Union[Tensor, float], | |||||
size: Optional[Iterable[int]], | size: Optional[Iterable[int]], | ||||
seed: int, | seed: int, | ||||
device: str, | |||||
handle: int, | handle: int, | ||||
) -> Tensor: | ) -> Tensor: | ||||
assert low < high, "Uniform is not defined when low >= high" | |||||
if size is None: | |||||
size = (1,) | |||||
op = UniformRNG(seed=seed, handle=handle) | |||||
handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle) | |||||
if not isinstance(shape, Tensor): | |||||
assert shape > 0, "Gamma is not defined when shape <= 0" | |||||
shape = Tensor(shape, dtype="float32", device=handle_cn) | |||||
if not isinstance(scale, Tensor): | |||||
assert scale > 0, "Gamma is not defined when scale <= 0" | |||||
scale = Tensor(scale, dtype="float32", device=handle_cn) | |||||
assert ( | |||||
handle_cn is None or handle_cn == shape.device | |||||
), "The shape ({}) must be the same device with handle ({})".format( | |||||
shape.device, handle_cn | |||||
) | |||||
assert ( | |||||
handle_cn is None or handle_cn == scale.device | |||||
), "The scale ({}) must be the same device with handle ({})".format( | |||||
scale.device, handle_cn | |||||
) | |||||
if isinstance(size, int) and size != 0: | |||||
size = (size,) | |||||
shape, scale = _broadcast_tensors_with_size([shape, scale], size) | |||||
op = GammaRNG(seed=seed, handle=handle) | |||||
(output,) = apply(op, shape, scale) | |||||
return output | |||||
def _beta( | |||||
alpha: Union[Tensor, float], | |||||
beta: Union[Tensor, float], | |||||
size: Optional[Iterable[int]], | |||||
seed: int, | |||||
handle: int, | |||||
) -> Tensor: | |||||
handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle) | |||||
if not isinstance(alpha, Tensor): | |||||
assert alpha > 0, "Beta is not defined when alpha <= 0" | |||||
alpha = Tensor(alpha, dtype="float32", device=handle_cn) | |||||
if not isinstance(beta, Tensor): | |||||
assert beta > 0, "Beta is not defined when beta <= 0" | |||||
beta = Tensor(beta, dtype="float32", device=handle_cn) | |||||
assert ( | |||||
handle_cn is None or handle_cn == alpha.device | |||||
), "The alpha ({}) must be the same device with handle ({})".format( | |||||
alpha.device, handle_cn | |||||
) | |||||
assert ( | |||||
handle_cn is None or handle_cn == beta.device | |||||
), "The beta ({}) must be the same device with handle ({})".format( | |||||
beta.device, handle_cn | |||||
) | |||||
if isinstance(size, int) and size != 0: | |||||
size = (size,) | |||||
alpha, beta = _broadcast_tensors_with_size([alpha, beta], size) | |||||
op = BetaRNG(seed=seed, handle=handle) | |||||
(output,) = apply(op, alpha, beta) | |||||
return output | |||||
def _poisson( | |||||
lam: Union[Tensor, float], size: Optional[Iterable[int]], seed: int, handle: int | |||||
) -> Tensor: | |||||
handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle) | |||||
if not isinstance(lam, Tensor): | |||||
assert lam > 0, "Poisson is not defined when lam <= 0" | |||||
lam = Tensor(lam, dtype="float32", device=handle_cn) | |||||
if isinstance(size, int) and size != 0: | |||||
size = (size,) | |||||
assert ( | |||||
handle_cn is None or handle_cn == lam.device | |||||
), "The lam ({}) must be the same device with handle ({})".format( | |||||
lam.device, handle_cn | |||||
) | |||||
(lam,) = _broadcast_tensors_with_size([lam], size) | |||||
op = PoissonRNG(seed=seed, handle=handle) | |||||
(output,) = apply(op, lam) | |||||
return output | |||||
def _permutation(n: int, seed: int, device: str, handle: int, dtype: str) -> Tensor: | |||||
assert isinstance(n, int) and n > 0, "Permutation is not defined when n <= 0" | |||||
size = (n,) | |||||
op = PermutationRNG(seed=seed, handle=handle, dtype=dtype) | |||||
_ref = Tensor([], dtype="int32", device=device) | _ref = Tensor([], dtype="int32", device=device) | ||||
shape = utils.astensor1d(size, _ref, dtype="int32", device=device) | shape = utils.astensor1d(size, _ref, dtype="int32", device=device) | ||||
(output,) = apply(op, shape) | (output,) = apply(op, shape) | ||||
return low + (high - low) * output | |||||
return output | |||||
class RNG: | class RNG: | ||||
def __init__(self, seed=0, device=None): | |||||
self.seed = seed | |||||
self.device = device if device else get_default_device() | |||||
self.handle = _new_rng_handle(self.device, self.seed) | |||||
r""" | |||||
:class:`RNG` exposes a number of methods for generating random numbers. | |||||
:param seed: random seed used to initialize the pseudo-random number generator. | |||||
Default: None | |||||
:param device: the device of generated tensor. Default: None | |||||
Examples: | |||||
.. testcode:: | |||||
import megengine.random as rand | |||||
rng = rand.RNG(seed=100) | |||||
x = rng.uniform(size=(2, 2)) | |||||
print(x.numpy()) | |||||
Outputs: | |||||
.. testoutput:: | |||||
:options: +SKIP | |||||
[[0.84811664 0.6147553 ] | |||||
[0.59429836 0.64727545]] | |||||
""" | |||||
def __init__(self, seed: int = None, device: str = None): | |||||
self._device = device if device else get_default_device() | |||||
if seed is not None: | |||||
self._seed = seed | |||||
self._handle = _new_rng_handle(self._device, self._seed) | |||||
else: | |||||
self._seed = _get_global_rng_seed | |||||
self._handle = 0 | |||||
self._device = None | |||||
def uniform( | def uniform( | ||||
self, low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None | self, low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None | ||||
): | ): | ||||
r""" | |||||
Random variable with uniform distribution $U(0, 1)$. | |||||
:param low: lower range. Default: 0 | |||||
:param high: upper range. Default: 1 | |||||
:param size: the size of output tensor. Default: None | |||||
:return: the output tensor. | |||||
Examples: | |||||
.. testcode:: | |||||
import megengine as mge | |||||
import megengine.random as rand | |||||
x = rand.uniform(size=(2, 2)) | |||||
print(x.numpy()) | |||||
Outputs: | |||||
.. testoutput:: | |||||
:options: +SKIP | |||||
[[0.91600335 0.6680226 ] | |||||
[0.2046729 0.2769141 ]] | |||||
""" | |||||
_seed = self._seed() if callable(self._seed) else self._seed | |||||
return _uniform( | return _uniform( | ||||
low=low, | low=low, | ||||
high=high, | high=high, | ||||
size=size, | size=size, | ||||
seed=self.seed, | |||||
device=self.device, | |||||
handle=self.handle, | |||||
seed=_seed, | |||||
device=self._device, | |||||
handle=self._handle, | |||||
) | ) | ||||
def normal( | def normal( | ||||
self, mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None | self, mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None | ||||
): | ): | ||||
r""" | |||||
Random variable with Gaussian distribution :math:`N(\mu, \sigma)`. | |||||
:param mean: the mean or expectation of the distribution. Default: 0 | |||||
:param std: the standard deviation of the distribution (variance = :math:`\sigma ^ 2`). | |||||
Default: 1 | |||||
:param size: the size of output tensor. Default: None | |||||
:return: the output tensor. | |||||
Examples: | |||||
.. testcode:: | |||||
import megengine as mge | |||||
import megengine.random as rand | |||||
x = rand.normal(mean=0, std=1, size=(2, 2)) | |||||
print(x.numpy()) | |||||
Outputs: | |||||
.. testoutput:: | |||||
:options: +SKIP | |||||
[[-1.4010863 -0.9874344 ] | |||||
[ 0.56373274 0.79656655]] | |||||
""" | |||||
_seed = self._seed() if callable(self._seed) else self._seed | |||||
return _normal( | return _normal( | ||||
mean=mean, | mean=mean, | ||||
std=std, | std=std, | ||||
size=size, | size=size, | ||||
seed=self.seed, | |||||
device=self.device, | |||||
handle=self.handle, | |||||
seed=_seed, | |||||
device=self._device, | |||||
handle=self._handle, | |||||
) | |||||
def gamma( | |||||
self, | |||||
shape: Union[Tensor, float], | |||||
scale: Union[Tensor, float] = 1, | |||||
size: Optional[Iterable[int]] = None, | |||||
): | |||||
r""" | |||||
Random variable with Gamma distribution :math:`\Gamma(k, \theta)`. | |||||
The corresponding probability density function is | |||||
.. math:: | |||||
p(x)=x^{k-1} \frac{e^{-x / \theta}}{\theta^{k} \Gamma(k)} | |||||
\quad \text { for } x>0 \quad k, \theta>0, | |||||
where :math:`\Gamma(k)` is the gamma function, | |||||
.. math:: | |||||
\Gamma(k)=(k-1) ! \quad \text { for } \quad k>0. | |||||
:param shape: the shape parameter (sometimes designated "k") of the distribution. | |||||
Must be non-negative. | |||||
:param scale: the scale parameter (sometimes designated "theta") of the distribution. | |||||
Must be non-negative. Default: 1 | |||||
:param size: the size of output tensor. If shape and scale are scalars and given size is, e.g., | |||||
`(m, n)`, then the output shape is `(m, n)`. If shape or scale is a Tensor and given size | |||||
is, e.g., `(m, n)`, then the output shape is `(m, n) + broadcast(shape, scale).shape`. | |||||
The broadcast rules are consistent with `numpy.broadcast`. Default: None | |||||
:return: the output tensor. | |||||
Examples: | |||||
.. testcode:: | |||||
import megengine as mge | |||||
import megengine.random as rand | |||||
x = rand.gamma(shape=2, scale=1, size=(2, 2)) | |||||
print(x.numpy()) | |||||
shape = mge.Tensor([[ 1], | |||||
[10]], dtype="float32") | |||||
scale = mge.Tensor([1,5], dtype="float32") | |||||
x = rand.gamma(shape=shape, scale=scale) | |||||
print(x.numpy()) | |||||
x = rand.gamma(shape=shape, scale=scale, size=2) | |||||
print(x.numpy()) | |||||
Outputs: | |||||
.. testoutput:: | |||||
:options: +SKIP | |||||
[[1.5064533 4.0689363 ] | |||||
[0.71639484 1.4551026 ]] | |||||
[[ 0.4352188 11.399335 ] | |||||
[ 9.1888 52.009277 ]] | |||||
[[[ 1.1726005 3.9654975 ] | |||||
[13.656933 36.559006 ]] | |||||
[[ 0.25848487 2.5540342 ] | |||||
[11.960409 21.031536 ]]] | |||||
""" | |||||
_seed = self._seed() if callable(self._seed) else self._seed | |||||
return _gamma( | |||||
shape=shape, scale=scale, size=size, seed=_seed, handle=self._handle | |||||
) | |||||
def beta( | |||||
self, | |||||
alpha: Union[Tensor, float], | |||||
beta: Union[Tensor, float], | |||||
size: Optional[Iterable[int]] = None, | |||||
): | |||||
r""" | |||||
Random variable with Beta distribution :math:`\operatorname{Beta}(\alpha, \beta)`. | |||||
The corresponding probability density function is | |||||
.. math:: | |||||
p(x)=\frac{1}{\mathrm{~B}(\alpha, \beta)} x^{\alpha-1}(1-x)^{\beta-1} | |||||
\quad \text { for } \alpha, \beta>0, | |||||
where :math:`\mathrm{~B}(\alpha, \beta)` is the beta function, | |||||
.. math:: | |||||
\mathrm{~B}(\alpha, \beta)=\int_{0}^{1} t^{\alpha-1}(1-t)^{\beta-1} d t. | |||||
:param alpha: the alpha parameter of the distribution. Must be non-negative. | |||||
:param beta: the beta parameter of the distribution. Must be non-negative. | |||||
:param size: the size of output tensor. If alpha and beta are scalars and given size is, e.g., | |||||
`(m, n)`, then the output shape is `(m, n)`. If alpha or beta is a Tensor and given size | |||||
is, e.g., `(m, n)`, then the output shape is `(m, n) + broadcast(alpha, beta).shape`. | |||||
The broadcast rules are consistent with `numpy.broadcast`. Default: None | |||||
:return: the output tensor. | |||||
Examples: | |||||
.. testcode:: | |||||
import megengine as mge | |||||
import megengine.random as rand | |||||
x = rand.beta(alpha=2, beta=1, size=(2, 2)) | |||||
print(x.numpy()) | |||||
alpha = mge.Tensor([[0.5], | |||||
[ 3]], dtype="float32") | |||||
beta = mge.Tensor([0.5,5], dtype="float32") | |||||
x = rand.beta(alpha=alpha, beta=beta) | |||||
print(x.numpy()) | |||||
x = rand.beta(alpha=alpha, beta=beta, size=2) | |||||
print(x.numpy()) | |||||
Outputs: | |||||
.. testoutput:: | |||||
:options: +SKIP | |||||
[[0.582565 0.91763186] | |||||
[0.86963767 0.6088103 ]] | |||||
[[0.41503012 0.16438372] | |||||
[0.90159506 0.47588003]] | |||||
[[[0.55195075 0.01111084] | |||||
[0.95298755 0.25048104]] | |||||
[[0.11680304 0.13859665] | |||||
[0.997879 0.43259275]]] | |||||
""" | |||||
_seed = self._seed() if callable(self._seed) else self._seed | |||||
return _beta(alpha=alpha, beta=beta, size=size, seed=_seed, handle=self._handle) | |||||
def poisson(self, lam: Union[float, Tensor], size: Optional[Iterable[int]] = None): | |||||
r""" | |||||
Random variable with poisson distribution :math:`\operatorname{Poisson}(\lambda)`. | |||||
The corresponding probability density function is | |||||
.. math:: | |||||
f(k ; \lambda)=\frac{\lambda^{k} e^{-\lambda}}{k !}, | |||||
where k is the number of occurrences :math:`({\displaystyle k=0,1,2...})`. | |||||
:param lam: the lambda parameter of the distribution. Must be non-negative. | |||||
:param size: the size of output tensor. If lam is a scalar and given size is, e.g., `(m, n)`, | |||||
then the output shape is `(m, n)`. If lam is a Tensor with shape `(k, v)` and given | |||||
size is, e.g., `(m, n)`, then the output shape is `(m, n, k, v)`. Default: None. | |||||
:return: the output tensor. | |||||
Examples: | |||||
.. testcode:: | |||||
import megengine as mge | |||||
import megengine.random as rand | |||||
x = rand.poisson(lam=2., size=(1, 3)) | |||||
print(x.numpy()) | |||||
lam = mge.Tensor([[1.,1.], | |||||
[10,10]], dtype="float32") | |||||
x = rand.poisson(lam=lam) | |||||
print(x.numpy()) | |||||
x = rand.poisson(lam=lam, size=(1,3)) | |||||
print(x.numpy()) | |||||
Outputs: | |||||
.. testoutput:: | |||||
:options: +SKIP | |||||
[[3. 1. 3.]] | |||||
[[ 2. 2.] | |||||
[12. 11.]] | |||||
[[[[ 1. 1.] | |||||
[11. 4.]] | |||||
[[ 0. 0.] | |||||
[ 9. 13.]] | |||||
[[ 0. 1.] | |||||
[ 7. 12.]]]] | |||||
""" | |||||
_seed = self._seed() if callable(self._seed) else self._seed | |||||
return _poisson(lam=lam, size=size, seed=_seed, handle=self._handle) | |||||
def permutation(self, n: int, *, dtype: str = "int32"): | |||||
r""" | |||||
Generates a random permutation of integers from :math:`0` to :math:`n - 1`. | |||||
:param n: the upper bound. Must be larger than 0. | |||||
:param dtype: the output data type. int32, int16 and float32 are | |||||
supported. Default: int32 | |||||
:return: the output tensor. | |||||
Examples: | |||||
.. testcode:: | |||||
import megengine as mge | |||||
import megengine.random as rand | |||||
x = rand.permutation(n=10, dtype="int32") | |||||
print(x.numpy()) | |||||
x = rand.permutation(n=10, dtype="float32") | |||||
print(x.numpy()) | |||||
Outputs: | |||||
.. testoutput:: | |||||
:options: +SKIP | |||||
[4 5 0 7 3 8 6 1 9 2] | |||||
[3. 4. 9. 0. 6. 8. 7. 1. 5. 2.] | |||||
""" | |||||
_seed = self._seed() if callable(self._seed) else self._seed | |||||
return _permutation( | |||||
n=n, seed=_seed, device=self._device, handle=self._handle, dtype=dtype | |||||
) | ) | ||||
def __del__(self): | def __del__(self): | ||||
_delete_rng_handle(self.handle) | |||||
if self._handle != 0: | |||||
_delete_rng_handle(self._handle) | |||||
def _default_rng(): | |||||
r"""Default constructor for :class:`RNG`.""" | |||||
return RNG(seed=None, device=None) | |||||
_default_handle = _default_rng() | |||||
uniform = _default_handle.uniform | |||||
normal = _default_handle.normal | |||||
gamma = _default_handle.gamma | |||||
beta = _default_handle.beta | |||||
poisson = _default_handle.poisson | |||||
permutation = _default_handle.permutation | |||||
def _random_seed_generator(): | def _random_seed_generator(): | ||||
@@ -476,4 +476,5 @@ void init_ops(py::module m) { | |||||
}, py::call_guard<py::gil_scoped_release>()); | }, py::call_guard<py::gil_scoped_release>()); | ||||
m.def("set_global_rng_seed", &rng::set_global_rng_seed); | m.def("set_global_rng_seed", &rng::set_global_rng_seed); | ||||
m.def("get_global_rng_seed", &rng::get_global_rng_seed); | m.def("get_global_rng_seed", &rng::get_global_rng_seed); | ||||
m.def("get_rng_handle_compnode", &rng::get_rng_handle_compnode); | |||||
} | } |
@@ -9,8 +9,8 @@ | |||||
import numpy as np | import numpy as np | ||||
import pytest | import pytest | ||||
import megengine | |||||
from megengine import is_cuda_available, tensor | |||||
import megengine.functional as F | |||||
from megengine import Tensor | |||||
from megengine.core._imperative_rt import CompNode | from megengine.core._imperative_rt import CompNode | ||||
from megengine.core._imperative_rt.core2 import apply | from megengine.core._imperative_rt.core2 import apply | ||||
from megengine.core._imperative_rt.ops import ( | from megengine.core._imperative_rt.ops import ( | ||||
@@ -18,10 +18,16 @@ from megengine.core._imperative_rt.ops import ( | |||||
get_global_rng_seed, | get_global_rng_seed, | ||||
new_rng_handle, | new_rng_handle, | ||||
) | ) | ||||
from megengine.core.ops.builtin import GaussianRNG, UniformRNG | |||||
from megengine.core.ops.builtin import ( | |||||
BetaRNG, | |||||
GammaRNG, | |||||
GaussianRNG, | |||||
PermutationRNG, | |||||
PoissonRNG, | |||||
UniformRNG, | |||||
) | |||||
from megengine.distributed.helper import get_device_count_by_fork | from megengine.distributed.helper import get_device_count_by_fork | ||||
from megengine.random import RNG | from megengine.random import RNG | ||||
from megengine.random.rng import _normal, _uniform | |||||
@pytest.mark.skipif( | @pytest.mark.skipif( | ||||
@@ -34,22 +40,24 @@ def test_gaussian_op(): | |||||
11, | 11, | ||||
12, | 12, | ||||
) | ) | ||||
shape = tensor(shape, dtype="int32") | |||||
op = GaussianRNG(seed=get_global_rng_seed(), mean=1.0, std=3.0) | |||||
shape = Tensor(shape, dtype="int32") | |||||
op = GaussianRNG(seed=get_global_rng_seed(), mean=1.0, std=3.0, dtype="float32") | |||||
(output,) = apply(op, shape) | (output,) = apply(op, shape) | ||||
assert np.fabs(output.numpy().mean() - 1.0) < 1e-1 | assert np.fabs(output.numpy().mean() - 1.0) < 1e-1 | ||||
assert np.sqrt(output.numpy().var()) - 3.0 < 1e-1 | |||||
assert np.fabs(np.sqrt(output.numpy().var()) - 3.0) < 1e-1 | |||||
assert str(output.device) == str(CompNode("xpux")) | assert str(output.device) == str(CompNode("xpux")) | ||||
assert output.dtype == np.float32 | |||||
cn = CompNode("xpu2") | cn = CompNode("xpu2") | ||||
seed = 233333 | seed = 233333 | ||||
h = new_rng_handle(cn, seed) | h = new_rng_handle(cn, seed) | ||||
op = GaussianRNG(seed=seed, mean=3.0, std=1.0, handle=h) | |||||
op = GaussianRNG(seed=seed, mean=3.0, std=1.0, dtype="float32", handle=h) | |||||
(output,) = apply(op, shape) | (output,) = apply(op, shape) | ||||
delete_rng_handle(h) | delete_rng_handle(h) | ||||
assert np.fabs(output.numpy().mean() - 3.0) < 1e-1 | assert np.fabs(output.numpy().mean() - 3.0) < 1e-1 | ||||
assert np.sqrt(output.numpy().var()) - 1.0 < 1e-1 | |||||
assert np.fabs(np.sqrt(output.numpy().var()) - 1.0) < 1e-1 | |||||
assert str(output.device) == str(cn) | assert str(output.device) == str(cn) | ||||
assert output.dtype == np.float32 | |||||
@pytest.mark.skipif( | @pytest.mark.skipif( | ||||
@@ -62,20 +70,138 @@ def test_uniform_op(): | |||||
11, | 11, | ||||
12, | 12, | ||||
) | ) | ||||
shape = tensor(shape, dtype="int32") | |||||
op = UniformRNG(seed=get_global_rng_seed()) | |||||
shape = Tensor(shape, dtype="int32") | |||||
op = UniformRNG(seed=get_global_rng_seed(), dtype="float32") | |||||
(output,) = apply(op, shape) | (output,) = apply(op, shape) | ||||
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 | assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 | ||||
assert str(output.device) == str(CompNode("xpux")) | assert str(output.device) == str(CompNode("xpux")) | ||||
assert output.dtype == np.float32 | |||||
cn = CompNode("xpu2") | cn = CompNode("xpu2") | ||||
seed = 233333 | seed = 233333 | ||||
h = new_rng_handle(cn, seed) | h = new_rng_handle(cn, seed) | ||||
op = UniformRNG(seed=seed, handle=h) | |||||
op = UniformRNG(seed=seed, dtype="float32", handle=h) | |||||
(output,) = apply(op, shape) | (output,) = apply(op, shape) | ||||
delete_rng_handle(h) | delete_rng_handle(h) | ||||
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 | assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 | ||||
assert str(output.device) == str(cn) | assert str(output.device) == str(cn) | ||||
assert output.dtype == np.float32 | |||||
@pytest.mark.skipif( | |||||
get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||||
) | |||||
def test_gamma_op(): | |||||
_shape, _scale = 2, 0.8 | |||||
_expected_mean, _expected_std = _shape * _scale, np.sqrt(_shape) * _scale | |||||
shape = F.full([8, 9, 11, 12], value=_shape, dtype="float32") | |||||
scale = F.full([8, 9, 11, 12], value=_scale, dtype="float32") | |||||
op = GammaRNG(seed=get_global_rng_seed(), handle=0) | |||||
(output,) = apply(op, shape, scale) | |||||
assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1 | |||||
assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1 | |||||
assert str(output.device) == str(CompNode("xpux")) | |||||
cn = CompNode("xpu2") | |||||
seed = 233333 | |||||
h = new_rng_handle(cn, seed) | |||||
shape = F.full([8, 9, 11, 12], value=_shape, dtype="float32", device="xpu2") | |||||
scale = F.full([8, 9, 11, 12], value=_scale, dtype="float32", device="xpu2") | |||||
op = GammaRNG(seed=seed, handle=h) | |||||
(output,) = apply(op, shape, scale) | |||||
delete_rng_handle(h) | |||||
assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1 | |||||
assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1 | |||||
assert str(output.device) == str(cn) | |||||
@pytest.mark.skipif( | |||||
get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||||
) | |||||
def test_beta_op(): | |||||
_alpha, _beta = 2, 0.8 | |||||
_expected_mean = _alpha / (_alpha + _beta) | |||||
_expected_std = np.sqrt( | |||||
_alpha * _beta / ((_alpha + _beta) ** 2 * (_alpha + _beta + 1)) | |||||
) | |||||
alpha = F.full([8, 9, 11, 12], value=_alpha, dtype="float32") | |||||
beta = F.full([8, 9, 11, 12], value=_beta, dtype="float32") | |||||
op = BetaRNG(seed=get_global_rng_seed()) | |||||
(output,) = apply(op, alpha, beta) | |||||
assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1 | |||||
assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1 | |||||
assert str(output.device) == str(CompNode("xpux")) | |||||
cn = CompNode("xpu2") | |||||
seed = 233333 | |||||
h = new_rng_handle(cn, seed) | |||||
alpha = F.full([8, 9, 11, 12], value=_alpha, dtype="float32", device=cn) | |||||
beta = F.full([8, 9, 11, 12], value=_beta, dtype="float32", device=cn) | |||||
op = BetaRNG(seed=seed, handle=h) | |||||
(output,) = apply(op, alpha, beta) | |||||
delete_rng_handle(h) | |||||
assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1 | |||||
assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1 | |||||
assert str(output.device) == str(cn) | |||||
@pytest.mark.skipif( | |||||
get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||||
) | |||||
def test_poisson_op(): | |||||
lam = F.full([8, 9, 11, 12], value=2, dtype="float32") | |||||
op = PoissonRNG(seed=get_global_rng_seed()) | |||||
(output,) = apply(op, lam) | |||||
assert np.fabs(output.numpy().mean() - 2.0) < 1e-1 | |||||
assert np.fabs(np.sqrt(output.numpy().var()) - np.sqrt(2.0)) < 1e-1 | |||||
assert str(output.device) == str(CompNode("xpux")) | |||||
cn = CompNode("xpu2") | |||||
seed = 233333 | |||||
h = new_rng_handle(cn, seed) | |||||
lam = F.full([8, 9, 11, 12], value=2, dtype="float32", device=cn) | |||||
op = PoissonRNG(seed=seed, handle=h) | |||||
(output,) = apply(op, lam) | |||||
delete_rng_handle(h) | |||||
assert np.fabs(output.numpy().mean() - 2.0) < 1e-1 | |||||
assert np.fabs(np.sqrt(output.numpy().var()) - np.sqrt(2.0)) < 1e-1 | |||||
assert str(output.device) == str(cn) | |||||
@pytest.mark.skipif( | |||||
get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||||
) | |||||
def test_permutation_op(): | |||||
n = 1000 | |||||
def test_permutation_op_dtype(dtype): | |||||
def sum_result(res, fun): | |||||
return sum([1 if i == v else 0 for i, v in enumerate(fun(res.numpy()))]) | |||||
shape = Tensor((n,), dtype="int32") | |||||
op = PermutationRNG(seed=get_global_rng_seed(), dtype=dtype) | |||||
(output,) = apply(op, shape) | |||||
assert sum_result(output, lambda x: x) < 500 | |||||
assert sum_result(output, np.sort) == n | |||||
assert str(output.device) == str(CompNode("xpux")) | |||||
assert output.dtype == dtype | |||||
cn = CompNode("xpu2") | |||||
seed = 233333 | |||||
h = new_rng_handle(cn, seed) | |||||
op = PermutationRNG(seed=seed, handle=h, dtype=dtype) | |||||
(output,) = apply(op, shape) | |||||
delete_rng_handle(h) | |||||
assert sum_result(output, lambda x: x) < 500 | |||||
assert sum_result(output, np.sort) == n | |||||
assert str(output.device) == str(cn) | |||||
assert output.dtype == dtype | |||||
test_permutation_op_dtype(np.float32) | |||||
test_permutation_op_dtype(np.int32) | |||||
test_permutation_op_dtype(np.int16) | |||||
@pytest.mark.skipif( | @pytest.mark.skipif( | ||||
@@ -133,3 +259,131 @@ def test_NormalRNG(): | |||||
assert all(out.shape.numpy() == np.array([20, 30, 40])) | assert all(out.shape.numpy() == np.array([20, 30, 40])) | ||||
assert np.abs(out.mean().numpy() - mean) / std < 0.1 | assert np.abs(out.mean().numpy() - mean) / std < 0.1 | ||||
assert np.abs(np.std(out.numpy()) - std) < 0.1 | assert np.abs(np.std(out.numpy()) - std) < 0.1 | ||||
@pytest.mark.skipif( | |||||
get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||||
) | |||||
def test_GammaRNG(): | |||||
m1 = RNG(seed=111, device="xpu0") | |||||
m2 = RNG(seed=111, device="xpu1") | |||||
m3 = RNG(seed=222, device="xpu0") | |||||
out1 = m1.gamma(2, size=(100,)) | |||||
out1_ = m1.uniform(size=(100,)) | |||||
out2 = m2.gamma(2, size=(100,)) | |||||
out3 = m3.gamma(2, size=(100,)) | |||||
np.testing.assert_equal(out1.numpy(), out2.numpy()) | |||||
assert out1.device == "xpu0" and out2.device == "xpu1" | |||||
assert not (out1.numpy() == out3.numpy()).all() | |||||
assert not (out1.numpy() == out1_.numpy()).all() | |||||
shape = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32, device="xpu0") | |||||
scale = Tensor([0.5, 1, 1.5], dtype=np.float32, device="xpu0") | |||||
expected_mean = (shape * scale).numpy() | |||||
expected_std = (F.sqrt(shape) * scale).numpy() | |||||
out = m1.gamma(shape=shape, scale=scale, size=(20, 30, 40)) | |||||
out_shp = out.shape | |||||
if isinstance(out_shp, tuple): | |||||
assert out_shp == (20, 30, 40, 2, 3) | |||||
else: | |||||
assert all(out.shape.numpy() == np.array([20, 30, 40, 2, 3])) | |||||
assert ( | |||||
np.abs(out.mean(axis=(0, 1)).numpy() - expected_mean) / expected_std | |||||
).mean() < 0.1 | |||||
assert (np.abs(np.std(out.numpy(), axis=(0, 1)) - expected_std)).mean() < 0.1 | |||||
@pytest.mark.skipif( | |||||
get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||||
) | |||||
def test_BetaRNG(): | |||||
m1 = RNG(seed=111, device="xpu0") | |||||
m2 = RNG(seed=111, device="xpu1") | |||||
m3 = RNG(seed=222, device="xpu0") | |||||
out1 = m1.beta(2, 1, size=(100,)) | |||||
out1_ = m1.uniform(size=(100,)) | |||||
out2 = m2.beta(2, 1, size=(100,)) | |||||
out3 = m3.beta(2, 1, size=(100,)) | |||||
np.testing.assert_equal(out1.numpy(), out2.numpy()) | |||||
assert out1.device == "xpu0" and out2.device == "xpu1" | |||||
assert not (out1.numpy() == out3.numpy()).all() | |||||
assert not (out1.numpy() == out1_.numpy()).all() | |||||
alpha = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32, device="xpu0") | |||||
beta = Tensor([0.5, 1, 1.5], dtype=np.float32, device="xpu0") | |||||
expected_mean = (alpha / (alpha + beta)).numpy() | |||||
expected_std = ( | |||||
F.sqrt(alpha * beta / (F.pow(alpha + beta, 2) * (alpha + beta + 1))) | |||||
).numpy() | |||||
out = m1.beta(alpha=alpha, beta=beta, size=(20, 30)) | |||||
out_shp = out.shape | |||||
if isinstance(out_shp, tuple): | |||||
assert out_shp == (20, 30, 2, 3) | |||||
else: | |||||
assert all(out.shape.numpy() == np.array([20, 30, 2, 3])) | |||||
assert ( | |||||
np.abs(out.mean(axis=(0, 1)).numpy() - expected_mean) / expected_std | |||||
).mean() < 0.1 | |||||
assert (np.abs(np.std(out.numpy(), axis=(0, 1)) - expected_std)).mean() < 0.1 | |||||
@pytest.mark.skipif( | |||||
get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||||
) | |||||
def test_PoissonRNG(): | |||||
m1 = RNG(seed=111, device="xpu0") | |||||
m2 = RNG(seed=111, device="xpu1") | |||||
m3 = RNG(seed=222, device="xpu0") | |||||
lam = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32) | |||||
out1 = m1.poisson(lam.to("xpu0"), size=(100,)) | |||||
out2 = m2.poisson(lam.to("xpu1"), size=(100,)) | |||||
out3 = m3.poisson(lam.to("xpu0"), size=(100,)) | |||||
np.testing.assert_equal(out1.numpy(), out2.numpy()) | |||||
assert out1.device == "xpu0" and out2.device == "xpu1" | |||||
assert not (out1.numpy() == out3.numpy()).all() | |||||
out = m1.poisson(lam.to("xpu0"), size=(20, 30)) | |||||
out_shp = out.shape | |||||
expected_shape = (20, 30) + lam._tuple_shape | |||||
if isinstance(out_shp, tuple): | |||||
assert out_shp == expected_shape | |||||
else: | |||||
assert all(out.shape.numpy() == np.array(expected_shape)) | |||||
lam = lam.numpy() | |||||
assert (np.abs(out.mean(axis=(0, 1)).numpy() - lam) / np.sqrt(lam)).mean() < 0.1 | |||||
assert np.abs(np.std(out.numpy(), axis=(0, 1)) - np.sqrt(lam)).mean() < 0.1 | |||||
@pytest.mark.skipif( | |||||
get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||||
) | |||||
def test_PermutationRNG(): | |||||
m1 = RNG(seed=111, device="xpu0") | |||||
m2 = RNG(seed=111, device="xpu1") | |||||
m3 = RNG(seed=222, device="xpu0") | |||||
out1 = m1.permutation(n=1000) | |||||
out1_ = m1.uniform(size=(1000,)) | |||||
out2 = m2.permutation(n=1000) | |||||
out3 = m3.permutation(n=1000) | |||||
np.testing.assert_equal(out1.numpy(), out2.numpy()) | |||||
assert out1.device == "xpu0" and out2.device == "xpu1" | |||||
assert not (out1.numpy() == out3.numpy()).all() | |||||
assert not (out1.numpy() == out1_.numpy()).all() | |||||
out = m1.permutation(n=1000) | |||||
out_shp = out.shape | |||||
if isinstance(out_shp, tuple): | |||||
assert out_shp == (1000,) | |||||
else: | |||||
assert all(out.shape.numpy() == np.array([1000])) | |||||
def sum_result(res, fun): | |||||
return sum([1 if i == v else 0 for i, v in enumerate(fun(res.numpy()))]) | |||||
assert sum_result(out, lambda x: x) < 500 | |||||
assert sum_result(out, np.sort) == 1000 |
@@ -180,6 +180,20 @@ struct OpMeth<UniformRNG> { | |||||
mgb_assert(handle_seed == rng.seed, | mgb_assert(handle_seed == rng.seed, | ||||
"inconsistent rng seed: rng op: %lu handle: %lu", | "inconsistent rng seed: rng op: %lu handle: %lu", | ||||
handle_seed, rng.seed); | handle_seed, rng.seed); | ||||
return {handle_seed, rng.dtype.enumv()}; | |||||
} | |||||
}; | |||||
template <> | |||||
struct OpMeth<PoissonRNG> { | |||||
using DnnOp = megdnn::PoissonRNG; | |||||
using Param = DnnOp::Param; | |||||
using OpNode = mgb::opr::PoissonRNG; | |||||
static Param make_param(const PoissonRNG& rng) { | |||||
auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); | |||||
mgb_assert(handle_seed == rng.seed, | |||||
"inconsistent rng seed: rng op: %lu handle: %lu", | |||||
handle_seed, rng.seed); | |||||
return {handle_seed}; | return {handle_seed}; | ||||
} | } | ||||
}; | }; | ||||
@@ -194,16 +208,168 @@ struct OpMeth<GaussianRNG> { | |||||
mgb_assert(handle_seed == rng.seed, | mgb_assert(handle_seed == rng.seed, | ||||
"inconsistent rng seed: rng op: %lu handle: %lu", | "inconsistent rng seed: rng op: %lu handle: %lu", | ||||
handle_seed, rng.seed); | handle_seed, rng.seed); | ||||
return {handle_seed, rng.mean, rng.std}; | |||||
return {handle_seed, rng.mean, rng.std, rng.dtype.enumv()}; | |||||
} | |||||
}; | |||||
template <> | |||||
struct OpMeth<GammaRNG> { | |||||
using DnnOp = megdnn::GammaRNG; | |||||
using Param = DnnOp::Param; | |||||
using OpNode = mgb::opr::GammaRNG; | |||||
static Param make_param(const GammaRNG& rng) { | |||||
auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); | |||||
mgb_assert(handle_seed == rng.seed, | |||||
"inconsistent rng seed: rng op: %lu handle: %lu", | |||||
handle_seed, rng.seed); | |||||
return {handle_seed}; | |||||
} | } | ||||
}; | }; | ||||
template <> | |||||
struct OpMeth<PermutationRNG> { | |||||
using DnnOp = megdnn::PermutationRNG; | |||||
using Param = DnnOp::Param; | |||||
using OpNode = mgb::opr::PermutationRNG; | |||||
static Param make_param(const PermutationRNG& rng) { | |||||
auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); | |||||
mgb_assert(handle_seed == rng.seed, | |||||
"inconsistent rng seed: rng op: %lu handle: %lu", | |||||
handle_seed, rng.seed); | |||||
return {handle_seed, rng.dtype.enumv()}; | |||||
} | |||||
}; | |||||
template <> | |||||
struct OpMeth<BetaRNG> { | |||||
using DnnOp = megdnn::BetaRNG; | |||||
using Param = DnnOp::Param; | |||||
using OpNode = mgb::opr::BetaRNG; | |||||
static Param make_param(const BetaRNG& rng) { | |||||
auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); | |||||
mgb_assert(handle_seed == rng.seed, | |||||
"inconsistent rng seed: rng op: %lu handle: %lu", | |||||
handle_seed, rng.seed); | |||||
return {handle_seed}; | |||||
} | |||||
}; | |||||
template <bool> | |||||
struct _InferLayout; | |||||
template <int nr_in> | |||||
struct _RNGOprMaker; | |||||
template <int nr_in> | |||||
struct _RNGOprInvoker; | |||||
template<> | |||||
struct _InferLayout<true> | |||||
{ | |||||
template<typename Op> | |||||
static TensorLayout do_infer(const TensorPtr& inp, const Op& rng){ | |||||
TensorShape tshape; | |||||
auto hv = inp->get_value().proxy_to_default_cpu(); | |||||
cg::copy_tensor_value_to_shape(tshape, hv); | |||||
return TensorLayout(tshape, rng.dtype); | |||||
} | |||||
template<typename Op> | |||||
static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng){ | |||||
TensorLayout out_layout = inp.layout; | |||||
out_layout.dtype = rng.dtype; | |||||
if (inp.layout.ndim == 0 || inp.value.empty()) { | |||||
out_layout.ndim = 0; | |||||
return out_layout; | |||||
} | |||||
mgb_assert( | |||||
inp.layout.ndim == 1, | |||||
"target shape of %s expects ndim=1; got ndim=%lu actually", | |||||
rng.dyn_typeinfo()->name, | |||||
inp.layout.ndim); | |||||
size_t target_ndim = inp.layout.shape[0]; | |||||
out_layout.ndim = target_ndim; | |||||
auto* ptr = inp.value.ptr<dt_int32>(); | |||||
for (size_t i = 0; i < target_ndim; ++i) { | |||||
out_layout.shape[i] = ptr[i]; | |||||
} | |||||
return out_layout; | |||||
} | |||||
}; | |||||
template<> | |||||
struct _InferLayout<false> | |||||
{ | |||||
template<typename Op> | |||||
static TensorLayout do_infer(const TensorPtr& inp, const Op& rng){ | |||||
return inp->layout(); | |||||
} | |||||
template<typename Op> | |||||
static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng){ | |||||
size_t size = inp.layout.total_nr_elems(); | |||||
mgb_assert( | |||||
size > 0, | |||||
"target size of %s expects size>0; got size=%lu actually", | |||||
rng.dyn_typeinfo()->name, | |||||
size); | |||||
return inp.layout; | |||||
} | |||||
}; | |||||
#define _INST_RNG_INVOLKER(DNN_NR_INPUTS) \ | |||||
template<> \ | |||||
struct _RNGOprInvoker<DNN_NR_INPUTS> { \ | |||||
template<typename Opr> \ | |||||
static void exec(Opr *dnn_op, const SmallVector<TensorPtr>& inputs,const TensorPtr& dest){ \ | |||||
size_t wk_size = 0; \ | |||||
wk_size = dnn_op->get_workspace_in_bytes(_FOR_EACH_IN(->layout())dest->layout()); \ | |||||
auto workspace = Blob::make(dest->comp_node(), wk_size); \ | |||||
megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); \ | |||||
dnn_op->exec(_FOR_EACH_IN(->dev_tensor().as_megdnn()) \ | |||||
dest->dev_tensor().as_megdnn(), dnn_wk); \ | |||||
} \ | |||||
}; | |||||
#define _INST_RNG_MAKER(MGB_NR_INPUTS) \ | |||||
template<> \ | |||||
struct _RNGOprMaker<MGB_NR_INPUTS> { \ | |||||
template<typename Op> \ | |||||
static SymbolVar make(const VarNodeArray& inputs, const Op& rng){ \ | |||||
auto param = OpMeth<Op>::make_param(rng); \ | |||||
OperatorNodeConfig config; \ | |||||
if (rng.handle) { \ | |||||
config = {rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)}; \ | |||||
} else { \ | |||||
config = {rng.make_name()}; \ | |||||
} \ | |||||
return OpMeth<Op>::OpNode::make(_FOR_EACH_IN() param, config); \ | |||||
} \ | |||||
}; | |||||
#define _FOR_EACH_IN(subfix) | |||||
_INST_RNG_INVOLKER(0) | |||||
#undef _FOR_EACH_IN | |||||
#define _FOR_EACH_IN(subfix) inputs[0] subfix, | |||||
_INST_RNG_INVOLKER(1) | |||||
_INST_RNG_MAKER(1) | |||||
#undef _FOR_EACH_IN | |||||
#define _FOR_EACH_IN(subfix) inputs[0] subfix, inputs[1] subfix, | |||||
_INST_RNG_INVOLKER(2) | |||||
_INST_RNG_MAKER(2) | |||||
#undef _FOR_EACH_IN | |||||
#undef _INST_RNG_INVOLKER | |||||
#undef _INST_RNG_MAKER | |||||
template <typename Op> | template <typename Op> | ||||
void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs, | void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs, | ||||
const SmallVector<TensorPtr>& outputs) { | const SmallVector<TensorPtr>& outputs) { | ||||
auto&& rng = op.cast_final_safe<Op>(); | auto&& rng = op.cast_final_safe<Op>(); | ||||
auto dest = outputs[0]; | auto dest = outputs[0]; | ||||
auto cn = dest->comp_node(); | auto cn = dest->comp_node(); | ||||
auto handle = rng.handle; | auto handle = rng.handle; | ||||
if (!handle) { | if (!handle) { | ||||
@@ -224,38 +390,40 @@ void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs, | |||||
handle_seed, dnn_op->param().seed); | handle_seed, dnn_op->param().seed); | ||||
} | } | ||||
dnn_op->param() = OpMeth<Op>::make_param(rng); | dnn_op->param() = OpMeth<Op>::make_param(rng); | ||||
// allocate workspace | |||||
size_t wk_size = dnn_op->get_workspace_in_bytes(dest->layout()); | |||||
auto workspace = Blob::make(cn, wk_size); | |||||
megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); | |||||
dnn_op->exec(dest->dev_tensor().as_megdnn(), dnn_wk); | |||||
_RNGOprInvoker<OpMeth<Op>::DnnOp::NR_INPUTS>::exec(dnn_op,inputs,dest); | |||||
} | } | ||||
template <typename Op> | template <typename Op> | ||||
SmallVector<LogicalTensorDesc> infer_output_attrs( | SmallVector<LogicalTensorDesc> infer_output_attrs( | ||||
const OpDef& op, const SmallVector<TensorPtr>& inputs) { | const OpDef& op, const SmallVector<TensorPtr>& inputs) { | ||||
LogicalTensorDesc dest; | LogicalTensorDesc dest; | ||||
auto handle = op.cast_final_safe<Op>().handle; | |||||
auto&& rng = op.cast_final_safe<Op>(); | |||||
auto handle = rng.handle; | |||||
if (handle) { | if (handle) { | ||||
dest.comp_node = RNGDnnOpManager::get_comp_node(handle); | dest.comp_node = RNGDnnOpManager::get_comp_node(handle); | ||||
} else { | } else { | ||||
dest.comp_node = inputs[0]->comp_node(); | dest.comp_node = inputs[0]->comp_node(); | ||||
} | } | ||||
auto hv = inputs[0]->get_value().proxy_to_default_cpu(); | |||||
TensorShape tshape; | |||||
cg::copy_tensor_value_to_shape(tshape, hv); | |||||
dest.layout = TensorLayout(tshape, dtype::Float32()); | |||||
constexpr bool rng_with_shape = OpMeth<Op>::DnnOp::NR_INPUTS == 0; | |||||
if(!rng_with_shape){ | |||||
for(int i = 0; i < inputs.size(); ++i){ | |||||
mgb_assert(inputs[i]->comp_node() == dest.comp_node, | |||||
"%s expects the device of inputs[%d] to be same as the device of handle; " | |||||
"got %s and %s actually", rng.dyn_typeinfo()->name, i, | |||||
inputs[i]->comp_node().to_string().c_str(), | |||||
dest.comp_node.to_string().c_str()); | |||||
} | |||||
} | |||||
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], rng); | |||||
return {dest}; | return {dest}; | ||||
} | } | ||||
template <typename Op> | template <typename Op> | ||||
SmallVector<TensorPtr> apply_on_physical_tensor( | SmallVector<TensorPtr> apply_on_physical_tensor( | ||||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | const OpDef& def, const SmallVector<TensorPtr>& inputs) { | ||||
auto desc = infer_output_attrs<Op>(def, inputs); | |||||
SmallVector<TensorPtr> outputs; | SmallVector<TensorPtr> outputs; | ||||
SmallVector<LogicalTensorDesc> desc; | |||||
desc = infer_output_attrs<Op>(def, inputs); | |||||
for (auto&& i : desc) { | for (auto&& i : desc) { | ||||
outputs.push_back(Tensor::make(i.layout, i.comp_node)); | outputs.push_back(Tensor::make(i.layout, i.comp_node)); | ||||
} | } | ||||
@@ -268,51 +436,32 @@ SymbolVar apply_on_var_node( | |||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
size_t nr_inp = inputs.size(); | size_t nr_inp = inputs.size(); | ||||
constexpr size_t dnn_nr_inp = OpMeth<Op>::DnnOp::NR_INPUTS; | |||||
auto&& rng = def.cast_final_safe<Op>(); | auto&& rng = def.cast_final_safe<Op>(); | ||||
mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually", | |||||
rng.dyn_typeinfo()->name, | |||||
nr_inp); | |||||
auto param = OpMeth<Op>::make_param(rng); | |||||
OperatorNodeConfig config; | |||||
if (rng.handle) { | |||||
config = {rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)}; | |||||
} else { | |||||
config = {rng.make_name()}; | |||||
if(dnn_nr_inp == 0){ | |||||
mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually", | |||||
rng.dyn_typeinfo()->name, | |||||
nr_inp); | |||||
} | } | ||||
return OpMeth<Op>::OpNode::make(inputs[0], param, config); | |||||
constexpr size_t mgb_nr_inp = dnn_nr_inp + !dnn_nr_inp; | |||||
return _RNGOprMaker<mgb_nr_inp>::make(inputs, rng); | |||||
} | } | ||||
template<typename T> | |||||
template<typename Op> | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | ||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | ||||
auto&& xxx_rng_def = def.cast_final_safe<T>(); | |||||
LogicalTensorDesc dest; | |||||
auto&& xxx_rng_def = def.cast_final_safe<Op>(); | |||||
size_t nr_inp = inputs.size(); | size_t nr_inp = inputs.size(); | ||||
mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually", | |||||
xxx_rng_def.dyn_typeinfo()->name, | |||||
nr_inp); | |||||
auto&& tshp = inputs[0]; | |||||
TensorLayout out_layout = tshp.layout; | |||||
out_layout.dtype = dtype::Float32(); | |||||
if (tshp.layout.ndim == 0 || tshp.value.empty()) { | |||||
out_layout.ndim = 0; | |||||
return {{{out_layout, tshp.comp_node}}, true}; | |||||
constexpr bool rng_with_shape = OpMeth<Op>::DnnOp::NR_INPUTS == 0; | |||||
if (rng_with_shape){ | |||||
mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually", | |||||
xxx_rng_def.dyn_typeinfo()->name, | |||||
nr_inp); | |||||
} | } | ||||
mgb_assert( | |||||
tshp.layout.ndim == 1, | |||||
"target shape of %s expects ndim=1; got ndim=%lu actually", | |||||
xxx_rng_def.dyn_typeinfo()->name, | |||||
tshp.layout.ndim); | |||||
size_t target_ndim = tshp.layout.shape[0]; | |||||
out_layout.ndim = target_ndim; | |||||
auto* ptr = tshp.value.ptr<dt_int32>(); | |||||
for (size_t i = 0; i < target_ndim; ++i) { | |||||
out_layout.shape[i] = ptr[i]; | |||||
} | |||||
return {{{out_layout, tshp.comp_node}}, true}; | |||||
dest.comp_node = inputs[0].comp_node; | |||||
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def); | |||||
return {{dest}, true}; | |||||
} | } | ||||
} // anonymous namespace | } // anonymous namespace | ||||
@@ -333,6 +482,10 @@ uint64_t get_global_rng_seed() { | |||||
return RNGDnnOpManager::get_glob_default_seed(); | return RNGDnnOpManager::get_glob_default_seed(); | ||||
} | } | ||||
CompNode get_rng_handle_compnode(Handle handle){ | |||||
return RNGDnnOpManager::get_comp_node(handle); | |||||
} | |||||
#define REG_RNG_OP(NAME)\ | #define REG_RNG_OP(NAME)\ | ||||
namespace { \ | namespace { \ | ||||
OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \ | OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \ | ||||
@@ -344,6 +497,11 @@ OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \ | |||||
REG_RNG_OP(UniformRNG) | REG_RNG_OP(UniformRNG) | ||||
REG_RNG_OP(GaussianRNG) | REG_RNG_OP(GaussianRNG) | ||||
REG_RNG_OP(GammaRNG) | |||||
REG_RNG_OP(PermutationRNG) | |||||
REG_RNG_OP(PoissonRNG) | |||||
REG_RNG_OP(BetaRNG) | |||||
#undef REG_RNG_OP | |||||
} // namespace mgb::imperative::rng | } // namespace mgb::imperative::rng | ||||
@@ -22,5 +22,6 @@ Handle new_handle(CompNode comp_node, uint64_t seed); | |||||
size_t delete_handle(Handle handle); | size_t delete_handle(Handle handle); | ||||
void set_global_rng_seed(uint64_t seed); | void set_global_rng_seed(uint64_t seed); | ||||
uint64_t get_global_rng_seed(); | uint64_t get_global_rng_seed(); | ||||
CompNode get_rng_handle_compnode(Handle handle); | |||||
} // namespace mgb::imperative::rng | } // namespace mgb::imperative::rng |
@@ -42,14 +42,72 @@ void check_rng_basic(Args&& ...args) { | |||||
} | } | ||||
} | } | ||||
template<typename Op, typename ...Args> | |||||
void check_rng_with_input_basic(const CompNode &cn, | |||||
const SmallVector<TensorPtr> &inputs, Args&& ...args) { | |||||
Handle h = new_handle(cn, 123); | |||||
auto op = Op::make(std::forward<Args>(args)..., h); | |||||
auto outputs = OpDef::apply_on_physical_tensor(*op, inputs); | |||||
ASSERT_TRUE(outputs[0]->layout().eq_shape(inputs[0]->shape())); | |||||
ASSERT_TRUE(cn == outputs[0]->comp_node()); | |||||
// sync before delete handle | |||||
for (auto&& p: outputs) { | |||||
p->get_value(); | |||||
} | |||||
delete_handle(h); | |||||
} | |||||
TEST(TestImperative, PoissonRNGBasic) { | |||||
REQUIRE_XPU(2); | |||||
for (auto&& cn: {CompNode::load("xpu0"), CompNode::load("xpu1")}){ | |||||
TensorShape shape{5, 3000}; | |||||
HostTensorND lam{cn, shape, dtype::Float32()}; | |||||
auto lam_ptr = lam.ptr<float>(); | |||||
for( int i = 0; i < 5*3000; ++i) lam_ptr[i] = 2; | |||||
SmallVector<TensorPtr> inputs{Tensor::make(lam)}; | |||||
check_rng_with_input_basic<PoissonRNG>(cn, inputs, 123); | |||||
} | |||||
} | |||||
TEST(TestImperative, BetaRNGBasic) { | |||||
REQUIRE_XPU(2); | |||||
for (auto&& cn: {CompNode::load("xpu0"), CompNode::load("xpu1")}){ | |||||
TensorShape shape{5, 3000}; | |||||
HostTensorND alpha{cn, shape, dtype::Float32()}, | |||||
beta{cn, shape, dtype::Float32()}; | |||||
auto lam_ptr = alpha.ptr<float>(), beta_ptr = beta.ptr<float>(); | |||||
for( int i = 0; i < 5*3000; ++i) lam_ptr[i] = 2, beta_ptr[i] = 2; | |||||
SmallVector<TensorPtr> inputs{Tensor::make(alpha), Tensor::make(beta)}; | |||||
check_rng_with_input_basic<BetaRNG>(cn, inputs, 123); | |||||
} | |||||
} | |||||
TEST(TestImperative, GammaRNGBasic) { | |||||
REQUIRE_XPU(2); | |||||
for (auto&& cn: {CompNode::load("xpu0"), CompNode::load("xpu1")}){ | |||||
TensorShape size{5, 3000}; | |||||
HostTensorND shape{cn, size, dtype::Float32()}, | |||||
scale{cn, size, dtype::Float32()}; | |||||
auto shape_ptr = shape.ptr<float>(), scale_ptr = scale.ptr<float>(); | |||||
for( int i = 0; i < 5*3000; ++i) shape_ptr[i] = 2, scale_ptr[i] = 2; | |||||
SmallVector<TensorPtr> inputs{Tensor::make(shape), Tensor::make(scale)}; | |||||
check_rng_with_input_basic<GammaRNG>(cn, inputs, 123); | |||||
} | |||||
} | |||||
TEST(TestImperative, UniformRNGBasic) { | TEST(TestImperative, UniformRNGBasic) { | ||||
REQUIRE_XPU(2); | REQUIRE_XPU(2); | ||||
check_rng_basic<UniformRNG>(123); | |||||
check_rng_basic<UniformRNG>(123, dtype::Float32()); | |||||
} | } | ||||
TEST(TestImperative, GaussianRNGBasic) { | TEST(TestImperative, GaussianRNGBasic) { | ||||
REQUIRE_XPU(2); | REQUIRE_XPU(2); | ||||
check_rng_basic<GaussianRNG>(123, 2.f, 3.f); | |||||
check_rng_basic<GaussianRNG>(123, 2.f, 3.f, dtype::Float32()); | |||||
} | |||||
TEST(TestImperative, PermutationRNGBasic) { | |||||
REQUIRE_XPU(2); | |||||
check_rng_basic<PermutationRNG>(123, dtype::Int32()); | |||||
} | } | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -123,9 +123,13 @@ def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> { | |||||
let hashFunction = [{ | let hashFunction = [{ | ||||
return mgb::hash_pair_combine( | return mgb::hash_pair_combine( | ||||
mgb::hash($_self.dyn_typeinfo()), | mgb::hash($_self.dyn_typeinfo()), | ||||
mgb::hash($_self.handle)); | |||||
mgb::hash_pair_combine( | |||||
mgb::hash($_self.handle), | |||||
mgb::hash($_self.dtype.enumv()) | |||||
) | |||||
); | |||||
}]; | }]; | ||||
let cmpFunction = [{return $0.handle == $1.handle;}]; | |||||
let cmpFunction = [{return $0.handle == $1.handle && $0.dtype == $1.dtype;}]; | |||||
} | } | ||||
def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> { | def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> { | ||||
@@ -139,11 +143,70 @@ def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> { | |||||
mgb::hash($_self.handle), | mgb::hash($_self.handle), | ||||
mgb::hash_pair_combine( | mgb::hash_pair_combine( | ||||
mgb::hash($_self.mean), | mgb::hash($_self.mean), | ||||
mgb::hash($_self.std)) | |||||
mgb::hash_pair_combine( | |||||
mgb::hash($_self.std), | |||||
mgb::hash($_self.dtype.enumv()) | |||||
) | |||||
) | |||||
) | |||||
); | |||||
}]; | |||||
let cmpFunction = [{return $0.handle == $1.handle && $0.mean == $1.mean && $0.std == $1.std && $0.dtype == $1.dtype;}]; | |||||
} | |||||
def GammaRNG: MgbHashableOp<"GammaRNG", [GammaRNGParam]> { | |||||
let extraArguments = (ins | |||||
MgbSizeTAddr:$handle | |||||
); | |||||
let hashFunction = [{ | |||||
return mgb::hash_pair_combine( | |||||
mgb::hash($_self.dyn_typeinfo()), | |||||
mgb::hash($_self.handle) | |||||
); | |||||
}]; | |||||
let cmpFunction = [{return $0.handle == $1.handle;}]; | |||||
} | |||||
def PoissonRNG: MgbHashableOp<"PoissonRNG", [PoissonRNGParam]> { | |||||
let extraArguments = (ins | |||||
MgbSizeTAddr:$handle | |||||
); | |||||
let hashFunction = [{ | |||||
return mgb::hash_pair_combine( | |||||
mgb::hash($_self.dyn_typeinfo()), | |||||
mgb::hash($_self.handle) | |||||
); | |||||
}]; | |||||
let cmpFunction = [{return $0.handle == $1.handle;}]; | |||||
} | |||||
def BetaRNG: MgbHashableOp<"BetaRNG", [BetaRNGParam]> { | |||||
let extraArguments = (ins | |||||
MgbSizeTAddr:$handle | |||||
); | |||||
let hashFunction = [{ | |||||
return mgb::hash_pair_combine( | |||||
mgb::hash($_self.dyn_typeinfo()), | |||||
mgb::hash($_self.handle) | |||||
); | |||||
}]; | |||||
let cmpFunction = [{return $0.handle == $1.handle;}]; | |||||
} | |||||
def PermutationRNG: MgbHashableOp<"PermutationRNG", [PermutationRNGParam]> { | |||||
let extraArguments = (ins | |||||
MgbSizeTAddr:$handle | |||||
); | |||||
let hashFunction = [{ | |||||
return mgb::hash_pair_combine( | |||||
mgb::hash($_self.dyn_typeinfo()), | |||||
mgb::hash_pair_combine( | |||||
mgb::hash($_self.handle), | |||||
mgb::hash($_self.dtype.enumv()) | |||||
) | ) | ||||
); | ); | ||||
}]; | }]; | ||||
let cmpFunction = [{return $0.handle == $1.handle && $0.mean == $1.mean && $0.std == $1.std;}]; | |||||
let cmpFunction = [{return $0.handle == $1.handle && $0.dtype == $1.dtype;}]; | |||||
} | } | ||||
def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> { | def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> { | ||||
@@ -19,46 +19,21 @@ using namespace mgb; | |||||
using namespace opr; | using namespace opr; | ||||
using namespace intl; | using namespace intl; | ||||
namespace { | |||||
template<class MegDNNOpr> | |||||
struct RNGName; | |||||
template<> | |||||
struct RNGName<megdnn::UniformRNG> { | |||||
static constexpr const char* name = "uniform_rng"; | |||||
}; | |||||
template<> | |||||
struct RNGName<megdnn::GaussianRNG> { | |||||
static constexpr const char* name = "gaussian_rng"; | |||||
}; | |||||
} // anonymous namespace | |||||
RNGOprBase::RNGOprBase(const OperatorNodeBaseCtorParam &opr, VarNode *shape): | |||||
Super(opr) | |||||
template<typename MegDNNOpr> | |||||
RNGOprBase<MegDNNOpr>::RNGOprBase(const OperatorNodeBaseCtorParam &opr, const Param ¶m): | |||||
Super(opr),m_param(param) | |||||
{ | { | ||||
add_input({shape}); | |||||
add_output(None)->dtype(dtype::Float32()); | |||||
cg::add_workspace_output(this); | |||||
// disable dedup | |||||
add_equivalence_component<ScalarHash<void*>>(this); | |||||
} | } | ||||
RNGOprBase::~RNGOprBase() { | |||||
} | |||||
cg::OperatorNodeBase::NodeProp* RNGOprBase::do_make_node_prop() const { | |||||
auto prop = Super::do_make_node_prop(); | |||||
prop->add_flag(NodeProp::Flag::IMPURE_FUNC); | |||||
prop->reset_dep_type(input(), {NodeProp::DepType::HOST_VALUE}); | |||||
return prop; | |||||
template<class MegDNNOpr> | |||||
UniqPtrWithCN<MegDNNOpr> RNGOprBase<MegDNNOpr>::create_megdnn_opr() { | |||||
auto opr = intl::create_megdnn_opr<MegDNNOpr>(comp_node()); | |||||
opr->param() = param(); | |||||
return opr; | |||||
} | } | ||||
void RNGOprBase::ensure_megdnn_opr() { | |||||
template<typename MegDNNOpr> | |||||
void RNGOprBase<MegDNNOpr>::ensure_megdnn_opr() { | |||||
if (!m_dnn_opr || m_dnn_opr.comp_node() != comp_node()) { | if (!m_dnn_opr || m_dnn_opr.comp_node() != comp_node()) { | ||||
// activate comp_node for curandCreateGenerator in create_megdnn_opr | // activate comp_node for curandCreateGenerator in create_megdnn_opr | ||||
comp_node().activate(); | comp_node().activate(); | ||||
@@ -66,53 +41,120 @@ void RNGOprBase::ensure_megdnn_opr() { | |||||
} | } | ||||
} | } | ||||
void RNGOprBase::init_output_static_infer_desc() { | |||||
using namespace cg::static_infer; | |||||
auto &&mgr = owner_graph()->static_infer_manager(); | |||||
auto infer_out = [](TensorShape &dest, const InpVal &inp) { | |||||
cg::copy_tensor_value_to_shape(dest, inp.val.at(0).value()); | |||||
return true; | |||||
}; | |||||
auto infer_wk = [this](TensorShape &dest, const InpVal &inp) { | |||||
ensure_megdnn_opr(); | |||||
dest.ndim = 1; | |||||
dest.shape[0] = m_dnn_opr->get_workspace_in_bytes( | |||||
{inp.val.at(0).shape(), output(0)->dtype()}); | |||||
return true; | |||||
}; | |||||
mgr.register_shape_infer(output(0), | |||||
{SourceType::DEP, {{input(0), DepType::VALUE}}, infer_out}); | |||||
mgr.register_shape_infer(output(1), | |||||
{SourceType::DEP, {{output(0), DepType::SHAPE}}, infer_wk}); | |||||
/* ================= RNG with shape ================= */ | |||||
#define _INST_RNG_OPR_WITH_SHAPE(RNGOpr, name) \ | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNGOpr); \ | |||||
cg::OperatorNodeBase::NodeProp* RNGOpr::do_make_node_prop() const { \ | |||||
auto prop = Super::do_make_node_prop(); \ | |||||
prop->add_flag(NodeProp::Flag::IMPURE_FUNC); \ | |||||
prop->reset_dep_type(input(), {NodeProp::DepType::HOST_VALUE}); \ | |||||
return prop; \ | |||||
} \ | |||||
RNGOpr::RNGOpr(VarNode *shape, const Param ¶m, \ | |||||
const OperatorNodeConfig &config): \ | |||||
Super({shape->owner_graph(), config, (name), {shape}}, param) \ | |||||
{ \ | |||||
DType dtype = DType::from_enum(param.dtype); \ | |||||
add_input({shape}); \ | |||||
add_output(None)->dtype(dtype); \ | |||||
cg::add_workspace_output(this); \ | |||||
add_equivalence_component<ScalarHash<void*>>(this); \ | |||||
} \ | |||||
SymbolVar RNGOpr::make(SymbolVar shape, const Param ¶m, \ | |||||
const OperatorNodeConfig &config){ \ | |||||
return shape.insert_single_output_opr<RNGOpr>(shape.node(), param, config); \ | |||||
} \ | |||||
void RNGOpr::init_output_static_infer_desc() { \ | |||||
using namespace cg::static_infer; \ | |||||
auto &&mgr = owner_graph()->static_infer_manager(); \ | |||||
auto infer_out = [](TensorShape &dest, const InpVal &inp) { \ | |||||
cg::copy_tensor_value_to_shape(dest, inp.val.at(0).value()); \ | |||||
return true; \ | |||||
}; \ | |||||
auto infer_wk = [this](TensorShape &dest, const InpVal &inp) { \ | |||||
ensure_megdnn_opr(); \ | |||||
dest.ndim = 1; \ | |||||
dest.shape[0] = m_dnn_opr->get_workspace_in_bytes( \ | |||||
{inp.val.at(0).shape(), output(0)->dtype()}); \ | |||||
return true; \ | |||||
}; \ | |||||
mgr.register_shape_infer(output(0), \ | |||||
{SourceType::DEP, {{input(0), DepType::VALUE}}, infer_out}); \ | |||||
mgr.register_shape_infer(output(1), \ | |||||
{SourceType::DEP, {{output(0), DepType::SHAPE}}, infer_wk}); \ | |||||
} \ | |||||
void RNGOpr::scn_do_execute() { \ | |||||
m_dnn_opr->exec(output(0)->dev_tensor().as_megdnn(), \ | |||||
get_megdnn_workspace_from_var(output(1))); \ | |||||
} | } | ||||
void RNGOprBase::scn_do_execute() { | |||||
m_dnn_opr->exec( | |||||
output(0)->dev_tensor().as_megdnn(), | |||||
get_megdnn_workspace_from_var(output(1))); | |||||
} | |||||
template<class MegDNNOpr> | |||||
RNGOpr<MegDNNOpr>::RNGOpr(VarNode *shape, const Param ¶m, | |||||
const OperatorNodeConfig &config): | |||||
Super({shape->owner_graph(), config, RNGName<MegDNNOpr>::name, {shape}}, | |||||
shape), | |||||
m_param(param) | |||||
{ | |||||
} | |||||
template<class MegDNNOpr> | |||||
SymbolVar RNGOpr<MegDNNOpr>::make(SymbolVar shape, const Param ¶m, | |||||
const OperatorNodeConfig &config) { | |||||
return shape.insert_single_output_opr<RNGOpr>(shape.node(), param, config); | |||||
} | |||||
template<class MegDNNOpr> | |||||
UniqPtrWithCN<megdnn::RNGBase> RNGOpr<MegDNNOpr>::create_megdnn_opr() { | |||||
auto opr = intl::create_megdnn_opr<MegDNNOpr>(comp_node()); | |||||
opr->param() = param(); | |||||
return opr; | |||||
} | |||||
_INST_RNG_OPR_WITH_SHAPE(UniformRNG,"uniform_rng") | |||||
_INST_RNG_OPR_WITH_SHAPE(GaussianRNG,"gaussian_rng") | |||||
_INST_RNG_OPR_WITH_SHAPE(PermutationRNG,"permutation_rng") | |||||
#undef _INST_RNG_OPR_WITH_SHAPE | |||||
/* ================= RNG with input ================= */ | |||||
#define _AS_MEGDNN(idx) input((idx))->dev_tensor().as_megdnn() | |||||
#define _INFER_WK_DEPS(idx) {input((idx)), DepType::SHAPE} | |||||
#define _INFER_WK_ARGS(idx) {inp.val.at((idx)).shape(), input((idx))->dtype()} | |||||
#define _INST_RNG_OPR_WITH_INPUT(RNGOpr, name) \ | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNGOpr); \ | |||||
RNGOpr::RNGOpr(_INPUTS(VarNode*,), const Param ¶m, \ | |||||
const OperatorNodeConfig &config): \ | |||||
Super({i0->owner_graph(), config, (name), {_INPUTS(,)}}, param) \ | |||||
{ \ | |||||
add_input({_INPUTS(,)}); \ | |||||
add_output(None)->dtype(i0->dtype()); \ | |||||
cg::add_workspace_output(this); \ | |||||
add_equivalence_component<ScalarHash<void*>>(this); \ | |||||
} \ | |||||
SymbolVar RNGOpr::make(_INPUTS(SymbolVar,), const Param ¶m, \ | |||||
const OperatorNodeConfig &config){ \ | |||||
return i0.insert_single_output_opr<RNGOpr>(_INPUTS(,.node()), param, config); \ | |||||
} \ | |||||
void RNGOpr::init_output_static_infer_desc() { \ | |||||
using namespace cg::static_infer; \ | |||||
auto &&mgr = owner_graph()->static_infer_manager(); \ | |||||
auto infer_wk = [this](TensorShape &dest, const InpVal &inp) { \ | |||||
ensure_megdnn_opr(); \ | |||||
dest.ndim = 1; \ | |||||
dest.shape[0] = m_dnn_opr->get_workspace_in_bytes( \ | |||||
_FOR_EACH(_INFER_WK_ARGS), \ | |||||
{output(0)->shape(), output(0)->dtype()}); \ | |||||
return true; \ | |||||
}; \ | |||||
mgr.register_shape_infer(output(0),ShapeInferDesc::make_identity(input(0))); \ | |||||
mgr.register_shape_infer(output(1),{SourceType::DEP, {_FOR_EACH(_INFER_WK_DEPS)}, \ | |||||
infer_wk}); \ | |||||
} \ | |||||
void RNGOpr::add_input_layout_constraint(){ \ | |||||
for (auto i : input()) i->add_layout_constraint_contiguous(); \ | |||||
}; \ | |||||
void RNGOpr::scn_do_execute() { \ | |||||
m_dnn_opr->exec(_FOR_EACH(_AS_MEGDNN),output(0)->dev_tensor().as_megdnn(), \ | |||||
get_megdnn_workspace_from_var(output(1))); \ | |||||
} | |||||
/* ================= 1 input ================= */ | |||||
#define _INPUTS(prefix, subfix) prefix i0 subfix | |||||
#define _FOR_EACH(cb) cb(0) | |||||
_INST_RNG_OPR_WITH_INPUT(PoissonRNG,"poisson_rng") | |||||
#undef _INPUTS | |||||
#undef _FOR_EACH | |||||
/* ================= 2 input ================= */ | |||||
#define _INPUTS(prefix,subfix) prefix i0 subfix, prefix i1 subfix | |||||
#define _FOR_EACH(cb) cb(0), cb(1) | |||||
_INST_RNG_OPR_WITH_INPUT(BetaRNG,"beta_rng") | |||||
_INST_RNG_OPR_WITH_INPUT(GammaRNG,"gamma_rng") | |||||
#undef _INPUTS | |||||
#undef _FOR_EACH | |||||
#undef _AS_MEGDNN | |||||
#undef _INFER_WK_DEPS | |||||
#undef _INFER_WK_ARGS | |||||
#undef _INST_RNG_OPR_WITH_INPUT | |||||
#define IMPL(_cls) \ | #define IMPL(_cls) \ | ||||
MGB_IMPL_OPR_GRAD(_cls) { \ | MGB_IMPL_OPR_GRAD(_cls) { \ | ||||
@@ -123,13 +165,21 @@ UniqPtrWithCN<megdnn::RNGBase> RNGOpr<MegDNNOpr>::create_megdnn_opr() { | |||||
namespace mgb { | namespace mgb { | ||||
namespace opr { | namespace opr { | ||||
namespace intl { | namespace intl { | ||||
template class RNGOpr<::megdnn::GaussianRNG>; | |||||
template class RNGOpr<::megdnn::UniformRNG>; | |||||
template class RNGOprBase<::megdnn::GaussianRNG>; | |||||
template class RNGOprBase<::megdnn::UniformRNG>; | |||||
template class RNGOprBase<::megdnn::GammaRNG>; | |||||
template class RNGOprBase<::megdnn::PermutationRNG>; | |||||
template class RNGOprBase<::megdnn::BetaRNG>; | |||||
template class RNGOprBase<::megdnn::PoissonRNG>; | |||||
#if MGB_ENABLE_GRAD | #if MGB_ENABLE_GRAD | ||||
IMPL(GaussianRNG); | IMPL(GaussianRNG); | ||||
IMPL(UniformRNG); | IMPL(UniformRNG); | ||||
IMPL(GammaRNG); | |||||
IMPL(PoissonRNG); | |||||
IMPL(PermutationRNG); | |||||
IMPL(BetaRNG); | |||||
#endif | #endif | ||||
} | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -17,6 +17,10 @@ namespace opr { | |||||
MGB_SEREG_OPR(UniformRNG, 1); | MGB_SEREG_OPR(UniformRNG, 1); | ||||
MGB_SEREG_OPR(GaussianRNG, 1); | MGB_SEREG_OPR(GaussianRNG, 1); | ||||
MGB_SEREG_OPR(GammaRNG, 2); | |||||
MGB_SEREG_OPR(PoissonRNG, 1); | |||||
MGB_SEREG_OPR(PermutationRNG, 1); | |||||
MGB_SEREG_OPR(BetaRNG, 2); | |||||
} // namespace opr | } // namespace opr | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -14,7 +14,6 @@ | |||||
#include "megbrain/graph.h" | #include "megbrain/graph.h" | ||||
#include "megbrain/opr/internal/out_shape_by_sym_var.h" | #include "megbrain/opr/internal/out_shape_by_sym_var.h" | ||||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
namespace mgb { | namespace mgb { | ||||
@@ -22,60 +21,81 @@ namespace opr { | |||||
namespace intl { | namespace intl { | ||||
template<typename MegDNNOpr> | |||||
MGB_DEFINE_CLS_WITH_SUPER(RNGOprBase, cg::SingleCNOperatorNodeBase) // { | MGB_DEFINE_CLS_WITH_SUPER(RNGOprBase, cg::SingleCNOperatorNodeBase) // { | ||||
UniqPtrWithCN<megdnn::RNGBase> m_dnn_opr; | |||||
void ensure_megdnn_opr(); | |||||
void init_output_static_infer_desc() override; | |||||
void scn_do_execute() override final; | |||||
protected: | |||||
RNGOprBase(const OperatorNodeBaseCtorParam &opr, VarNode *shape); | |||||
~RNGOprBase(); | |||||
NodeProp* do_make_node_prop() const override; | |||||
virtual UniqPtrWithCN<megdnn::RNGBase> create_megdnn_opr() = 0; | |||||
}; | |||||
template<class MegDNNOpr> | |||||
MGB_DEFINE_OPR_CLASS(RNGOpr, RNGOprBase) // { | |||||
public: | public: | ||||
using Param = typename MegDNNOpr::Param; | using Param = typename MegDNNOpr::Param; | ||||
RNGOpr(VarNode *shape, const Param ¶m, | |||||
const OperatorNodeConfig &config); | |||||
static SymbolVar make(SymbolVar shape, const Param ¶m = {}, | |||||
const OperatorNodeConfig &config = {}); | |||||
static SymbolVar make(ComputingGraph &graph, const TensorShape &shape, | |||||
const OperatorNodeConfig &config, | |||||
const Param ¶m = {}) { | |||||
return make(var_from_tensor_shape(graph, config, "rng", shape), | |||||
param, config); | |||||
} | |||||
const Param& param() const { | const Param& param() const { | ||||
return m_param; | return m_param; | ||||
} | } | ||||
private: | private: | ||||
Param m_param; | Param m_param; | ||||
UniqPtrWithCN<megdnn::RNGBase> create_megdnn_opr() override; | |||||
UniqPtrWithCN<MegDNNOpr> create_megdnn_opr(); | |||||
protected: | |||||
~RNGOprBase(){}; | |||||
RNGOprBase(const OperatorNodeBaseCtorParam &opr, const Param ¶m); | |||||
void ensure_megdnn_opr(); | |||||
UniqPtrWithCN<MegDNNOpr> m_dnn_opr; | |||||
}; | |||||
/* ================= RNG with shape ================= */ | |||||
#define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \ | |||||
MGB_DEFINE_OPR_CLASS(RNG,RNGOprBase<megdnn::RNG>) \ | |||||
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \ | |||||
public: \ | |||||
RNG(VarNode *shape, const Param ¶m, const OperatorNodeConfig &config); \ | |||||
static SymbolVar make(SymbolVar shape, const Param ¶m = {}, \ | |||||
const OperatorNodeConfig &config = {}); \ | |||||
static SymbolVar make(ComputingGraph &graph, const TensorShape &shape, \ | |||||
const OperatorNodeConfig &config, \ | |||||
const Param ¶m = {}) { \ | |||||
return make(var_from_tensor_shape(graph, config, "rng", shape), \ | |||||
param, config); \ | |||||
} \ | |||||
void init_output_static_infer_desc() override; \ | |||||
void scn_do_execute() override; \ | |||||
}; | }; | ||||
#undef _MGB_DYN_TYPE_OBJ_FINAL_IMPL_TPL | |||||
#define _MGB_DYN_TYPE_OBJ_FINAL_IMPL_TPL template<class MegDNNOpr> | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNGOpr<MegDNNOpr>); | |||||
#undef _MGB_DYN_TYPE_OBJ_FINAL_IMPL_TPL | |||||
#define _MGB_DYN_TYPE_OBJ_FINAL_IMPL_TPL | |||||
_DEFINE_RNG_OPR_WITH_SHAPE_CLASS(UniformRNG) | |||||
_DEFINE_RNG_OPR_WITH_SHAPE_CLASS(GaussianRNG) | |||||
_DEFINE_RNG_OPR_WITH_SHAPE_CLASS(PermutationRNG) | |||||
#undef _DEFINE_RNG_OPR_WITH_SHAPE_CLASS | |||||
/* ================= RNG with input ================= */ | |||||
#define _DEFINE_RNG_OPR_WITH_INPUT_CLASS(RNG) \ | |||||
MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>) \ | |||||
void add_input_layout_constraint() override; \ | |||||
public: \ | |||||
RNG(_INPUTS(VarNode*), const Param ¶m, \ | |||||
const OperatorNodeConfig &config); \ | |||||
static SymbolVar make(_INPUTS(SymbolVar),const Param ¶m = {}, \ | |||||
const OperatorNodeConfig &config = {}); \ | |||||
void init_output_static_infer_desc() override; \ | |||||
void scn_do_execute() override; \ | |||||
}; | |||||
} // intl | |||||
/* ================= 1 input ================= */ | |||||
#define _INPUTS(preifx) preifx i0 | |||||
_DEFINE_RNG_OPR_WITH_INPUT_CLASS(PoissonRNG) | |||||
#undef _INPUTS | |||||
using UniformRNG = intl::RNGOpr<megdnn::UniformRNG>; | |||||
using GaussianRNG = intl::RNGOpr<megdnn::GaussianRNG>; | |||||
/* ================= 2 input ================= */ | |||||
#define _INPUTS(preifx) preifx i0, preifx i1 | |||||
_DEFINE_RNG_OPR_WITH_INPUT_CLASS(BetaRNG) | |||||
_DEFINE_RNG_OPR_WITH_INPUT_CLASS(GammaRNG) | |||||
#undef _INPUTS | |||||
#undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS | |||||
} // intl | |||||
using UniformRNG = intl::UniformRNG; | |||||
using GaussianRNG = intl::GaussianRNG; | |||||
using GammaRNG = intl::GammaRNG; | |||||
using PermutationRNG = intl::PermutationRNG; | |||||
using PoissonRNG = intl::PoissonRNG; | |||||
using BetaRNG = intl::BetaRNG; | |||||
} // namespace opr | } // namespace opr | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -19,84 +19,76 @@ | |||||
using namespace mgb; | using namespace mgb; | ||||
namespace { | namespace { | ||||
struct BasicStat { | |||||
double mean, std, min, max; | |||||
static BasicStat make(const float *ptr, size_t size, | |||||
double mean_expect = 0) { | |||||
double sum = 0, sum2 = 0, | |||||
min = std::numeric_limits<double>::max(), | |||||
max = std::numeric_limits<double>::lowest(); | |||||
for (size_t i = 0; i < size; ++ i) { | |||||
double cur = ptr[i]; | |||||
min = std::min(min, cur); | |||||
max = std::max(max, cur); | |||||
cur -= mean_expect; | |||||
sum += cur; | |||||
sum2 += cur * cur; | |||||
} | |||||
double mean = sum / size + mean_expect, | |||||
std = sqrt((sum2 - sum * sum / size) / (size - 1)); | |||||
return {mean, std, min, max}; | |||||
struct BasicStat { | |||||
double mean, std, min, max; | |||||
static BasicStat make(const float* ptr, size_t size, | |||||
double mean_expect = 0) { | |||||
double sum = 0, sum2 = 0, min = std::numeric_limits<double>::max(), | |||||
max = std::numeric_limits<double>::lowest(); | |||||
for (size_t i = 0; i < size; ++i) { | |||||
double cur = ptr[i]; | |||||
min = std::min(min, cur); | |||||
max = std::max(max, cur); | |||||
cur -= mean_expect; | |||||
sum += cur; | |||||
sum2 += cur * cur; | |||||
} | } | ||||
}; | |||||
void check_reproducibility( | |||||
thin_function<SymbolVar(SymbolVar, uint64_t seed)> make) { | |||||
auto graph = ComputingGraph::make(); | |||||
constexpr size_t SIZE = 123; | |||||
// out[func][opr][run] | |||||
HostTensorND out[2][2][2]; | |||||
auto run = [&](int fid) { | |||||
SymbolVar | |||||
o0 = make(cg::var_from_tensor_shape(*graph, | |||||
{CompNode::load("xpu0")}, "shp0", {SIZE}), 0), | |||||
o1 = make(cg::var_from_tensor_shape(*graph, | |||||
{CompNode::load("xpu0")}, "shp0", {SIZE}), 1); | |||||
HostTensorND host_o0, host_o1; | |||||
auto func = graph->compile({ | |||||
make_callback_copy(o0, host_o0), | |||||
make_callback_copy(o1, host_o1)}); | |||||
for (int i = 0; i < 2; ++ i) { | |||||
func->execute(); | |||||
out[fid][0][i].copy_from(host_o0); | |||||
out[fid][1][i].copy_from(host_o1); | |||||
} | |||||
}; | |||||
run(0); | |||||
run(1); | |||||
for (int i = 0; i < 2; ++ i) { | |||||
for (int j = 0; j < 2; ++ j) | |||||
MGB_ASSERT_TENSOR_EQ(out[0][i][j], out[1][i][j]); | |||||
double mean = sum / size + mean_expect, | |||||
std = sqrt((sum2 - sum * sum / size) / (size - 1)); | |||||
return {mean, std, min, max}; | |||||
} | |||||
}; | |||||
void check_reproducibility(std::shared_ptr<ComputingGraph> graph, size_t size, | |||||
thin_function<SymbolVar(uint64_t seed)> make) { | |||||
// out[func][opr][run] | |||||
HostTensorND out[2][2][2]; | |||||
auto run = [&](int fid) { | |||||
SymbolVar o0 = make(0), o1 = make(1); | |||||
HostTensorND host_o0, host_o1; | |||||
auto func = graph->compile({make_callback_copy(o0, host_o0), | |||||
make_callback_copy(o1, host_o1)}); | |||||
for (int i = 0; i < 2; ++i) { | |||||
func->execute(); | |||||
out[fid][0][i].copy_from(host_o0); | |||||
out[fid][1][i].copy_from(host_o1); | |||||
} | } | ||||
}; | |||||
run(0); | |||||
run(1); | |||||
auto max_diff = [&](int off0, int off1) { | |||||
float diff = 0; | |||||
auto p0 = out[0][off0 / 2][off0 % 2].ptr<float>(), | |||||
p1 = out[0][off1 / 2][off1 % 2].ptr<float>(); | |||||
for (size_t i = 0; i < SIZE; ++ i) { | |||||
update_max(diff, std::abs(p0[i] - p1[i])); | |||||
} | |||||
return diff; | |||||
}; | |||||
for (int i = 0; i < 4; ++ i) { | |||||
for (int j = i + 1; j < 4; ++ j) | |||||
ASSERT_GT(max_diff(i, j), 0.3) << i << " " << j; | |||||
for (int i = 0; i < 2; ++i) { | |||||
for (int j = 0; j < 2; ++j) | |||||
MGB_ASSERT_TENSOR_EQ(out[0][i][j], out[1][i][j]); | |||||
} | |||||
auto max_diff = [&](int off0, int off1) { | |||||
float diff = 0; | |||||
auto p0 = out[0][off0 / 2][off0 % 2].ptr<float>(), | |||||
p1 = out[0][off1 / 2][off1 % 2].ptr<float>(); | |||||
for (size_t i = 0; i < size; ++i) { | |||||
update_max(diff, std::abs(p0[i] - p1[i])); | |||||
} | } | ||||
return diff; | |||||
}; | |||||
for (int i = 0; i < 4; ++i) { | |||||
for (int j = i + 1; j < 4; ++j) | |||||
ASSERT_GT(max_diff(i, j), 0.3) << i << " " << j; | |||||
} | } | ||||
} | |||||
} // anonymous namespace | |||||
} // anonymous namespace | |||||
TEST(TestOprRand, Uniform) { | TEST(TestOprRand, Uniform) { | ||||
static constexpr size_t M = 128, N = 64; | static constexpr size_t M = 128, N = 64; | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
SymbolVar dev_out = opr::UniformRNG::make( | SymbolVar dev_out = opr::UniformRNG::make( | ||||
*graph, {M, N}, {CompNode::load("xpu0")}); | |||||
*graph, {M, N}, {CompNode::load("xpu0")}, {23, DTypeEnum::Float32}); | |||||
HostTensorND host_out; | HostTensorND host_out; | ||||
auto func = graph->compile({make_callback_copy(dev_out, host_out)}); | auto func = graph->compile({make_callback_copy(dev_out, host_out)}); | ||||
@@ -115,9 +107,10 @@ TEST(TestOprRand, Gaussian) { | |||||
static constexpr size_t SIZE = 123451; | static constexpr size_t SIZE = 123451; | ||||
constexpr float MEAN = 1, STD = 2; | constexpr float MEAN = 1, STD = 2; | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto y = opr::GaussianRNG::make( | auto y = opr::GaussianRNG::make( | ||||
SymbolVar::make_scalar(int(SIZE), *graph, {CompNode::load("xpu0")}), | SymbolVar::make_scalar(int(SIZE), *graph, {CompNode::load("xpu0")}), | ||||
{23, MEAN, STD}); | |||||
{23, MEAN, STD, DTypeEnum::Float32}); | |||||
HostTensorND host_y; | HostTensorND host_y; | ||||
auto func = graph->compile({make_callback_copy(y, host_y)}); | auto func = graph->compile({make_callback_copy(y, host_y)}); | ||||
@@ -130,17 +123,212 @@ TEST(TestOprRand, Gaussian) { | |||||
ASSERT_LT(fabs(stat.std - STD), 0.1); | ASSERT_LT(fabs(stat.std - STD), 0.1); | ||||
} | } | ||||
TEST(TestOprRand, Gamma) { | |||||
std::shared_ptr<HostTensorND> shape_host(new HostTensorND{ | |||||
CompNode::load("xpux"), TensorShape{2000000*5}, dtype::Float32()}); | |||||
std::shared_ptr<HostTensorND> scale_host(new HostTensorND{ | |||||
CompNode::load("xpux"), TensorShape{2000000*5}, dtype::Float32()}); | |||||
auto shape_ptr = shape_host->ptr<float>(); | |||||
auto scale_ptr = scale_host->ptr<float>(); | |||||
for (int i = 0; i < 5; ++i) { | |||||
for (int j = 0; j < 2000000; ++j) { | |||||
shape_ptr[i * 2000000 + j] = 2 * 0.3 * i + 0.5; | |||||
scale_ptr[i * 2000000 + j] = i * 0.3 + 0.5; | |||||
} | |||||
} | |||||
auto graph = ComputingGraph::make(); | |||||
auto shape_sym = opr::Host2DeviceCopy::make(*graph, shape_host); | |||||
auto scale_sym = opr::Host2DeviceCopy::make(*graph, scale_host); | |||||
auto y = opr::GammaRNG::make(shape_sym, scale_sym, {10}); | |||||
HostTensorND host_y; | |||||
auto func = graph->compile({make_callback_copy(y, host_y)}); | |||||
func->execute(); | |||||
ASSERT_EQ(TensorShape({2000000*5}), host_y.shape()); | |||||
for (int i = 0; i < 5; ++i) { | |||||
float a = 2 * 0.3 * i + 0.5, b = i * 0.3 + 0.5; | |||||
float mean = a * b; | |||||
float std = a * (b * b); | |||||
auto stat = BasicStat::make(host_y.ptr<float>() + 2000000 * i, | |||||
2000000, mean); | |||||
ASSERT_LT(fabs(stat.mean - mean), 0.01); | |||||
ASSERT_LT(fabs(stat.std - sqrt(std)), 0.01); | |||||
} | |||||
} | |||||
TEST(TestOprRand, Poisson) { | |||||
std::shared_ptr<HostTensorND> lam_host(new HostTensorND{ | |||||
CompNode::load("xpux"), TensorShape{200000*5}, dtype::Float32()}); | |||||
auto lam_ptr = lam_host->ptr<float>(); | |||||
for (int i = 0; i < 5; ++i) { | |||||
for (int j = 0; j < 200000; ++j) { | |||||
lam_ptr[i * 200000 + j] = i + 1; | |||||
} | |||||
} | |||||
auto graph = ComputingGraph::make(); | |||||
auto lam_sym = opr::Host2DeviceCopy::make(*graph, lam_host); | |||||
auto y = opr::PoissonRNG::make(lam_sym, {10}); | |||||
HostTensorND host_y; | |||||
auto func = graph->compile({make_callback_copy(y, host_y)}); | |||||
func->execute(); | |||||
ASSERT_EQ(TensorShape({200000*5}), host_y.shape()); | |||||
for (int i = 0; i < 5; ++i) { | |||||
float lambda = i + 1; | |||||
auto stat = BasicStat::make(host_y.ptr<float>() + 200000 * i, | |||||
200000,lambda); | |||||
ASSERT_LT(fabs(stat.mean - lambda), 0.01); | |||||
ASSERT_LT(fabs(stat.std - sqrt(lambda)), 0.1); | |||||
} | |||||
} | |||||
TEST(TestOprRand, Beta) { | |||||
std::shared_ptr<HostTensorND> alpha_host(new HostTensorND{ | |||||
CompNode::load("xpux"), TensorShape{200000*5}, dtype::Float32()}); | |||||
std::shared_ptr<HostTensorND> beta_host(new HostTensorND{ | |||||
CompNode::load("xpux"), TensorShape{200000*5}, dtype::Float32()}); | |||||
auto alpha_ptr = alpha_host->ptr<float>(); | |||||
auto beta_ptr = beta_host->ptr<float>(); | |||||
for (int i = 0; i < 5; ++i) { | |||||
for (int j = 0; j < 200000; ++j) { | |||||
alpha_ptr[i * 200000 + j] = 0.3 * i + 0.1; | |||||
beta_ptr[i * 200000 + j] = 2 * i * 0.3 + 0.1; | |||||
} | |||||
} | |||||
auto graph = ComputingGraph::make(); | |||||
auto alpha_sym = opr::Host2DeviceCopy::make(*graph, alpha_host); | |||||
auto beta_sym = opr::Host2DeviceCopy::make(*graph, beta_host); | |||||
auto y = opr::BetaRNG::make(alpha_sym,beta_sym, {10}); | |||||
HostTensorND host_y; | |||||
auto func = graph->compile({make_callback_copy(y, host_y)}); | |||||
func->execute(); | |||||
ASSERT_EQ(TensorShape({200000*5}), host_y.shape()); | |||||
for (int i = 0; i < 5; ++i) { | |||||
float a = 0.3 * i + 0.1, b = 2 * i * 0.3 + 0.1; | |||||
float mean = a / (a + b); | |||||
float std = a * b / ((a + b) * (a + b) * (a + b + 1)); | |||||
auto stat = BasicStat::make(host_y.ptr<float>() + 200000 * i, | |||||
200000, mean); | |||||
ASSERT_LT(fabs(stat.mean - mean), 0.01); | |||||
ASSERT_LT(fabs(stat.std - sqrt(std)), 0.01); | |||||
} | |||||
} | |||||
TEST(TestOprRand, PermutationRNG) { | |||||
static constexpr size_t SIZE = 123451; | |||||
auto graph = ComputingGraph::make(); | |||||
auto y = opr::PermutationRNG::make( | |||||
SymbolVar::make_scalar(int(SIZE), *graph, {CompNode::load("xpu0")}), | |||||
{23, DTypeEnum::Int32}); | |||||
HostTensorND host_y; | |||||
auto func = graph->compile({make_callback_copy(y, host_y)}); | |||||
func->execute(); | |||||
ASSERT_EQ(TensorShape({SIZE}), host_y.shape()); | |||||
auto ptr = host_y.ptr<int32_t>(); | |||||
std::vector<int32_t> res(SIZE); | |||||
int not_same = 0; | |||||
for (size_t i = 0; i < SIZE; ++i) { | |||||
if ((ptr[i] - int32_t(i)) >= 1) not_same++; | |||||
res[i] = ptr[i]; | |||||
} | |||||
ASSERT_GT(not_same, 5000); | |||||
std::sort(res.begin(), res.end()); | |||||
for (size_t i = 0; i < SIZE; ++i) { | |||||
ASSERT_LE(std::abs(res[i] - int32_t(i)), 1e-8); | |||||
} | |||||
} | |||||
TEST(TestOprRand, UniformReprod) { | TEST(TestOprRand, UniformReprod) { | ||||
check_reproducibility([](SymbolVar shp, uint64_t seed) { | |||||
static constexpr size_t SIZE = 123; | |||||
auto graph = ComputingGraph::make(); | |||||
auto shp = cg::var_from_tensor_shape(*graph, {CompNode::load("xpu0")}, | |||||
"shp0", {SIZE}); | |||||
check_reproducibility(graph, SIZE, [&shp](uint64_t seed) { | |||||
return opr::UniformRNG::make(shp, {seed}); | return opr::UniformRNG::make(shp, {seed}); | ||||
}); | }); | ||||
} | } | ||||
TEST(TestOprRand, GaussianReprod) { | TEST(TestOprRand, GaussianReprod) { | ||||
check_reproducibility([](SymbolVar shp, uint64_t seed) { | |||||
static constexpr size_t SIZE = 123; | |||||
auto graph = ComputingGraph::make(); | |||||
auto shp = cg::var_from_tensor_shape(*graph, {CompNode::load("xpu0")}, | |||||
"shp0", {SIZE}); | |||||
check_reproducibility(graph, SIZE, [&shp](uint64_t seed) { | |||||
return opr::GaussianRNG::make(shp, {seed}); | return opr::GaussianRNG::make(shp, {seed}); | ||||
}); | }); | ||||
} | } | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
TEST(TestOprRand, GammaReprod) { | |||||
static constexpr size_t SIZE = 123; | |||||
std::shared_ptr<HostTensorND> shape_host(new HostTensorND{ | |||||
CompNode::load("xpux"), TensorShape{SIZE}, dtype::Float32()}); | |||||
std::shared_ptr<HostTensorND> scale_host(new HostTensorND{ | |||||
CompNode::load("xpux"), TensorShape{SIZE}, dtype::Float32()}); | |||||
auto shape_ptr = shape_host->ptr<float>(); | |||||
auto scale_ptr = scale_host->ptr<float>(); | |||||
for (size_t i = 0; i < SIZE; ++i){ | |||||
shape_ptr[i] = 0.5; | |||||
scale_ptr[i] = 1.2; | |||||
} | |||||
auto graph = ComputingGraph::make(); | |||||
auto shape_sym = opr::Host2DeviceCopy::make(*graph, shape_host); | |||||
auto scale_sym = opr::Host2DeviceCopy::make(*graph, scale_host); | |||||
check_reproducibility(graph, SIZE, [&shape_sym,&scale_sym](uint64_t seed) { | |||||
return opr::GammaRNG::make(shape_sym, scale_sym, {seed}); | |||||
}); | |||||
} | |||||
TEST(TestOprRand, PoissonReprod) { | |||||
static constexpr size_t SIZE = 123; | |||||
std::shared_ptr<HostTensorND> lam_host(new HostTensorND{ | |||||
CompNode::load("xpux"), TensorShape{SIZE}, dtype::Float32()}); | |||||
auto lam_ptr = lam_host->ptr<float>(); | |||||
for (size_t i = 0; i < SIZE; ++i) | |||||
lam_ptr[i] = 2; | |||||
auto graph = ComputingGraph::make(); | |||||
auto lam_sym = opr::Host2DeviceCopy::make(*graph, lam_host); | |||||
check_reproducibility(graph, SIZE, [&lam_sym](uint64_t seed) { | |||||
return opr::PoissonRNG::make(lam_sym, {seed}); | |||||
}); | |||||
} | |||||
TEST(TestOprRand, BetaReprod) { | |||||
static constexpr size_t SIZE = 123; | |||||
std::shared_ptr<HostTensorND> alpha_host(new HostTensorND{ | |||||
CompNode::load("xpux"), TensorShape{SIZE}, dtype::Float32()}); | |||||
std::shared_ptr<HostTensorND> beta_host(new HostTensorND{ | |||||
CompNode::load("xpux"), TensorShape{SIZE}, dtype::Float32()}); | |||||
auto alpha_ptr = alpha_host->ptr<float>(); | |||||
auto beta_ptr = beta_host->ptr<float>(); | |||||
for (size_t i = 0; i < SIZE; ++i){ | |||||
alpha_ptr[i] = 0.5; | |||||
beta_ptr[i] = 1.2; | |||||
} | |||||
auto graph = ComputingGraph::make(); | |||||
auto alpha_sym = opr::Host2DeviceCopy::make(*graph, alpha_host); | |||||
auto beta_sym = opr::Host2DeviceCopy::make(*graph, beta_host); | |||||
check_reproducibility(graph, SIZE, [&alpha_sym,&beta_sym](uint64_t seed) { | |||||
return opr::BetaRNG::make(alpha_sym, beta_sym, {seed}); | |||||
}); | |||||
} | |||||
TEST(TestOprRand, PermutationReprod) { | |||||
static constexpr size_t SIZE = 123; | |||||
auto graph = ComputingGraph::make(); | |||||
auto shp = cg::var_from_tensor_shape(*graph, {CompNode::load("xpu0")}, | |||||
"shp0", {SIZE}); | |||||
check_reproducibility(graph, SIZE, [&shp](uint64_t seed) { | |||||
return opr::PermutationRNG::make(shp, {seed, DTypeEnum::Float32}); | |||||
}); | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -108,6 +108,10 @@ union OperatorParam { | |||||
param.TQT = 74, | param.TQT = 74, | ||||
param.Correlation = 75, | param.Correlation = 75, | ||||
param.LSQ = 76, | param.LSQ = 76, | ||||
param.GammaRNG = 77, | |||||
param.PoissonRNG = 78, | |||||
param.PermutationRNG = 79, | |||||
param.BetaRNG = 80, | |||||
} | } | ||||
table Operator { | table Operator { | ||||