Browse Source

feat(dnn): add fill kernel

GitOrigin-RevId: d2cee3a7a0
tags/v1.6.0-rc1
Megvii Engine Team 3 years ago
parent
commit
1af350c6d2
21 changed files with 512 additions and 2 deletions
  1. +15
    -0
      dnn/include/megdnn/oprs/general.h
  2. +1
    -1
      dnn/scripts/opr_param_defs.py
  3. +25
    -0
      dnn/src/common/fill.cpp
  4. +2
    -1
      dnn/src/common/handle_impl.h
  5. +1
    -0
      dnn/src/common/opr_trait.h
  6. +45
    -0
      dnn/src/cuda/fill/kern.cu
  7. +25
    -0
      dnn/src/cuda/fill/kern.cuh
  8. +37
    -0
      dnn/src/cuda/fill/opr_impl.cpp
  9. +31
    -0
      dnn/src/cuda/fill/opr_impl.h
  10. +1
    -0
      dnn/src/cuda/handle_create.cpp
  11. +46
    -0
      dnn/src/naive/fill/opr_impl.cpp
  12. +33
    -0
      dnn/src/naive/fill/opr_impl.h
  13. +1
    -0
      dnn/src/naive/handle.cpp
  14. +51
    -0
      dnn/src/rocm/fill/fill.cpp.hip
  15. +25
    -0
      dnn/src/rocm/fill/fill.h.hip
  16. +36
    -0
      dnn/src/rocm/fill/opr_impl.cpp
  17. +26
    -0
      dnn/src/rocm/fill/opr_impl.h
  18. +2
    -0
      dnn/src/rocm/handle.cpp
  19. +37
    -0
      dnn/test/common/fill.h
  20. +36
    -0
      dnn/test/cuda/fill.cpp
  21. +36
    -0
      dnn/test/rocm/fill.cpp

+ 15
- 0
dnn/include/megdnn/oprs/general.h View File

@@ -1338,6 +1338,21 @@ class CheckHasInf: public OperatorBase {
void check_exec(const TensorLayout &src, const TensorLayout &dst,
size_t workspace_in_bytes);
};

/*!
* \brief fill the tensor with a scalar value
*/
class Fill: public OperatorBase {
DEF_OPR_PARAM(Fill);
DEF_OPR_IMPL(Fill, OperatorBase, 0, 1);

public:
virtual void exec(_megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout& dst) = 0;
protected:
void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
};
} // namespace megdnn

#include "megdnn/internal/opr_header_epilogue.h"


+ 1
- 1
dnn/scripts/opr_param_defs.py View File

@@ -1170,4 +1170,4 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o
add_fields('int32', 'qmin', '-2147483648').
add_fields('int32', 'qmax', '2147483647')
)
pdef('Fill').add_fields('float32', 'value', '0')

+ 25
- 0
dnn/src/common/fill.cpp View File

