- /**
- * \file dnn/include/megdnn/oprs/utils.h
- * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- *
- * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
- * implied.
- */
- #pragma once
- #include "megdnn/internal/opr_header_prologue.h"
-
- namespace megdnn {
-
- //! base class for random number generators
- class RNGBase : public OperatorBase {
- DEF_OPR_IMPL_CTOR(RNGBase, OperatorBase);
-
- public:
- virtual void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(const TensorLayout& dst) = 0;
-
- protected:
- virtual void check_exec(const TensorLayout& dst, size_t workspace_in_bytes) = 0;
- };
-
- //! sample from poisson distribution
- class PoissonRNG : public OperatorBase {
- DEF_OPR_IMPL(PoissonRNG, OperatorBase, 1, 1);
- DEF_OPR_PARAM(PoissonRNG);
-
- public:
- virtual void exec(
- _megdnn_tensor_in lam, _megdnn_tensor_out dst,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& lam, const TensorLayout& dst) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& lam, const TensorLayout& dst,
- size_t workspace_in_bytes);
- };
-
- //! sample from beta distribution
- class BetaRNG : public OperatorBase {
- DEF_OPR_IMPL(BetaRNG, OperatorBase, 2, 1);
- DEF_OPR_PARAM(BetaRNG);
-
- public:
- virtual void exec(
- _megdnn_tensor_in alpha, _megdnn_tensor_in beta, _megdnn_tensor_out dst,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& alpha, const TensorLayout& beta,
- const TensorLayout& dst) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& alpha, const TensorLayout& beta,
- const TensorLayout& dst, size_t workspace_in_bytes);
- };
-
- //! sample from gamma distribution
- class GammaRNG : public OperatorBase {
- DEF_OPR_IMPL(GammaRNG, OperatorBase, 2, 1);
- DEF_OPR_PARAM(GammaRNG);
-
- public:
- virtual void exec(
- _megdnn_tensor_in shape, _megdnn_tensor_in scale, _megdnn_tensor_out dst,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& shape, const TensorLayout& scale,
- const TensorLayout& dst) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& shape, const TensorLayout& scale,
- const TensorLayout& dst, size_t workspace_in_bytes);
- };
-
- //! sample from uniform distribution on the interval (0, 1]
- class UniformRNG : public RNGBase {
- DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1);
- DEF_OPR_PARAM(UniformRNG);
-
- protected:
- void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
- };
-
- //! sample from gaussian distribution
- class GaussianRNG : public RNGBase {
- DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1);
- DEF_OPR_PARAM(GaussianRNG);
-
- protected:
- void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
- };
-
- class PermutationRNG : public RNGBase {
- DEF_OPR_IMPL(PermutationRNG, RNGBase, 0, 1);
- DEF_OPR_PARAM(PermutationRNG);
-
- protected:
- void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
- };
-
- class ShuffleRNGForward : public OperatorBase {
- DEF_OPR_IMPL(ShuffleRNGForward, OperatorBase, 1, 2);
- DEF_OPR_PARAM(ShuffleRNG);
-
- public:
- virtual void exec(
- _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_tensor_out indices,
- _megdnn_workspace workspace) = 0;
- void deduce_layout(
- const TensorLayout& src, TensorLayout& dst, TensorLayout& indices);
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& src, const TensorLayout& dst,
- const TensorLayout& indices) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& src, const TensorLayout& dst,
- const TensorLayout& indices, size_t workspace_in_bytes);
- };
- using ShuffleRNG = ShuffleRNGForward;
-
- class ShuffleRNGBackward : public OperatorBase {
- DEF_OPR_IMPL(ShuffleRNGBackward, OperatorBase, 2, 1);
- DEF_OPR_PARAM(ShuffleRNG);
-
- public:
- virtual void exec(
- _megdnn_tensor_in diff, _megdnn_tensor_in indices, _megdnn_tensor_out grad,
- _megdnn_workspace workspace) = 0;
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& diff, const TensorLayout& indices,
- const TensorLayout& grad) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& diff, const TensorLayout& indices,
- const TensorLayout& grad, size_t workspace_in_bytes);
- };
-
- /*!
- * \brief sleep for specific time on the computing device; useful for testing
- * async problems
- */
- class SleepForward : public OperatorBase {
- DEF_OPR_IMPL(SleepForward, OperatorBase, 0, 0);
- DEF_OPR_PARAM(Sleep);
-
- public:
- virtual void exec() = 0;
- };
- using Sleep = SleepForward;
-
- /*!
- * \brief calculating checksum of a tensor
- *
- * data must be a one-dimensional contiguous tensor with dtype byte
- */
- class ChecksumForward : public OperatorBase {
- DEF_OPR_PARAM(Empty);
- DEF_OPR_IMPL(ChecksumForward, OperatorBase, 0, 1);
-
- public:
- using Result = opr_result::Checksum;
-
- virtual size_t get_workspace_in_bytes(const TensorLayout& data) = 0;
-
- virtual Result exec(_megdnn_tensor_in data, _megdnn_workspace workspace) = 0;
-
- protected:
- void check_exec(const TensorLayout& layout, size_t workspace_in_bytes);
- };
- using Checksum = ChecksumForward;
-
- /*!
- * \brief calculating max absolute difference of the two input tensors
- *
- * src1 and src2 must be a one-dimensional contiguous tensor.
- */
- class MaxTensorDiff : public OperatorBase {
- DEF_OPR_PARAM(Empty);
- DEF_OPR_IMPL(MaxTensorDiff, OperatorBase, 0, 2);
-
- public:
- virtual size_t get_workspace_in_bytes(
- const TensorLayout& layout1, const TensorLayout& layout2) = 0;
-
- virtual float exec(
- _megdnn_tensor_in src1, _megdnn_tensor_in src2,
- _megdnn_workspace workspace) = 0;
-
- protected:
- void check_exec(
- const TensorLayout& layout1, const TensorLayout& layout2,
- size_t workspace_in_bytes);
- };
-
- bool check_bias_share_in_channel(
- const TensorLayout& bias, const param::ConvBias::Format format);
-
- } // namespace megdnn
-
- #include "megdnn/internal/opr_header_epilogue.h"
-
- // vim: syntax=cpp.doxygen
|