Browse Source

fix(mgb/sereg): fix rng operator compatibility

GitOrigin-RevId: 66d1694035
release-1.5
Megvii Engine Team 3 years ago
parent
commit
287cab49c2
3 changed files with 21 additions and 11 deletions
  1. +9
    -2
      dnn/scripts/opr_param_defs.py
  2. +2
    -2
      src/opr/impl/rand.oprdecl
  3. +10
    -7
      src/opr/impl/rand.sereg.h

+ 9
- 2
dnn/scripts/opr_param_defs.py View File

@@ -745,13 +745,20 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
'dtype', Doc('dtype', 'data type of output value'), 'dtype', Doc('dtype', 'data type of output value'),
'DTypeEnum::Float32')) 'DTypeEnum::Float32'))


(pdef('UniformRNG').
(pdef('UniformRNG', version=0, is_legacy=True).
add_fields('uint64', 'seed', 0))

(pdef('UniformRNG', version=1).
add_fields('uint64', 'seed', 0). add_fields('uint64', 'seed', 0).
add_fields( add_fields(
'dtype', Doc('dtype', 'The dtype of output Tensor. Only support Float32.'), 'dtype', Doc('dtype', 'The dtype of output Tensor. Only support Float32.'),
'DTypeEnum::Float32')) 'DTypeEnum::Float32'))


(pdef('GaussianRNG').
(pdef('GaussianRNG', version=0, is_legacy=True).
add_fields('uint64', 'seed', 0).
add_fields('float32', 'mean', 0, 'std', 1))

(pdef('GaussianRNG', version=1).
add_fields('uint64', 'seed', 0). add_fields('uint64', 'seed', 0).
add_fields('float32', 'mean', 0, 'std', 1). add_fields('float32', 'mean', 0, 'std', 1).
add_fields( add_fields(


+ 2
- 2
src/opr/impl/rand.oprdecl View File

@@ -1,12 +1,12 @@
decl_opr('UniformRNG', pyname='_uniform_rng', decl_opr('UniformRNG', pyname='_uniform_rng',
inputs=['shape'], inputs=['shape'],
params='UniformRNG', params='UniformRNG',
canonize_input_vars='canonize_shape_input')
canonize_input_vars='canonize_shape_input', version=1)


decl_opr('GaussianRNG', pyname='_gaussian_rng', decl_opr('GaussianRNG', pyname='_gaussian_rng',
inputs=['shape'], inputs=['shape'],
params='GaussianRNG', params='GaussianRNG',
canonize_input_vars='canonize_shape_input')
canonize_input_vars='canonize_shape_input', version=1)


inputs = [ inputs = [
Doc('shape', Doc('shape',


+ 10
- 7
src/opr/impl/rand.sereg.h View File

@@ -13,18 +13,21 @@
#include "megbrain/serialization/sereg.h" #include "megbrain/serialization/sereg.h"


namespace mgb { namespace mgb {


namespace opr { namespace opr {


MGB_SEREG_OPR(UniformRNG, 1);
MGB_SEREG_OPR(GaussianRNG, 1);
MGB_SEREG_OPR(GammaRNG, 2);
MGB_SEREG_OPR(PoissonRNG, 1);
MGB_SEREG_OPR(PermutationRNG, 1);
MGB_SEREG_OPR(BetaRNG, 2);
using UniformRNGV1 = opr::UniformRNG;
MGB_SEREG_OPR(UniformRNGV1, 1);
using GaussianRNGV1 = opr::GaussianRNG;
MGB_SEREG_OPR(GaussianRNGV1, 1);
MGB_SEREG_OPR(GammaRNG, 2);
MGB_SEREG_OPR(PoissonRNG, 1);
MGB_SEREG_OPR(PermutationRNG, 1);
MGB_SEREG_OPR(BetaRNG, 2);


} // namespace opr } // namespace opr
} // namespace mgb } // namespace mgb



// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}



Loading…
Cancel
Save