@@ -0,0 +1,25 @@
/**
* \file dnn/src/common/fill.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "megdnn/oprs.h"
#include "src/common/utils.h"

namespace megdnn {

void Fill::check_exec(const TensorLayout& dst, size_t workspace_in_bytes) {
megdnn_assert_contiguous(dst);
auto required_workspace_in_bytes = get_workspace_in_bytes(dst);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}

} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 2
- 1
dnn/src/common/handle_impl.h View File

@@ -216,7 +216,8 @@ private:
cb(TQTBackward) \
cb(CheckHasInf) \
cb(LSQForward) \
cb(LSQBackward)
cb(LSQBackward) \
cb(Fill)

/*!
* \brief specialize HandleImpl::create_operator for a single opr type;


+ 1
- 0
dnn/src/common/opr_trait.h View File

@@ -130,6 +130,7 @@ DEF(ChecksumForward, 1, true, false);
DEF(CheckHasInf, 2, true, true);
DEF(LSQForward, 5, true, true);
DEF(LSQBackward, 7, true, false);
DEF(Fill, 1, true, false);
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 45
- 0
dnn/src/cuda/fill/kern.cu View File

@@ -0,0 +1,45 @@
/**
* \file dnn/src/cuda/fill/kern.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/fill/kern.cuh"
#include "megdnn/dtype.h"
#include "src/cuda/utils.cuh"

namespace {

template <typename T>
__global__ void kernel(T *dst, T value, uint32_t size) {
int32_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i < size) {
dst[i] = value;
}
}

} // anonymous namespace

namespace megdnn {
namespace cuda {
namespace fill {

template <typename T>
void exec_internal(T *dst, T value, size_t size, cudaStream_t stream) {
kernel<T><<<DIVUP(size, NR_THREADS), NR_THREADS, 0, stream>>>(dst, value, size);
after_kernel_launch();
}

#define INST(T) template void exec_internal<T>(T *, \
T, size_t, cudaStream_t);
#define cb(DType) INST(typename DTypeTrait<DType>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)

} // namespace fill
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 25
- 0
dnn/src/cuda/fill/kern.cuh View File

@@ -0,0 +1,25 @@
/**
* \file dnn/src/cuda/fill/kern.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <stdint.h>
#include <cuda_runtime_api.h>

namespace megdnn {
namespace cuda {
namespace fill {

template <typename T>
void exec_internal(T *dst, T value, size_t size, cudaStream_t stream);

} // namespace fill
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 37
- 0
dnn/src/cuda/fill/opr_impl.cpp View File

@@ -0,0 +1,37 @@
/**
* \file dnn/src/cuda/fill/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/cuda/fill/kern.cuh"
#include "src/cuda/fill/opr_impl.h"

#include "src/cuda/utils.h"

namespace megdnn {
namespace cuda {

void FillImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) {
check_exec(dst.layout, workspace.size);
auto stream = cuda_stream(handle());
auto size = dst.layout.total_nr_elems();
#define cb(DType) \
if (dst.layout.dtype == DType()) { \
using ctype = typename DTypeTrait<DType>::ctype; \
fill::exec_internal<ctype>(dst.ptr<ctype>(), \
static_cast<ctype>(param().value), size, stream); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
}

} // namespace cuda
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 31
- 0
dnn/src/cuda/fill/opr_impl.h View File

@@ -0,0 +1,31 @@
/**
* \file dnn/src/cuda/fill/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once
#include "megdnn/oprs.h"

namespace megdnn {
namespace cuda {

class FillImpl final : public Fill {
public:
using Fill::Fill;
void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout &) override {
return 0;
}
};

} // namespace cuda
} // namespace megdnn

// vim: syntax=cpp.doxygen


+ 1
- 0
dnn/src/cuda/handle_create.cpp View File

@@ -38,6 +38,7 @@
#include "src/cuda/elemwise_multi_type/opr_impl.h"
#include "src/cuda/eye/opr_impl.h"
#include "src/cuda/fake_quant/opr_impl.h"
#include "src/cuda/fill/opr_impl.h"
#include "src/cuda/flip/opr_impl.h"
#include "src/cuda/gaussian_blur/opr_impl.h"
#include "src/cuda/group_local/opr_impl.h"


+ 46
- 0
dnn/src/naive/fill/opr_impl.cpp View File

@@ -0,0 +1,46 @@
/**
* \file dnn/src/naive/fill/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/naive/fill/opr_impl.h"
#include "src/naive/handle.h"
#include "src/common/utils.h"

#include <cstring>
#include <limits>

namespace megdnn {
namespace naive {

void FillImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) {
check_exec(dst.layout, workspace.size);
size_t size = dst.layout.total_nr_elems();
#define cb(DType) \
if (dst.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
using ctype = typename DTypeTrait<DType>::ctype; \
ctype *ptr = dst.ptr<ctype>(); \
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal<ctype>(ptr, size)); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
}

template <typename ctype>
void FillImpl::exec_internal(ctype *dst, size_t size) {
auto value = static_cast<ctype>(param().value);
for (size_t i = 0; i < size; ++i) {
dst[i] = value;
}
}

} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen


+ 33
- 0
dnn/src/naive/fill/opr_impl.h View File

@@ -0,0 +1,33 @@
/**
* \file dnn/src/naive/fill/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"

namespace megdnn {
namespace naive {

class FillImpl : public Fill {
public:
using Fill::Fill;

void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout &) override {
return 0;
}
private:
template <typename ctype>
void exec_internal(ctype *dst, size_t size);
};

} // namespace naive
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 1
- 0
dnn/src/naive/handle.cpp View File

@@ -40,6 +40,7 @@
#include "src/naive/elemwise_multi_type/opr_impl.h"
#include "src/naive/eye/opr_impl.h"
#include "src/naive/fake_quant/opr_impl.h"
#include "src/naive/fill/opr_impl.h"
#include "src/naive/flip/opr_impl.h"
#include "src/naive/gaussian_blur/opr_impl.h"
#include "src/naive/group_local/opr_impl.h"


+ 51
- 0
dnn/src/rocm/fill/fill.cpp.hip View File

@@ -0,0 +1,51 @@
/**
* \file dnn/src/rocm/fill/fill.cpp.hip
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "hcc_detail/hcc_defs_prologue.h"
#include "hip_header.h"
#include "megdnn/dtype.h"
#include "src/rocm/fill/fill.h.hip"
#include "src/rocm/utils.h.hip"

namespace {

template <typename T>
__global__ void kernel(T *dst, T value, uint32_t size) {
int32_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i < size) {
dst[i] = value;
}
}

} // anonymous namespace

namespace megdnn {
namespace rocm {
namespace fill {

template <typename T>
void exec_internal(T *dst, T value, size_t size, hipStream_t stream) {
hipLaunchKernelGGL(
(kernel<T>),
dim3(DIVUP(size, NR_THREADS)),
dim3(NR_THREADS),
0, stream, dst, value, size);
after_kernel_launch();
}

#define INST(T) template void exec_internal<T>(T *, \
T, size_t, hipStream_t);
#define cb(DType) INST(typename DTypeTrait<DType>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)

} // namespace fill
} // namespace rocm
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 25
- 0
dnn/src/rocm/fill/fill.h.hip View File

@@ -0,0 +1,25 @@
/**
* \file dnn/src/rocm/fill/fill.h.hip
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <stdint.h>
#include "hip_header.h"

namespace megdnn {
namespace rocm {
namespace fill {

template <typename T>
void exec_internal(T *dst, T value, size_t size, hipStream_t stream);

} // namespace fill
} // namespace rocm
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 36
- 0
dnn/src/rocm/fill/opr_impl.cpp View File

@@ -0,0 +1,36 @@
/**
* \file dnn/src/rocm/fill/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "hcc_detail/hcc_defs_prologue.h"
#include "src/rocm/fill/opr_impl.h"

#include "src/rocm/fill/fill.h.hip"
#include "src/rocm/utils.h"

namespace megdnn {
namespace rocm {

void FillImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) {
check_exec(dst.layout, workspace.size);
auto stream = hip_stream(handle());
auto size = dst.layout.total_nr_elems();
#define cb(DType) \
if (dst.layout.dtype == DType()) { \
using ctype = typename DTypeTrait<DType>::ctype; \
fill::exec_internal<ctype>(dst.ptr<ctype>(), \
static_cast<ctype>(param().value), size, stream); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
}

} // namespace rocm
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 26
- 0
dnn/src/rocm/fill/opr_impl.h View File

@@ -0,0 +1,26 @@
/**
* \file dnn/src/rocm/fill/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"

namespace megdnn {
namespace rocm {

class FillImpl final : public Fill {
public:
using Fill::Fill;
void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; }
};

} // namespace rocm
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 2
- 0
dnn/src/rocm/handle.cpp View File

@@ -37,6 +37,7 @@
#include "src/rocm/sleep/opr_impl.h"
#include "src/rocm/batch_normalization/opr_impl.h"
#include "src/rocm/param_pack/opr_impl.h"
#include "src/rocm/fill/opr_impl.h"

#include <miopen/version.h>
#include <hip/hip_version.h>
@@ -176,6 +177,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ParamPackConcat);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Fill);

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wpragmas"


+ 37
- 0
dnn/test/common/fill.h View File

@@ -0,0 +1,37 @@
/**
* \file dnn/test/common/fill.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include "megdnn/handle.h"
#include "megdnn/oprs/general.h"

#include "src/common/opr_trait.h"
#include "test/common/checker.h"

namespace megdnn {
namespace test {
namespace fill {

inline void run_fill_test(Handle* handle, DType dtype) {
Checker<Fill> checker(handle);
for (float value : {-1.23, 0.0, 0.001, 234.0, 2021.072}) {
checker.set_param({value});
checker.set_dtype(0, dtype);
checker.exec(TensorShapeArray{{1, 1}});
checker.exec(TensorShapeArray{{2, 3, 4}});
}
}

} // namespace fill
} // namespace test
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 36
- 0
dnn/test/cuda/fill.cpp View File

@@ -0,0 +1,36 @@
/**
* \file dnn/test/cuda/fill.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "test/common/fill.h"

#include "test/cuda/fixture.h"

namespace megdnn {
namespace test {
namespace fill {

TEST_F(CUDA, FILL_F32) {
run_fill_test(handle_cuda(), dtype::Float32{});
}

TEST_F(CUDA, FILL_I32) {
run_fill_test(handle_cuda(), dtype::Int32{});
}

#if !MEGDNN_DISABLE_FLOAT16
TEST_F(CUDA, FILL_F16) {
run_fill_test(handle_cuda(), dtype::Float16{});
}
#endif

} // namespace fill
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 36
- 0
dnn/test/rocm/fill.cpp View File

@@ -0,0 +1,36 @@
/**
* \file dnn/test/rocm/fill.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "test/common/fill.h"

#include "test/rocm/fixture.h"

namespace megdnn {
namespace test {
namespace fill {

TEST_F(ROCM, FILL_F32) {
run_fill_test(handle_rocm(), dtype::Float32{});
}

TEST_F(ROCM, FILL_I32) {
run_fill_test(handle_rocm(), dtype::Int32{});
}

#if !MEGDNN_DISABLE_FLOAT16
TEST_F(ROCM, FILL_F16) {
run_fill_test(handle_rocm(), dtype::Float16{});
}
#endif

} // namespace fill
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

Loading…
Cancel
Save