@@ -1338,6 +1338,21 @@ class CheckHasInf: public OperatorBase { | |||
void check_exec(const TensorLayout &src, const TensorLayout &dst, | |||
size_t workspace_in_bytes); | |||
}; | |||
/*! | |||
* \brief fill the tensor with a scalar value | |||
*/ | |||
class Fill: public OperatorBase { | |||
DEF_OPR_PARAM(Fill); | |||
DEF_OPR_IMPL(Fill, OperatorBase, 0, 1); | |||
public: | |||
virtual void exec(_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); | |||
}; | |||
} // namespace megdnn | |||
#include "megdnn/internal/opr_header_epilogue.h" | |||
@@ -1170,4 +1170,4 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o | |||
add_fields('int32', 'qmin', '-2147483648'). | |||
add_fields('int32', 'qmax', '2147483647') | |||
) | |||
pdef('Fill').add_fields('float32', 'value', '0') |
@@ -0,0 +1,25 @@ | |||
/** | |||
* \file dnn/src/common/fill.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
void Fill::check_exec(const TensorLayout& dst, size_t workspace_in_bytes) { | |||
megdnn_assert_contiguous(dst); | |||
auto required_workspace_in_bytes = get_workspace_in_bytes(dst); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
} | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -216,7 +216,8 @@ private: | |||
cb(TQTBackward) \ | |||
cb(CheckHasInf) \ | |||
cb(LSQForward) \ | |||
cb(LSQBackward) | |||
cb(LSQBackward) \ | |||
cb(Fill) | |||
/*! | |||
* \brief specialize HandleImpl::create_operator for a single opr type; | |||
@@ -130,6 +130,7 @@ DEF(ChecksumForward, 1, true, false); | |||
DEF(CheckHasInf, 2, true, true); | |||
DEF(LSQForward, 5, true, true); | |||
DEF(LSQBackward, 7, true, false); | |||
DEF(Fill, 1, true, false); | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,45 @@ | |||
/** | |||
* \file dnn/src/cuda/fill/kern.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/cuda/fill/kern.cuh" | |||
#include "megdnn/dtype.h" | |||
#include "src/cuda/utils.cuh" | |||
namespace { | |||
template <typename T> | |||
__global__ void kernel(T *dst, T value, uint32_t size) { | |||
int32_t i = threadIdx.x + blockIdx.x * blockDim.x; | |||
if (i < size) { | |||
dst[i] = value; | |||
} | |||
} | |||
} // anonymous namespace | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace fill { | |||
template <typename T> | |||
void exec_internal(T *dst, T value, size_t size, cudaStream_t stream) { | |||
kernel<T><<<DIVUP(size, NR_THREADS), NR_THREADS, 0, stream>>>(dst, value, size); | |||
after_kernel_launch(); | |||
} | |||
#define INST(T) template void exec_internal<T>(T *, \ | |||
T, size_t, cudaStream_t); | |||
#define cb(DType) INST(typename DTypeTrait<DType>::ctype) | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
} // namespace fill | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,25 @@ | |||
/** | |||
* \file dnn/src/cuda/fill/kern.cuh | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <stdint.h> | |||
#include <cuda_runtime_api.h> | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace fill { | |||
template <typename T> | |||
void exec_internal(T *dst, T value, size_t size, cudaStream_t stream); | |||
} // namespace fill | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,37 @@ | |||
/** | |||
* \file dnn/src/cuda/fill/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/cuda/fill/kern.cuh" | |||
#include "src/cuda/fill/opr_impl.h" | |||
#include "src/cuda/utils.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
void FillImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
check_exec(dst.layout, workspace.size); | |||
auto stream = cuda_stream(handle()); | |||
auto size = dst.layout.total_nr_elems(); | |||
#define cb(DType) \ | |||
if (dst.layout.dtype == DType()) { \ | |||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||
fill::exec_internal<ctype>(dst.ptr<ctype>(), \ | |||
static_cast<ctype>(param().value), size, stream); \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
#undef cb | |||
} | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,31 @@ | |||
/** | |||
* \file dnn/src/cuda/fill/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
class FillImpl final : public Fill { | |||
public: | |||
using Fill::Fill; | |||
void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout &) override { | |||
return 0; | |||
} | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen | |||
@@ -38,6 +38,7 @@ | |||
#include "src/cuda/elemwise_multi_type/opr_impl.h" | |||
#include "src/cuda/eye/opr_impl.h" | |||
#include "src/cuda/fake_quant/opr_impl.h" | |||
#include "src/cuda/fill/opr_impl.h" | |||
#include "src/cuda/flip/opr_impl.h" | |||
#include "src/cuda/gaussian_blur/opr_impl.h" | |||
#include "src/cuda/group_local/opr_impl.h" | |||
@@ -0,0 +1,46 @@ | |||
/** | |||
* \file dnn/src/naive/fill/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/naive/fill/opr_impl.h" | |||
#include "src/naive/handle.h" | |||
#include "src/common/utils.h" | |||
#include <cstring> | |||
#include <limits> | |||
namespace megdnn { | |||
namespace naive { | |||
void FillImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
check_exec(dst.layout, workspace.size); | |||
size_t size = dst.layout.total_nr_elems(); | |||
#define cb(DType) \ | |||
if (dst.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \ | |||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||
ctype *ptr = dst.ptr<ctype>(); \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal<ctype>(ptr, size)); \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
#undef cb | |||
} | |||
template <typename ctype> | |||
void FillImpl::exec_internal(ctype *dst, size_t size) { | |||
auto value = static_cast<ctype>(param().value); | |||
for (size_t i = 0; i < size; ++i) { | |||
dst[i] = value; | |||
} | |||
} | |||
} // namespace naive | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen | |||
@@ -0,0 +1,33 @@ | |||
/** | |||
* \file dnn/src/naive/fill/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
namespace megdnn { | |||
namespace naive { | |||
class FillImpl : public Fill { | |||
public: | |||
using Fill::Fill; | |||
void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout &) override { | |||
return 0; | |||
} | |||
private: | |||
template <typename ctype> | |||
void exec_internal(ctype *dst, size_t size); | |||
}; | |||
} // namespace naive | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -40,6 +40,7 @@ | |||
#include "src/naive/elemwise_multi_type/opr_impl.h" | |||
#include "src/naive/eye/opr_impl.h" | |||
#include "src/naive/fake_quant/opr_impl.h" | |||
#include "src/naive/fill/opr_impl.h" | |||
#include "src/naive/flip/opr_impl.h" | |||
#include "src/naive/gaussian_blur/opr_impl.h" | |||
#include "src/naive/group_local/opr_impl.h" | |||
@@ -0,0 +1,51 @@ | |||
/** | |||
* \file dnn/src/rocm/fill/fill.cpp.hip | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "hcc_detail/hcc_defs_prologue.h" | |||
#include "hip_header.h" | |||
#include "megdnn/dtype.h" | |||
#include "src/rocm/fill/fill.h.hip" | |||
#include "src/rocm/utils.h.hip" | |||
namespace { | |||
template <typename T> | |||
__global__ void kernel(T *dst, T value, uint32_t size) { | |||
int32_t i = threadIdx.x + blockIdx.x * blockDim.x; | |||
if (i < size) { | |||
dst[i] = value; | |||
} | |||
} | |||
} // anonymous namespace | |||
namespace megdnn { | |||
namespace rocm { | |||
namespace fill { | |||
template <typename T> | |||
void exec_internal(T *dst, T value, size_t size, hipStream_t stream) { | |||
hipLaunchKernelGGL( | |||
(kernel<T>), | |||
dim3(DIVUP(size, NR_THREADS)), | |||
dim3(NR_THREADS), | |||
0, stream, dst, value, size); | |||
after_kernel_launch(); | |||
} | |||
#define INST(T) template void exec_internal<T>(T *, \ | |||
T, size_t, hipStream_t); | |||
#define cb(DType) INST(typename DTypeTrait<DType>::ctype) | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
} // namespace fill | |||
} // namespace rocm | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,25 @@ | |||
/** | |||
* \file dnn/src/rocm/fill/fill.h.hip | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <stdint.h> | |||
#include "hip_header.h" | |||
namespace megdnn { | |||
namespace rocm { | |||
namespace fill { | |||
template <typename T> | |||
void exec_internal(T *dst, T value, size_t size, hipStream_t stream); | |||
} // namespace fill | |||
} // namespace rocm | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,36 @@ | |||
/** | |||
* \file dnn/src/rocm/fill/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "hcc_detail/hcc_defs_prologue.h" | |||
#include "src/rocm/fill/opr_impl.h" | |||
#include "src/rocm/fill/fill.h.hip" | |||
#include "src/rocm/utils.h" | |||
namespace megdnn { | |||
namespace rocm { | |||
void FillImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
check_exec(dst.layout, workspace.size); | |||
auto stream = hip_stream(handle()); | |||
auto size = dst.layout.total_nr_elems(); | |||
#define cb(DType) \ | |||
if (dst.layout.dtype == DType()) { \ | |||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||
fill::exec_internal<ctype>(dst.ptr<ctype>(), \ | |||
static_cast<ctype>(param().value), size, stream); \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
#undef cb | |||
} | |||
} // namespace rocm | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,26 @@ | |||
/** | |||
* \file dnn/src/rocm/fill/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
namespace megdnn { | |||
namespace rocm { | |||
class FillImpl final : public Fill { | |||
public: | |||
using Fill::Fill; | |||
void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } | |||
}; | |||
} // namespace rocm | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -37,6 +37,7 @@ | |||
#include "src/rocm/sleep/opr_impl.h" | |||
#include "src/rocm/batch_normalization/opr_impl.h" | |||
#include "src/rocm/param_pack/opr_impl.h" | |||
#include "src/rocm/fill/opr_impl.h" | |||
#include <miopen/version.h> | |||
#include <hip/hip_version.h> | |||
@@ -176,6 +177,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNForward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ParamPackConcat); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Fill); | |||
#pragma GCC diagnostic push | |||
#pragma GCC diagnostic ignored "-Wpragmas" | |||
@@ -0,0 +1,37 @@ | |||
/** | |||
* \file dnn/test/common/fill.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/handle.h" | |||
#include "megdnn/oprs/general.h" | |||
#include "src/common/opr_trait.h" | |||
#include "test/common/checker.h" | |||
namespace megdnn { | |||
namespace test { | |||
namespace fill { | |||
inline void run_fill_test(Handle* handle, DType dtype) { | |||
Checker<Fill> checker(handle); | |||
for (float value : {-1.23, 0.0, 0.001, 234.0, 2021.072}) { | |||
checker.set_param({value}); | |||
checker.set_dtype(0, dtype); | |||
checker.exec(TensorShapeArray{{1, 1}}); | |||
checker.exec(TensorShapeArray{{2, 3, 4}}); | |||
} | |||
} | |||
} // namespace fill | |||
} // namespace test | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,36 @@ | |||
/** | |||
* \file dnn/test/cuda/fill.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "test/common/fill.h" | |||
#include "test/cuda/fixture.h" | |||
namespace megdnn { | |||
namespace test { | |||
namespace fill { | |||
TEST_F(CUDA, FILL_F32) { | |||
run_fill_test(handle_cuda(), dtype::Float32{}); | |||
} | |||
TEST_F(CUDA, FILL_I32) { | |||
run_fill_test(handle_cuda(), dtype::Int32{}); | |||
} | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
TEST_F(CUDA, FILL_F16) { | |||
run_fill_test(handle_cuda(), dtype::Float16{}); | |||
} | |||
#endif | |||
} // namespace fill | |||
} // namespace test | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,36 @@ | |||
/** | |||
* \file dnn/test/rocm/fill.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "test/common/fill.h" | |||
#include "test/rocm/fixture.h" | |||
namespace megdnn { | |||
namespace test { | |||
namespace fill { | |||
TEST_F(ROCM, FILL_F32) { | |||
run_fill_test(handle_rocm(), dtype::Float32{}); | |||
} | |||
TEST_F(ROCM, FILL_I32) { | |||
run_fill_test(handle_rocm(), dtype::Int32{}); | |||
} | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
TEST_F(ROCM, FILL_F16) { | |||
run_fill_test(handle_rocm(), dtype::Float16{}); | |||
} | |||
#endif | |||
} // namespace fill | |||
} // namespace test | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |