Browse Source

fix(mgb/sereg): fix rng operator compatibility

GitOrigin-RevId: 66d1694035
release-1.5
Megvii Engine Team 3 years ago
parent
commit
287cab49c2
3 changed files with 21 additions and 11 deletions
  1. +9
    -2
      dnn/scripts/opr_param_defs.py
  2. +2
    -2
      src/opr/impl/rand.oprdecl
  3. +10
    -7
      src/opr/impl/rand.sereg.h

+ 9
- 2
dnn/scripts/opr_param_defs.py View File

@@ -745,13 +745,20 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
'dtype', Doc('dtype', 'data type of output value'),
'DTypeEnum::Float32'))

(pdef('UniformRNG').
(pdef('UniformRNG', version=0, is_legacy=True).
add_fields('uint64', 'seed', 0))

(pdef('UniformRNG', version=1).
add_fields('uint64', 'seed', 0).
add_fields(
'dtype', Doc('dtype', 'The dtype of output Tensor. Only support Float32.'),
'DTypeEnum::Float32'))

(pdef('GaussianRNG').
(pdef('GaussianRNG', version=0, is_legacy=True).
add_fields('uint64', 'seed', 0).
add_fields('float32', 'mean', 0, 'std', 1))

(pdef('GaussianRNG', version=1).
add_fields('uint64', 'seed', 0).
add_fields('float32', 'mean', 0, 'std', 1).
add_fields(


+ 2
- 2
src/opr/impl/rand.oprdecl View File

@@ -1,12 +1,12 @@
decl_opr('UniformRNG', pyname='_uniform_rng',
inputs=['shape'],
params='UniformRNG',
canonize_input_vars='canonize_shape_input')
canonize_input_vars='canonize_shape_input', version=1)

decl_opr('GaussianRNG', pyname='_gaussian_rng',
inputs=['shape'],
params='GaussianRNG',
canonize_input_vars='canonize_shape_input')
canonize_input_vars='canonize_shape_input', version=1)

inputs = [
Doc('shape',


+ 10
- 7
src/opr/impl/rand.sereg.h View File

@@ -13,18 +13,21 @@
#include "megbrain/serialization/sereg.h"

namespace mgb {


namespace opr {

MGB_SEREG_OPR(UniformRNG, 1);
MGB_SEREG_OPR(GaussianRNG, 1);
MGB_SEREG_OPR(GammaRNG, 2);
MGB_SEREG_OPR(PoissonRNG, 1);
MGB_SEREG_OPR(PermutationRNG, 1);
MGB_SEREG_OPR(BetaRNG, 2);
using UniformRNGV1 = opr::UniformRNG;
MGB_SEREG_OPR(UniformRNGV1, 1);
using GaussianRNGV1 = opr::GaussianRNG;
MGB_SEREG_OPR(GaussianRNGV1, 1);
MGB_SEREG_OPR(GammaRNG, 2);
MGB_SEREG_OPR(PoissonRNG, 1);
MGB_SEREG_OPR(PermutationRNG, 1);
MGB_SEREG_OPR(BetaRNG, 2);

} // namespace opr
} // namespace mgb


// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}


Loading…
Cancel
Save