|
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import numpy as np
- import torch
-
- from megengine.core import tensor
- from megengine.utils import prod
-
-
- def _uniform(shape):
- return np.random.random(shape).astype(np.float32)
-
-
- def init_with_same_value(mge_param, torch_param, initializer=_uniform):
- mge_shape = mge_param.shape
- torch_shape = torch_param.shape
- assert prod(mge_shape) == prod(torch_shape)
- weight = initializer(mge_shape)
- mge_param.set_value(weight)
- torch_param.data = torch.Tensor(weight.reshape(torch_shape))
-
-
- def gen_same_input(shape, initializer=_uniform):
- data = initializer(shape)
- mge_input = tensor(data)
- torch_input = torch.Tensor(data)
- return mge_input, torch_input
|