GitOrigin-RevId: 43016ffa2b
tags/v1.8.0
@@ -998,6 +998,28 @@ protected: | |||||
void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); | void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); | ||||
}; | }; | ||||
class Diag : public OperatorBase { | |||||
DEF_OPR_IMPL(Diag, OperatorBase, 1, 1); | |||||
DEF_OPR_PARAM(Diag); | |||||
public: | |||||
/** | |||||
* \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.diag.html | |||||
*/ | |||||
virtual void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout(const TensorLayout& src, TensorLayout& dst); | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout& dst) = 0; | |||||
protected: | |||||
void check_exec( | |||||
const TensorLayout& src, const TensorLayout& dst, | |||||
size_t workspace_in_bytes); | |||||
}; | |||||
class IndexingOneHotBase : public OperatorBase { | class IndexingOneHotBase : public OperatorBase { | ||||
DEF_OPR_IMPL_CTOR(IndexingOneHotBase, OperatorBase); | DEF_OPR_IMPL_CTOR(IndexingOneHotBase, OperatorBase); | ||||
DEF_OPR_PARAM(Axis); | DEF_OPR_PARAM(Axis); | ||||
@@ -759,6 +759,14 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) | |||||
'dtype', Doc('dtype', 'data type of output value'), | 'dtype', Doc('dtype', 'data type of output value'), | ||||
'DTypeEnum::Float32')) | 'DTypeEnum::Float32')) | ||||
(pdef('Diag'). | |||||
add_fields( | |||||
'int32', | |||||
Doc('k', 'Index of the diagonal: 0 (the default) refers to the main ' | |||||
'diagonal, a positive value refers to an upper diagonal, and a ' | |||||
'negative value to a lower diagonal.'), | |||||
0)) | |||||
(pdef('UniformRNG', version=0, is_legacy=True). | (pdef('UniformRNG', version=0, is_legacy=True). | ||||
add_fields('uint64', 'seed', 0)) | add_fields('uint64', 'seed', 0)) | ||||
@@ -0,0 +1,47 @@ | |||||
/** | |||||
* \file dnn/src/common/diag.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megdnn/oprs.h" | |||||
#include "src/common/utils.h" | |||||
namespace megdnn { | |||||
void Diag::deduce_layout(const TensorLayout& src, TensorLayout& dst) { | |||||
megdnn_assert( | |||||
src.ndim == 1 || src.ndim == 2, "Only support vector or matrix as input."); | |||||
int k = param().k; | |||||
if (src.ndim == 1) { | |||||
size_t o = src.total_nr_elems() + std::abs(k); | |||||
dst = TensorLayout(TensorShape({o, o}), src.dtype); | |||||
} else { // src.ndim == 2 | |||||
size_t m = src.shape[0]; | |||||
size_t n = src.shape[1]; | |||||
size_t o = (k >= 0 ? std::min(n - k, m) : std::min(m + k, n)); | |||||
megdnn_assert(o > 0, "The moved diagonal is out of the input matrix."); | |||||
dst = TensorLayout(TensorShape({o}), src.dtype); | |||||
} | |||||
} | |||||
void Diag::check_exec( | |||||
const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { | |||||
TensorLayout dst_expected; | |||||
megdnn_assert_eq_dtype(src, dst); | |||||
deduce_layout(src, dst_expected); | |||||
megdnn_assert_eq_layout(dst_expected, dst); | |||||
megdnn_assert_contiguous(dst); | |||||
auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); | |||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
} | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -146,6 +146,7 @@ private: | |||||
cb(BatchedSetMeshIndexing) \ | cb(BatchedSetMeshIndexing) \ | ||||
cb(Linspace) \ | cb(Linspace) \ | ||||
cb(Eye) \ | cb(Eye) \ | ||||
cb(Diag) \ | |||||
cb(SleepForward) \ | cb(SleepForward) \ | ||||
cb(UniformRNG) \ | cb(UniformRNG) \ | ||||
cb(GaussianRNG) \ | cb(GaussianRNG) \ | ||||
@@ -88,6 +88,7 @@ DEF(IndexingRemapForward, 3, true, true); | |||||
DEF(IndexingRemapBackward, 3, true, false); | DEF(IndexingRemapBackward, 3, true, false); | ||||
DEF(Linspace, 1, true, false); | DEF(Linspace, 1, true, false); | ||||
DEF(Eye, 1, true, false); | DEF(Eye, 1, true, false); | ||||
DEF(Diag, 2, true, true); | |||||
DEF(Flip, 2, true, true); | DEF(Flip, 2, true, true); | ||||
DEF(ROICopy, 2, true, true); | DEF(ROICopy, 2, true, true); | ||||
DEF(Rotate, 2, true, true); | DEF(Rotate, 2, true, true); | ||||
@@ -0,0 +1,87 @@ | |||||
/** | |||||
* \file dnn/src/cuda/diag/diag.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megdnn/dtype.h" | |||||
#include "src/cuda/diag/diag.cuh" | |||||
#include "src/cuda/utils.cuh" | |||||
namespace { | |||||
template <typename T> | |||||
__global__ void kernel_to_vector( | |||||
T* src, T* dst, ptrdiff_t start, ptrdiff_t size, ptrdiff_t stride_sum, | |||||
ptrdiff_t dst_stride) { | |||||
ptrdiff_t i = threadIdx.x + blockIdx.x * blockDim.x; | |||||
if (i < size) { | |||||
dst[dst_stride * i] = src[start + stride_sum * i]; | |||||
} | |||||
} | |||||
template <typename T> | |||||
__global__ void kernel_to_matrix( | |||||
T* src, T* dst, ptrdiff_t offset, ptrdiff_t n, ptrdiff_t k, | |||||
ptrdiff_t dst_stride0, ptrdiff_t dst_stride1, ptrdiff_t src_stride) { | |||||
ptrdiff_t i = threadIdx.x + blockIdx.x * blockDim.x; | |||||
ptrdiff_t x = i % n; | |||||
ptrdiff_t y = i / n; | |||||
ptrdiff_t p = dst_stride0 * y + dst_stride1 * x; | |||||
if (i < n * n) { | |||||
if (y + k == x) | |||||
dst[p] = src[src_stride * (y - offset)]; | |||||
else | |||||
dst[p] = 0; | |||||
} | |||||
} | |||||
} // anonymous namespace | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace diag { | |||||
template <typename T> | |||||
void exec_internal_to_vector( | |||||
T* src, T* dst, ptrdiff_t start, ptrdiff_t size, ptrdiff_t stride_sum, | |||||
ptrdiff_t dst_stride, cudaStream_t stream) { | |||||
kernel_to_vector<T><<<DIVUP(size, NR_THREADS), NR_THREADS, 0, stream>>>( | |||||
src, dst, start, size, stride_sum, dst_stride); | |||||
after_kernel_launch(); | |||||
} | |||||
template <typename T> | |||||
void exec_internal_to_matrix( | |||||
T* src, T* dst, ptrdiff_t offset, ptrdiff_t n, ptrdiff_t k, | |||||
ptrdiff_t dst_stride0, ptrdiff_t dst_stride1, ptrdiff_t src_stride, | |||||
cudaStream_t stream) { | |||||
kernel_to_matrix<T><<<DIVUP(n * n, NR_THREADS), NR_THREADS, 0, stream>>>( | |||||
src, dst, offset, n, k, dst_stride0, dst_stride1, src_stride); | |||||
after_kernel_launch(); | |||||
} | |||||
#define INST(T) \ | |||||
template void exec_internal_to_vector<T>( \ | |||||
T*, T*, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, cudaStream_t); | |||||
#define cb(DType) INST(typename DTypeTrait<DType>::ctype) | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||||
cb(::megdnn::dtype::Bool) | |||||
#undef INST | |||||
#undef cb | |||||
#define INST(T) \ | |||||
template void exec_internal_to_matrix<T>( \ | |||||
T*, T*, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, \ | |||||
cudaStream_t); | |||||
#define cb(DType) INST(typename DTypeTrait<DType>::ctype) | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) cb(::megdnn::dtype::Bool) | |||||
} // namespace diag | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,33 @@ | |||||
/** | |||||
* \file dnn/src/cuda/diag/diag.cuh | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <cuda_runtime_api.h> | |||||
#include <stdint.h> | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace diag { | |||||
template <typename T> | |||||
void exec_internal_to_vector( | |||||
T* src, T* dst, ptrdiff_t start, ptrdiff_t size, ptrdiff_t stride_sum, | |||||
ptrdiff_t dst_stride, cudaStream_t stream); | |||||
template <typename T> | |||||
void exec_internal_to_matrix( | |||||
T* src, T* dst, ptrdiff_t start, ptrdiff_t n, ptrdiff_t k, | |||||
ptrdiff_t dst_stride0, ptrdiff_t dst_stride1, ptrdiff_t src_stride, | |||||
cudaStream_t stream); | |||||
} // namespace diag | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,61 @@ | |||||
/** | |||||
* \file dnn/src/cuda/diag/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/cuda/diag/opr_impl.h" | |||||
#include "src/cuda/diag/diag.cuh" | |||||
#include "src/cuda/utils.h" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
void DiagImpl::exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||||
check_exec(src.layout, dst.layout, workspace.size); | |||||
if (src.layout.ndim == 2) { | |||||
auto src_stride0 = src.layout.stride[0]; | |||||
auto src_stride1 = src.layout.stride[1]; | |||||
auto dst_stride = dst.layout.stride[0]; | |||||
auto start = | |||||
(param().k >= 0) ? param().k * src_stride1 : -param().k * src_stride0; | |||||
#define cb(DType) \ | |||||
if (dst.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \ | |||||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||||
diag::exec_internal_to_vector<ctype>( \ | |||||
src.ptr<ctype>(), dst.ptr<ctype>(), start, dst.layout.shape[0], \ | |||||
src_stride0 + src_stride1, dst_stride, cuda_stream(handle())); \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||||
cb(::megdnn::dtype::Bool) | |||||
#undef cb | |||||
} else { | |||||
auto n = dst.layout.shape[0]; | |||||
auto src_stride = src.layout.stride[0]; | |||||
auto dst_stride0 = dst.layout.stride[0]; | |||||
auto dst_stride1 = dst.layout.stride[1]; | |||||
auto offset = (param().k >= 0) ? 0 : -param().k; | |||||
#define cb(DType) \ | |||||
if (dst.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \ | |||||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||||
diag::exec_internal_to_matrix<ctype>( \ | |||||
src.ptr<ctype>(), dst.ptr<ctype>(), offset, n, param().k, dst_stride0, \ | |||||
dst_stride1, src_stride, cuda_stream(handle())); \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||||
cb(::megdnn::dtype::Bool) | |||||
#undef cb | |||||
} | |||||
} | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,31 @@ | |||||
/** | |||||
* \file dnn/src/cuda/diag/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
class DiagImpl final : public Diag { | |||||
public: | |||||
using Diag::Diag; | |||||
void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout& dst) override { | |||||
return 0; | |||||
} | |||||
}; | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -33,6 +33,7 @@ | |||||
#include "src/cuda/dct/opr_impl.h" | #include "src/cuda/dct/opr_impl.h" | ||||
#include "src/cuda/deformable_conv/opr_impl.h" | #include "src/cuda/deformable_conv/opr_impl.h" | ||||
#include "src/cuda/deformable_ps_roi_pooling/opr_impl.h" | #include "src/cuda/deformable_ps_roi_pooling/opr_impl.h" | ||||
#include "src/cuda/diag/opr_impl.h" | |||||
#include "src/cuda/dot/opr_impl.h" | #include "src/cuda/dot/opr_impl.h" | ||||
#include "src/cuda/dropout/opr_impl.h" | #include "src/cuda/dropout/opr_impl.h" | ||||
#include "src/cuda/elemwise/opr_impl.h" | #include "src/cuda/elemwise/opr_impl.h" | ||||
@@ -154,6 +155,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedIncrMeshIndexing); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedSetMeshIndexing); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedSetMeshIndexing); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Linspace); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(Linspace); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Eye); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(Eye); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Diag); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(UniformRNG); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(UniformRNG); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GaussianRNG); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(GaussianRNG); | ||||
@@ -0,0 +1,60 @@ | |||||
/** | |||||
* \file dnn/src/naive/diag/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/naive/diag/opr_impl.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/naive/handle.h" | |||||
namespace megdnn { | |||||
namespace naive { | |||||
template <typename ctype> | |||||
void DiagImpl::exec_internal( | |||||
ctype* src, const TensorLayout& src_layout, ctype* dst, | |||||
const TensorLayout& dst_layout, size_t input_ndim, int k) { | |||||
if (input_ndim == 1) { | |||||
size_t l = src_layout.shape[0]; | |||||
size_t s0 = dst_layout.stride[0]; | |||||
size_t s1 = dst_layout.stride[1]; | |||||
size_t start = (k >= 0) ? (k * s1) : (-k * s0); | |||||
for (size_t i = 0; i < dst_layout.shape[0]; ++i) | |||||
for (size_t j = 0; j < dst_layout.shape[1]; ++j) | |||||
dst[i * s0 + j * s1] = 0; | |||||
for (size_t i = 0; i < l; ++i) | |||||
dst[start + i * (s0 + s1)] = src[i]; | |||||
} else { | |||||
size_t l = dst_layout.shape[0]; | |||||
size_t s0 = src_layout.stride[0]; | |||||
size_t s1 = src_layout.stride[1]; | |||||
size_t start = (k >= 0) ? (k * s1) : (-k * s0); | |||||
for (size_t i = 0; i < l; ++i) | |||||
dst[i] = src[start + i * (s0 + s1)]; | |||||
} | |||||
} | |||||
void DiagImpl::exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||||
check_exec(src.layout, dst.layout, workspace.size); | |||||
#define cb(DType) \ | |||||
if (src.layout.dtype == DType()) { \ | |||||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal<ctype>( \ | |||||
src.ptr<ctype>(), src.layout, dst.ptr<ctype>(), dst.layout, \ | |||||
src.layout.ndim, param().k)); \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||||
cb(::megdnn::dtype::Bool) | |||||
#undef cb | |||||
} | |||||
} // namespace naive | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,37 @@ | |||||
/** | |||||
* \file dnn/src/naive/diag/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
namespace megdnn { | |||||
namespace naive { | |||||
class DiagImpl : public Diag { | |||||
public: | |||||
using Diag::Diag; | |||||
void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
private: | |||||
template <typename ctype> | |||||
void exec_internal( | |||||
ctype* src, const TensorLayout& src_layout, ctype* dst, | |||||
const TensorLayout& dst_layout, size_t input_ndim, int k); | |||||
}; | |||||
} // namespace naive | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -34,6 +34,7 @@ | |||||
#include "src/naive/dct/opr_impl.h" | #include "src/naive/dct/opr_impl.h" | ||||
#include "src/naive/deformable_conv/opr_impl.h" | #include "src/naive/deformable_conv/opr_impl.h" | ||||
#include "src/naive/deformable_ps_roi_pooling/opr_impl.h" | #include "src/naive/deformable_ps_roi_pooling/opr_impl.h" | ||||
#include "src/naive/diag/opr_impl.h" | |||||
#include "src/naive/dot/opr_impl.h" | #include "src/naive/dot/opr_impl.h" | ||||
#include "src/naive/dropout/opr_impl.h" | #include "src/naive/dropout/opr_impl.h" | ||||
#include "src/naive/elemwise/opr_impl.h" | #include "src/naive/elemwise/opr_impl.h" | ||||
@@ -0,0 +1,42 @@ | |||||
/** | |||||
* \file dnn/test/cuda/diag.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "test/cuda/fixture.h" | |||||
#include "megdnn/oprs.h" | |||||
#include "test/common/checker.h" | |||||
namespace megdnn { | |||||
namespace test { | |||||
TEST_F(CUDA, DIAG) { | |||||
Checker<Diag> checker(handle_cuda()); | |||||
for (DType dtype : | |||||
std::vector<DType>{dtype::Float16(), dtype::Int32(), dtype::Float32()}) | |||||
for (int k = -5; k < 5; ++k) { | |||||
checker.set_param({k}); | |||||
checker.set_dtype(0, dtype); | |||||
checker.set_dtype(1, dtype); | |||||
size_t absk = static_cast<size_t>(std::abs(k)); | |||||
checker.exec(TensorShapeArray{{8}, {8 + absk, 8 + absk}}); | |||||
auto oshape = [&](int n, int m) -> TensorShape { | |||||
size_t o = (k >= 0 ? std::min(n - k, m) : std::min(m + k, n)); | |||||
return {o, o}; | |||||
}; | |||||
checker.exec(TensorShapeArray{{8, 6}, oshape(8, 6)}); | |||||
checker.exec(TensorShapeArray{{6, 8}, oshape(6, 8)}); | |||||
checker.exec(TensorShapeArray{{8, 8}, oshape(8, 8)}); | |||||
} | |||||
} | |||||
} // namespace test | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,111 @@ | |||||
/** | |||||
* \file dnn/test/naive/diag.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megdnn/dtype.h" | |||||
#include "megdnn/oprs.h" | |||||
#include "test/common/checker.h" | |||||
#include "test/naive/fixture.h" | |||||
namespace megdnn { | |||||
namespace test { | |||||
TEST_F(NAIVE, DiagVector2Matrix) { | |||||
Checker<Diag> checker(handle(), false); | |||||
Diag::Param param; | |||||
param.k = 0; | |||||
checker.set_param(param).exect( | |||||
Testcase{TensorValue({3}, dtype::Float32(), {1, 2, 3}), {}}, | |||||
Testcase{ | |||||
{}, | |||||
// clang-format off | |||||
TensorValue({3, 3}, dtype::Float32(), {1, 0, 0, | |||||
0, 2, 0, | |||||
0, 0, 3})}); | |||||
// clang-format on | |||||
} | |||||
TEST_F(NAIVE, DiagVector2Matrix_PositiveK) { | |||||
Checker<Diag> checker(handle(), false); | |||||
Diag::Param param; | |||||
param.k = 1; | |||||
checker.set_param(param).exect( | |||||
Testcase{TensorValue({3}, dtype::Float32(), {1, 2, 3}), {}}, | |||||
Testcase{ | |||||
{}, | |||||
// clang-format off | |||||
TensorValue({4, 4}, dtype::Float32(), {0, 1, 0, 0, | |||||
0, 0, 2, 0, | |||||
0, 0, 0, 3, | |||||
0, 0, 0, 0,})}); | |||||
// clang-format on | |||||
} | |||||
TEST_F(NAIVE, DiagVector2Matrix_NegativeK) { | |||||
Checker<Diag> checker(handle(), false); | |||||
Diag::Param param; | |||||
param.k = -1; | |||||
checker.set_param(param).exect( | |||||
Testcase{TensorValue({3}, dtype::Float32(), {1, 2, 3}), {}}, | |||||
Testcase{ | |||||
{}, | |||||
// clang-format off | |||||
TensorValue({4, 4}, dtype::Float32(), {0, 0, 0, 0, | |||||
1, 0, 0, 0, | |||||
0, 2, 0, 0, | |||||
0, 0, 3, 0,})}); | |||||
// clang-format on | |||||
} | |||||
TEST_F(NAIVE, DiagMatrix2Vector) { | |||||
Checker<Diag> checker(handle(), false); | |||||
Diag::Param param; | |||||
param.k = 0; | |||||
checker.set_param(param).exect( | |||||
// clang-format off | |||||
Testcase{TensorValue({3, 3}, dtype::Float32(), {1, 2, 3, | |||||
4, 5, 6, | |||||
7, 8, 9}), | |||||
// clang-format on | |||||
{}}, | |||||
Testcase{{}, TensorValue({3}, dtype::Float32(), {1, 5, 9})}); | |||||
} | |||||
TEST_F(NAIVE, DiagMatrix2Vector_PositiveK) { | |||||
Checker<Diag> checker(handle(), false); | |||||
Diag::Param param; | |||||
param.k = 1; | |||||
checker.set_param(param).exect( | |||||
// clang-format off | |||||
Testcase{TensorValue({3, 3}, dtype::Float32(), {1, 2, 3, | |||||
4, 5, 6, | |||||
7, 8, 9}), | |||||
// clang-format on | |||||
{}}, | |||||
Testcase{{}, TensorValue({2}, dtype::Float32(), {2, 6})}); | |||||
} | |||||
TEST_F(NAIVE, DiagMatrix2Vector_NegativeK) { | |||||
Checker<Diag> checker(handle(), false); | |||||
Diag::Param param; | |||||
param.k = -1; | |||||
checker.set_param(param).exect( | |||||
// clang-format off | |||||
Testcase{TensorValue({3, 3}, dtype::Float32(), {1, 2, 3, | |||||
4, 5, 6, | |||||
7, 8, 9}), | |||||
// clang-format on | |||||
{}}, | |||||
Testcase{{}, TensorValue({2}, dtype::Float32(), {4, 8})}); | |||||
} | |||||
} // namespace test | |||||
} // namespace megdnn |
@@ -28,6 +28,7 @@ __all__ = [ | |||||
"concat", | "concat", | ||||
"cond_take", | "cond_take", | ||||
"cumsum", | "cumsum", | ||||
"diag", | |||||
"expand_dims", | "expand_dims", | ||||
"eye", | "eye", | ||||
"flatten", | "flatten", | ||||
@@ -53,6 +54,32 @@ __all__ = [ | |||||
] | ] | ||||
def diag(inp, k=0) -> Tensor: | |||||
r"""If ``inp`` is a 1D tensor, then returns a 2D tensor with the elements of ``inp`` as the diagonal. | |||||
If ``inp`` is a 2D tensor, then returns a 1D tensor with the diagonal elements of ``inp``. | |||||
Args: | |||||
inp: input tensor. | |||||
k: diagonal in consider. Use :math:`k=0` for the main diagonal, :math:`k>0` for diagonals above the | |||||
main diagonal, and :math:`k<0` for diagonals below the main diagonal. Default: 0. | |||||
Returns: | |||||
the extracted diagonal or constructed diagonal array. | |||||
Examples: | |||||
>>> inp = F.arange(6, dtype='int32').reshape(2,3) | |||||
>>> out = F.diag(inp, k=1) | |||||
>>> out | |||||
Tensor([1 5], dtype=int32, device=xpux:0) | |||||
>>> F.diag(out) | |||||
Tensor([[1 0] | |||||
[0 5]], dtype=int32, device=xpux:0) | |||||
""" | |||||
op = builtin.Diag(k=k) | |||||
(result,) = apply(op, inp) | |||||
return result | |||||
def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor: | def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor: | ||||
r"""Returns a 2D tensor with ones on the diagonal and zeros elsewhere. | r"""Returns a 2D tensor with ones on the diagonal and zeros elsewhere. | ||||
@@ -42,6 +42,26 @@ def test_eye(): | |||||
) | ) | ||||
@pytest.mark.parametrize("is_varnode", [False, True]) | |||||
def test_diag(is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
shapes = [(10, 10), (6, 9), (8, 7), (8,)] | |||||
cases = [] | |||||
for shp in shapes: | |||||
cases.append({"input": [np.random.random(shp).astype("float32")]}) | |||||
for axis in range(-2, 3): | |||||
def run(data): | |||||
return F.diag(data, k=axis) | |||||
opr_test(cases, run, ref_fn=lambda x: np.diag(x, axis), network=network) | |||||
def test_full(): | def test_full(): | ||||
shape = (2, 3) | shape = (2, 3) | ||||
values = [True, 4, 5.0] | values = [True, 4, 5.0] | ||||
@@ -433,6 +433,19 @@ OP_TRAIT_REG(Eye, Eye).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace | } // namespace | ||||
namespace { | namespace { | ||||
namespace diag { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const Diag&>(def); | |||||
mgb_assert(inputs.size() == 1); | |||||
cg::OperatorNodeConfig config{op.make_name()}; | |||||
opr::Diag::Param param{op.k}; | |||||
return opr::Diag::make(inputs[0], param, config); | |||||
} | |||||
OP_TRAIT_REG(Diag, Diag).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace diag | |||||
} // namespace | |||||
namespace { | |||||
namespace roi_pooling { | namespace roi_pooling { | ||||
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const ROIPooling&>(def); | auto&& op = static_cast<const ROIPooling&>(def); | ||||
@@ -240,6 +240,8 @@ def Eye: MgbHashableOp<"Eye", [EyeParam]> { | |||||
); | ); | ||||
} | } | ||||
def Diag: MgbHashableOp<"Diag", [DiagParam]>; | |||||
def GetVarShape : MgbHashableOp<"GetVarShape", [OptionalAxisV1Param]>; | def GetVarShape : MgbHashableOp<"GetVarShape", [OptionalAxisV1Param]>; | ||||
def Concat: MgbHashableOp<"Concat", [AxisParam]> { | def Concat: MgbHashableOp<"Concat", [AxisParam]> { | ||||
@@ -75,6 +75,91 @@ struct MegDNNOprInitInputsModifier<IndexingSetOneHot> | |||||
} // namespace opr | } // namespace opr | ||||
} // namespace mgb | } // namespace mgb | ||||
/* ==================== Diag ==================== */ | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Diag); | |||||
MEGDNN_OPR_INIT1(Diag, "diag") | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(Diag) { | |||||
if (wrt_idx == 0) { | |||||
SymbolVar data_sym{opr.input(0)}; | |||||
return DiagBackward::make(data_sym.symshape(), out_grad[0], opr.param()).node(); | |||||
} | |||||
return InvalidGrad::make(opr, wrt_idx); | |||||
} | |||||
#endif | |||||
/* ==================== DiagBackward ==================== */ | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(DiagBackward); | |||||
DiagBackward::DiagBackward( | |||||
VarNode* shape, VarNode* value, const Param& param, | |||||
const OperatorNodeConfig& config) | |||||
: Super{shape->owner_graph(), config, "diag_backward", {shape, value}}, | |||||
m_param{param} { | |||||
add_input({shape, value}); | |||||
add_output(None)->dtype(value->dtype()); | |||||
add_equivalence_component<PODHash<Param>>(&m_param); | |||||
} | |||||
SymbolVar DiagBackward::make( | |||||
SymbolVar shape, SymbolVar value, const Param& param, | |||||
const OperatorNodeConfig& config) { | |||||
return shape.insert_single_output_opr<DiagBackward>( | |||||
shape.node(), value.node(), param, config); | |||||
} | |||||
cg::OperatorNodeBase::NodeProp* DiagBackward::do_make_node_prop() const { | |||||
auto prop = Super::do_make_node_prop(); | |||||
using D = NodeProp::DepType; | |||||
prop->add_dep_type(input(0), D::HOST_VALUE); | |||||
return prop; | |||||
} | |||||
void DiagBackward::scn_do_execute() { | |||||
auto&& dest = output(0)->dev_tensor(); | |||||
auto&& val = input(1)->dev_tensor(); | |||||
auto&& layout = dest.layout(); | |||||
mgb_assert(layout.ndim == 1 || layout.ndim == 2); | |||||
if (layout.ndim == 2) { | |||||
dev_tensor_memset(dest, 0); | |||||
size_t offset = (m_param.k >= 0) ? (m_param.k * layout.stride[1]) | |||||
: (-m_param.k * layout.stride[0]); | |||||
auto dest_sub = dest.sub(SubTensorSpec::make_from_offset_elem( | |||||
{val.shape(), {layout.stride[0] + layout.stride[1]}, val.dtype()}, | |||||
offset)); | |||||
dest_sub.copy_from_fixlayout(val); | |||||
} else { | |||||
auto&& opr = m_dnn_opr; | |||||
if (!opr) { | |||||
opr = intl::create_megdnn_opr<megdnn::Diag>(comp_node()); | |||||
opr->param() = m_param; | |||||
} | |||||
opr->exec(val.as_megdnn(), dest.as_megdnn(), {}); | |||||
} | |||||
} | |||||
void DiagBackward::record_execute_deps(ExecDependencyArray& deps) { | |||||
deps.emplace_back(std::make_unique<intl::MegDNNGraphDep>(std::move(m_dnn_opr))); | |||||
} | |||||
void DiagBackward::init_output_static_infer_desc() { | |||||
using namespace cg::static_infer; | |||||
auto&& mgr = owner_graph()->static_infer_manager(); | |||||
auto infer_shape = [](TensorShape& dest, const InpVal& inp) { | |||||
cg::copy_tensor_value_to_shape(dest, inp.val.at(0).value()); | |||||
return true; | |||||
}; | |||||
mgr.register_shape_infer( | |||||
output(0), {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_shape}); | |||||
} | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(DiagBackward) { | |||||
return InvalidGrad::make(opr, wrt_idx); | |||||
} | |||||
#endif | |||||
/* ==================== IndexingOneHot ==================== */ | /* ==================== IndexingOneHot ==================== */ | ||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(IndexingOneHot); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(IndexingOneHot); | ||||
MEGDNN_OPR_INIT2(IndexingOneHot, "indexing_one_hot") | MEGDNN_OPR_INIT2(IndexingOneHot, "indexing_one_hot") | ||||
@@ -1,3 +1,25 @@ | |||||
decl_opr( | |||||
'Diag', | |||||
desc='Extract a diagonal or construct a diagonal array', | |||||
inputs=[ | |||||
Doc('k', 'Index of the diagonal: 0 (the default) refers to the main ' | |||||
'diagonal, a positive value refers to an upper diagonal, and a ' | |||||
'negative value to a lower diagonal.') | |||||
], | |||||
params='Diag' | |||||
) | |||||
decl_opr( | |||||
'DiagBackward', | |||||
desc='backward function of Diag', | |||||
inputs=[ | |||||
Doc('k', 'Index of the diagonal: 0 (the default) refers to the main ' | |||||
'diagonal, a positive value refers to an upper diagonal, and a ' | |||||
'negative value to a lower diagonal.') | |||||
], | |||||
params='Diag' | |||||
) | |||||
decl_opr('IndexingOneHot', pyname='_indexing_one_hot', | decl_opr('IndexingOneHot', pyname='_indexing_one_hot', | ||||
inputs=['src', 'index'], | inputs=['src', 'index'], | ||||
params=[('axis', 'Axis')]) | params=[('axis', 'Axis')]) | ||||
@@ -25,6 +25,8 @@ MGB_SEREG_MODIFY_SUBTENSOR_OPR(BatchedSetMeshIndexing); | |||||
namespace mgb { | namespace mgb { | ||||
namespace opr { | namespace opr { | ||||
MGB_SEREG_OPR(Diag, 1); | |||||
MGB_SEREG_OPR(DiagBackward, 2); | |||||
MGB_SEREG_OPR(IndexingOneHot, 2); | MGB_SEREG_OPR(IndexingOneHot, 2); | ||||
MGB_SEREG_OPR(IndexingRemap, 2); | MGB_SEREG_OPR(IndexingRemap, 2); | ||||
MGB_SEREG_OPR(IndexingRemapBackward, 3); | MGB_SEREG_OPR(IndexingRemapBackward, 3); | ||||
@@ -19,6 +19,37 @@ | |||||
namespace mgb { | namespace mgb { | ||||
namespace opr { | namespace opr { | ||||
MGB_DEFINE_OPR_CLASS(Diag, intl::MegDNNOprWrapperFwd<megdnn::Diag>) // { | |||||
public: | |||||
MGE_WIN_DECLSPEC_FUC Diag( | |||||
VarNode* src, const Param& param, const OperatorNodeConfig& config); | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||||
SymbolVar src, const Param& param, const OperatorNodeConfig& config = {}); | |||||
}; | |||||
MGB_DEFINE_OPR_CLASS(DiagBackward, cg::SingleCNOperatorNodeBase) // { | |||||
public: | |||||
using Param = megdnn::Diag::Param; | |||||
MGE_WIN_DECLSPEC_FUC DiagBackward( | |||||
VarNode* shape, VarNode* value, const Param& param, | |||||
const OperatorNodeConfig& config); | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||||
SymbolVar shape, SymbolVar value, const Param& param, | |||||
const OperatorNodeConfig& config = {}); | |||||
const Param& param() const { return m_param; } | |||||
private: | |||||
Param m_param; | |||||
intl::UniqPtrWithCN<megdnn::Diag> m_dnn_opr; | |||||
void scn_do_execute() override; | |||||
void init_output_static_infer_desc() override; | |||||
NodeProp* do_make_node_prop() const override; | |||||
void record_execute_deps(ExecDependencyArray& deps) override; | |||||
}; | |||||
MGB_DEFINE_OPR_CLASS( | MGB_DEFINE_OPR_CLASS( | ||||
IndexingOneHot, intl::MegDNNOprWrapperFwd<megdnn::IndexingOneHotForward>) // { | IndexingOneHot, intl::MegDNNOprWrapperFwd<megdnn::IndexingOneHotForward>) // { | ||||
public: | public: | ||||
@@ -52,6 +52,37 @@ void gen_index_onehot(int* max_value, HostTensorND& dest) { | |||||
} | } | ||||
} | } | ||||
void test_diag(int32_t axis, const TensorShapeArray& test_cases) { | |||||
using Checker = AutoOprChecker<1, 1>; | |||||
auto nopr = megdnn_naive_handle()->create_operator<megdnn::Diag>(); | |||||
nopr->param() = {axis}; | |||||
auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { | |||||
return {opr::Diag::make(inputs[0], {axis})}; | |||||
}; | |||||
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { | |||||
auto&& src = *inp[0]; | |||||
TensorShape oshp(src.shape()); | |||||
if (oshp.ndim == 1) { | |||||
size_t o = oshp.shape[0] + std::abs(axis); | |||||
oshp = {o, o}; | |||||
} else { | |||||
size_t m = oshp.shape[0]; | |||||
size_t n = oshp.shape[1]; | |||||
size_t o = (axis >= 0) ? std::min(n - axis, m) : std::min(m + axis, n); | |||||
oshp = {o}; | |||||
} | |||||
dest[0].resize(oshp); | |||||
nopr->exec(src.as_megdnn(), dest[0].as_megdnn(), {}); | |||||
}; | |||||
Checker checker{make_graph, fwd}; | |||||
for (auto&& i : test_cases) { | |||||
checker.run({i}); | |||||
} | |||||
} | |||||
void test_one_hot_get(int32_t axis, const TensorShapeArray& test_cases) { | void test_one_hot_get(int32_t axis, const TensorShapeArray& test_cases) { | ||||
using Checker = AutoOprChecker<2, 1>; | using Checker = AutoOprChecker<2, 1>; | ||||
@@ -145,6 +176,12 @@ void test_one_hot(int32_t axis, const TensorShapeArray& test_cases) { | |||||
} // anonymous namespace | } // anonymous namespace | ||||
TEST(TestOprDiag, Diag) { | |||||
TensorShapeArray cases = {{7, 7}, {7, 9}, {9, 7}, {8}}; | |||||
for (int32_t k = -3; k < 3; ++k) | |||||
test_diag(k, cases); | |||||
} | |||||
TEST(TestOprIndexing, OneHot2D) { | TEST(TestOprIndexing, OneHot2D) { | ||||
TensorShapeArray cases = {{1, 1}, {2, 2}, {10, 8}, {8, 10}}; | TensorShapeArray cases = {{1, 1}, {2, 2}, {10, 8}, {8, 10}}; | ||||
test_one_hot(0, cases); | test_one_hot(0, cases); | ||||
@@ -122,6 +122,7 @@ union OperatorParam { | |||||
param.RNN = 88, | param.RNN = 88, | ||||
param.LSTM = 89, | param.LSTM = 89, | ||||
param.Softmax = 90, | param.Softmax = 90, | ||||
param.Diag = 91, | |||||
} | } | ||||
table Operator { | table Operator { | ||||