GitOrigin-RevId: 43016ffa2b
tags/v1.8.0
@@ -998,6 +998,28 @@ protected: | |||
void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); | |||
}; | |||
class Diag : public OperatorBase { | |||
DEF_OPR_IMPL(Diag, OperatorBase, 1, 1); | |||
DEF_OPR_PARAM(Diag); | |||
public: | |||
/** | |||
* \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.diag.html | |||
*/ | |||
virtual void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& src, TensorLayout& dst); | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& dst) = 0; | |||
protected: | |||
void check_exec( | |||
const TensorLayout& src, const TensorLayout& dst, | |||
size_t workspace_in_bytes); | |||
}; | |||
class IndexingOneHotBase : public OperatorBase { | |||
DEF_OPR_IMPL_CTOR(IndexingOneHotBase, OperatorBase); | |||
DEF_OPR_PARAM(Axis); | |||
@@ -759,6 +759,14 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) | |||
'dtype', Doc('dtype', 'data type of output value'), | |||
'DTypeEnum::Float32')) | |||
(pdef('Diag'). | |||
add_fields( | |||
'int32', | |||
Doc('k', 'Index of the diagonal: 0 (the default) refers to the main ' | |||
'diagonal, a positive value refers to an upper diagonal, and a ' | |||
'negative value to a lower diagonal.'), | |||
0)) | |||
(pdef('UniformRNG', version=0, is_legacy=True). | |||
add_fields('uint64', 'seed', 0)) | |||
@@ -0,0 +1,47 @@ | |||
/** | |||
* \file dnn/src/common/diag.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
void Diag::deduce_layout(const TensorLayout& src, TensorLayout& dst) { | |||
megdnn_assert( | |||
src.ndim == 1 || src.ndim == 2, "Only support vector or matrix as input."); | |||
int k = param().k; | |||
if (src.ndim == 1) { | |||
size_t o = src.total_nr_elems() + std::abs(k); | |||
dst = TensorLayout(TensorShape({o, o}), src.dtype); | |||
} else { // src.ndim == 2 | |||
size_t m = src.shape[0]; | |||
size_t n = src.shape[1]; | |||
size_t o = (k >= 0 ? std::min(n - k, m) : std::min(m + k, n)); | |||
megdnn_assert(o > 0, "The moved diagonal is out of the input matrix."); | |||
dst = TensorLayout(TensorShape({o}), src.dtype); | |||
} | |||
} | |||
void Diag::check_exec( | |||
const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { | |||
TensorLayout dst_expected; | |||
megdnn_assert_eq_dtype(src, dst); | |||
deduce_layout(src, dst_expected); | |||
megdnn_assert_eq_layout(dst_expected, dst); | |||
megdnn_assert_contiguous(dst); | |||
auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
} | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -146,6 +146,7 @@ private: | |||
cb(BatchedSetMeshIndexing) \ | |||
cb(Linspace) \ | |||
cb(Eye) \ | |||
cb(Diag) \ | |||
cb(SleepForward) \ | |||
cb(UniformRNG) \ | |||
cb(GaussianRNG) \ | |||
@@ -88,6 +88,7 @@ DEF(IndexingRemapForward, 3, true, true); | |||
DEF(IndexingRemapBackward, 3, true, false); | |||
DEF(Linspace, 1, true, false); | |||
DEF(Eye, 1, true, false); | |||
DEF(Diag, 2, true, true); | |||
DEF(Flip, 2, true, true); | |||
DEF(ROICopy, 2, true, true); | |||
DEF(Rotate, 2, true, true); | |||
@@ -0,0 +1,87 @@ | |||
/** | |||
* \file dnn/src/cuda/diag/diag.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megdnn/dtype.h" | |||
#include "src/cuda/diag/diag.cuh" | |||
#include "src/cuda/utils.cuh" | |||
namespace { | |||
template <typename T> | |||
__global__ void kernel_to_vector( | |||
T* src, T* dst, ptrdiff_t start, ptrdiff_t size, ptrdiff_t stride_sum, | |||
ptrdiff_t dst_stride) { | |||
ptrdiff_t i = threadIdx.x + blockIdx.x * blockDim.x; | |||
if (i < size) { | |||
dst[dst_stride * i] = src[start + stride_sum * i]; | |||
} | |||
} | |||
template <typename T> | |||
__global__ void kernel_to_matrix( | |||
T* src, T* dst, ptrdiff_t offset, ptrdiff_t n, ptrdiff_t k, | |||
ptrdiff_t dst_stride0, ptrdiff_t dst_stride1, ptrdiff_t src_stride) { | |||
ptrdiff_t i = threadIdx.x + blockIdx.x * blockDim.x; | |||
ptrdiff_t x = i % n; | |||
ptrdiff_t y = i / n; | |||
ptrdiff_t p = dst_stride0 * y + dst_stride1 * x; | |||
if (i < n * n) { | |||
if (y + k == x) | |||
dst[p] = src[src_stride * (y - offset)]; | |||
else | |||
dst[p] = 0; | |||
} | |||
} | |||
} // anonymous namespace | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace diag { | |||
template <typename T> | |||
void exec_internal_to_vector( | |||
T* src, T* dst, ptrdiff_t start, ptrdiff_t size, ptrdiff_t stride_sum, | |||
ptrdiff_t dst_stride, cudaStream_t stream) { | |||
kernel_to_vector<T><<<DIVUP(size, NR_THREADS), NR_THREADS, 0, stream>>>( | |||
src, dst, start, size, stride_sum, dst_stride); | |||
after_kernel_launch(); | |||
} | |||
template <typename T> | |||
void exec_internal_to_matrix( | |||
T* src, T* dst, ptrdiff_t offset, ptrdiff_t n, ptrdiff_t k, | |||
ptrdiff_t dst_stride0, ptrdiff_t dst_stride1, ptrdiff_t src_stride, | |||
cudaStream_t stream) { | |||
kernel_to_matrix<T><<<DIVUP(n * n, NR_THREADS), NR_THREADS, 0, stream>>>( | |||
src, dst, offset, n, k, dst_stride0, dst_stride1, src_stride); | |||
after_kernel_launch(); | |||
} | |||
#define INST(T) \ | |||
template void exec_internal_to_vector<T>( \ | |||
T*, T*, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, cudaStream_t); | |||
#define cb(DType) INST(typename DTypeTrait<DType>::ctype) | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
cb(::megdnn::dtype::Bool) | |||
#undef INST | |||
#undef cb | |||
#define INST(T) \ | |||
template void exec_internal_to_matrix<T>( \ | |||
T*, T*, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, \ | |||
cudaStream_t); | |||
#define cb(DType) INST(typename DTypeTrait<DType>::ctype) | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) cb(::megdnn::dtype::Bool) | |||
} // namespace diag | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,33 @@ | |||
/** | |||
* \file dnn/src/cuda/diag/diag.cuh | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <cuda_runtime_api.h> | |||
#include <stdint.h> | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace diag { | |||
template <typename T> | |||
void exec_internal_to_vector( | |||
T* src, T* dst, ptrdiff_t start, ptrdiff_t size, ptrdiff_t stride_sum, | |||
ptrdiff_t dst_stride, cudaStream_t stream); | |||
template <typename T> | |||
void exec_internal_to_matrix( | |||
T* src, T* dst, ptrdiff_t start, ptrdiff_t n, ptrdiff_t k, | |||
ptrdiff_t dst_stride0, ptrdiff_t dst_stride1, ptrdiff_t src_stride, | |||
cudaStream_t stream); | |||
} // namespace diag | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,61 @@ | |||
/** | |||
* \file dnn/src/cuda/diag/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/cuda/diag/opr_impl.h" | |||
#include "src/cuda/diag/diag.cuh" | |||
#include "src/cuda/utils.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
void DiagImpl::exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
check_exec(src.layout, dst.layout, workspace.size); | |||
if (src.layout.ndim == 2) { | |||
auto src_stride0 = src.layout.stride[0]; | |||
auto src_stride1 = src.layout.stride[1]; | |||
auto dst_stride = dst.layout.stride[0]; | |||
auto start = | |||
(param().k >= 0) ? param().k * src_stride1 : -param().k * src_stride0; | |||
#define cb(DType) \ | |||
if (dst.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \ | |||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||
diag::exec_internal_to_vector<ctype>( \ | |||
src.ptr<ctype>(), dst.ptr<ctype>(), start, dst.layout.shape[0], \ | |||
src_stride0 + src_stride1, dst_stride, cuda_stream(handle())); \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
cb(::megdnn::dtype::Bool) | |||
#undef cb | |||
} else { | |||
auto n = dst.layout.shape[0]; | |||
auto src_stride = src.layout.stride[0]; | |||
auto dst_stride0 = dst.layout.stride[0]; | |||
auto dst_stride1 = dst.layout.stride[1]; | |||
auto offset = (param().k >= 0) ? 0 : -param().k; | |||
#define cb(DType) \ | |||
if (dst.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \ | |||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||
diag::exec_internal_to_matrix<ctype>( \ | |||
src.ptr<ctype>(), dst.ptr<ctype>(), offset, n, param().k, dst_stride0, \ | |||
dst_stride1, src_stride, cuda_stream(handle())); \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
cb(::megdnn::dtype::Bool) | |||
#undef cb | |||
} | |||
} | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,31 @@ | |||
/** | |||
* \file dnn/src/cuda/diag/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
class DiagImpl final : public Diag { | |||
public: | |||
using Diag::Diag; | |||
void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& dst) override { | |||
return 0; | |||
} | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -33,6 +33,7 @@ | |||
#include "src/cuda/dct/opr_impl.h" | |||
#include "src/cuda/deformable_conv/opr_impl.h" | |||
#include "src/cuda/deformable_ps_roi_pooling/opr_impl.h" | |||
#include "src/cuda/diag/opr_impl.h" | |||
#include "src/cuda/dot/opr_impl.h" | |||
#include "src/cuda/dropout/opr_impl.h" | |||
#include "src/cuda/elemwise/opr_impl.h" | |||
@@ -154,6 +155,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedIncrMeshIndexing); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedSetMeshIndexing); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Linspace); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Eye); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Diag); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(UniformRNG); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GaussianRNG); | |||
@@ -0,0 +1,60 @@ | |||
/** | |||
* \file dnn/src/naive/diag/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/naive/diag/opr_impl.h" | |||
#include "src/common/utils.h" | |||
#include "src/naive/handle.h" | |||
namespace megdnn { | |||
namespace naive { | |||
template <typename ctype> | |||
void DiagImpl::exec_internal( | |||
ctype* src, const TensorLayout& src_layout, ctype* dst, | |||
const TensorLayout& dst_layout, size_t input_ndim, int k) { | |||
if (input_ndim == 1) { | |||
size_t l = src_layout.shape[0]; | |||
size_t s0 = dst_layout.stride[0]; | |||
size_t s1 = dst_layout.stride[1]; | |||
size_t start = (k >= 0) ? (k * s1) : (-k * s0); | |||
for (size_t i = 0; i < dst_layout.shape[0]; ++i) | |||
for (size_t j = 0; j < dst_layout.shape[1]; ++j) | |||
dst[i * s0 + j * s1] = 0; | |||
for (size_t i = 0; i < l; ++i) | |||
dst[start + i * (s0 + s1)] = src[i]; | |||
} else { | |||
size_t l = dst_layout.shape[0]; | |||
size_t s0 = src_layout.stride[0]; | |||
size_t s1 = src_layout.stride[1]; | |||
size_t start = (k >= 0) ? (k * s1) : (-k * s0); | |||
for (size_t i = 0; i < l; ++i) | |||
dst[i] = src[start + i * (s0 + s1)]; | |||
} | |||
} | |||
void DiagImpl::exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
check_exec(src.layout, dst.layout, workspace.size); | |||
#define cb(DType) \ | |||
if (src.layout.dtype == DType()) { \ | |||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal<ctype>( \ | |||
src.ptr<ctype>(), src.layout, dst.ptr<ctype>(), dst.layout, \ | |||
src.layout.ndim, param().k)); \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
cb(::megdnn::dtype::Bool) | |||
#undef cb | |||
} | |||
} // namespace naive | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,37 @@ | |||
/** | |||
* \file dnn/src/naive/diag/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
namespace megdnn { | |||
namespace naive { | |||
class DiagImpl : public Diag { | |||
public: | |||
using Diag::Diag; | |||
void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { | |||
return 0; | |||
} | |||
private: | |||
template <typename ctype> | |||
void exec_internal( | |||
ctype* src, const TensorLayout& src_layout, ctype* dst, | |||
const TensorLayout& dst_layout, size_t input_ndim, int k); | |||
}; | |||
} // namespace naive | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -34,6 +34,7 @@ | |||
#include "src/naive/dct/opr_impl.h" | |||
#include "src/naive/deformable_conv/opr_impl.h" | |||
#include "src/naive/deformable_ps_roi_pooling/opr_impl.h" | |||
#include "src/naive/diag/opr_impl.h" | |||
#include "src/naive/dot/opr_impl.h" | |||
#include "src/naive/dropout/opr_impl.h" | |||
#include "src/naive/elemwise/opr_impl.h" | |||
@@ -0,0 +1,42 @@ | |||
/** | |||
* \file dnn/test/cuda/diag.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "test/cuda/fixture.h" | |||
#include "megdnn/oprs.h" | |||
#include "test/common/checker.h" | |||
namespace megdnn { | |||
namespace test { | |||
TEST_F(CUDA, DIAG) { | |||
Checker<Diag> checker(handle_cuda()); | |||
for (DType dtype : | |||
std::vector<DType>{dtype::Float16(), dtype::Int32(), dtype::Float32()}) | |||
for (int k = -5; k < 5; ++k) { | |||
checker.set_param({k}); | |||
checker.set_dtype(0, dtype); | |||
checker.set_dtype(1, dtype); | |||
size_t absk = static_cast<size_t>(std::abs(k)); | |||
checker.exec(TensorShapeArray{{8}, {8 + absk, 8 + absk}}); | |||
auto oshape = [&](int n, int m) -> TensorShape { | |||
size_t o = (k >= 0 ? std::min(n - k, m) : std::min(m + k, n)); | |||
return {o, o}; | |||
}; | |||
checker.exec(TensorShapeArray{{8, 6}, oshape(8, 6)}); | |||
checker.exec(TensorShapeArray{{6, 8}, oshape(6, 8)}); | |||
checker.exec(TensorShapeArray{{8, 8}, oshape(8, 8)}); | |||
} | |||
} | |||
} // namespace test | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,111 @@ | |||
/** | |||
* \file dnn/test/naive/diag.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megdnn/dtype.h" | |||
#include "megdnn/oprs.h" | |||
#include "test/common/checker.h" | |||
#include "test/naive/fixture.h" | |||
namespace megdnn { | |||
namespace test { | |||
TEST_F(NAIVE, DiagVector2Matrix) { | |||
Checker<Diag> checker(handle(), false); | |||
Diag::Param param; | |||
param.k = 0; | |||
checker.set_param(param).exect( | |||
Testcase{TensorValue({3}, dtype::Float32(), {1, 2, 3}), {}}, | |||
Testcase{ | |||
{}, | |||
// clang-format off | |||
TensorValue({3, 3}, dtype::Float32(), {1, 0, 0, | |||
0, 2, 0, | |||
0, 0, 3})}); | |||
// clang-format on | |||
} | |||
TEST_F(NAIVE, DiagVector2Matrix_PositiveK) { | |||
Checker<Diag> checker(handle(), false); | |||
Diag::Param param; | |||
param.k = 1; | |||
checker.set_param(param).exect( | |||
Testcase{TensorValue({3}, dtype::Float32(), {1, 2, 3}), {}}, | |||
Testcase{ | |||
{}, | |||
// clang-format off | |||
TensorValue({4, 4}, dtype::Float32(), {0, 1, 0, 0, | |||
0, 0, 2, 0, | |||
0, 0, 0, 3, | |||
0, 0, 0, 0,})}); | |||
// clang-format on | |||
} | |||
TEST_F(NAIVE, DiagVector2Matrix_NegativeK) { | |||
Checker<Diag> checker(handle(), false); | |||
Diag::Param param; | |||
param.k = -1; | |||
checker.set_param(param).exect( | |||
Testcase{TensorValue({3}, dtype::Float32(), {1, 2, 3}), {}}, | |||
Testcase{ | |||
{}, | |||
// clang-format off | |||
TensorValue({4, 4}, dtype::Float32(), {0, 0, 0, 0, | |||
1, 0, 0, 0, | |||
0, 2, 0, 0, | |||
0, 0, 3, 0,})}); | |||
// clang-format on | |||
} | |||
TEST_F(NAIVE, DiagMatrix2Vector) { | |||
Checker<Diag> checker(handle(), false); | |||
Diag::Param param; | |||
param.k = 0; | |||
checker.set_param(param).exect( | |||
// clang-format off | |||
Testcase{TensorValue({3, 3}, dtype::Float32(), {1, 2, 3, | |||
4, 5, 6, | |||
7, 8, 9}), | |||
// clang-format on | |||
{}}, | |||
Testcase{{}, TensorValue({3}, dtype::Float32(), {1, 5, 9})}); | |||
} | |||
TEST_F(NAIVE, DiagMatrix2Vector_PositiveK) { | |||
Checker<Diag> checker(handle(), false); | |||
Diag::Param param; | |||
param.k = 1; | |||
checker.set_param(param).exect( | |||
// clang-format off | |||
Testcase{TensorValue({3, 3}, dtype::Float32(), {1, 2, 3, | |||
4, 5, 6, | |||
7, 8, 9}), | |||
// clang-format on | |||
{}}, | |||
Testcase{{}, TensorValue({2}, dtype::Float32(), {2, 6})}); | |||
} | |||
TEST_F(NAIVE, DiagMatrix2Vector_NegativeK) { | |||
Checker<Diag> checker(handle(), false); | |||
Diag::Param param; | |||
param.k = -1; | |||
checker.set_param(param).exect( | |||
// clang-format off | |||
Testcase{TensorValue({3, 3}, dtype::Float32(), {1, 2, 3, | |||
4, 5, 6, | |||
7, 8, 9}), | |||
// clang-format on | |||
{}}, | |||
Testcase{{}, TensorValue({2}, dtype::Float32(), {4, 8})}); | |||
} | |||
} // namespace test | |||
} // namespace megdnn |
@@ -28,6 +28,7 @@ __all__ = [ | |||
"concat", | |||
"cond_take", | |||
"cumsum", | |||
"diag", | |||
"expand_dims", | |||
"eye", | |||
"flatten", | |||
@@ -53,6 +54,32 @@ __all__ = [ | |||
] | |||
def diag(inp, k=0) -> Tensor: | |||
r"""If ``inp`` is a 1D tensor, then returns a 2D tensor with the elements of ``inp`` as the diagonal. | |||
If ``inp`` is a 2D tensor, then returns a 1D tensor with the diagonal elements of ``inp``. | |||
Args: | |||
inp: input tensor. | |||
k: diagonal in consider. Use :math:`k=0` for the main diagonal, :math:`k>0` for diagonals above the | |||
main diagonal, and :math:`k<0` for diagonals below the main diagonal. Default: 0. | |||
Returns: | |||
the extracted diagonal or constructed diagonal array. | |||
Examples: | |||
>>> inp = F.arange(6, dtype='int32').reshape(2,3) | |||
>>> out = F.diag(inp, k=1) | |||
>>> out | |||
Tensor([1 5], dtype=int32, device=xpux:0) | |||
>>> F.diag(out) | |||
Tensor([[1 0] | |||
[0 5]], dtype=int32, device=xpux:0) | |||
""" | |||
op = builtin.Diag(k=k) | |||
(result,) = apply(op, inp) | |||
return result | |||
def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor: | |||
r"""Returns a 2D tensor with ones on the diagonal and zeros elsewhere. | |||
@@ -42,6 +42,26 @@ def test_eye(): | |||
) | |||
@pytest.mark.parametrize("is_varnode", [False, True]) | |||
def test_diag(is_varnode): | |||
if is_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
shapes = [(10, 10), (6, 9), (8, 7), (8,)] | |||
cases = [] | |||
for shp in shapes: | |||
cases.append({"input": [np.random.random(shp).astype("float32")]}) | |||
for axis in range(-2, 3): | |||
def run(data): | |||
return F.diag(data, k=axis) | |||
opr_test(cases, run, ref_fn=lambda x: np.diag(x, axis), network=network) | |||
def test_full(): | |||
shape = (2, 3) | |||
values = [True, 4, 5.0] | |||
@@ -433,6 +433,19 @@ OP_TRAIT_REG(Eye, Eye).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace | |||
namespace { | |||
namespace diag { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const Diag&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
cg::OperatorNodeConfig config{op.make_name()}; | |||
opr::Diag::Param param{op.k}; | |||
return opr::Diag::make(inputs[0], param, config); | |||
} | |||
OP_TRAIT_REG(Diag, Diag).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace diag | |||
} // namespace | |||
namespace { | |||
namespace roi_pooling { | |||
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const ROIPooling&>(def); | |||
@@ -240,6 +240,8 @@ def Eye: MgbHashableOp<"Eye", [EyeParam]> { | |||
); | |||
} | |||
def Diag: MgbHashableOp<"Diag", [DiagParam]>; | |||
def GetVarShape : MgbHashableOp<"GetVarShape", [OptionalAxisV1Param]>; | |||
def Concat: MgbHashableOp<"Concat", [AxisParam]> { | |||
@@ -75,6 +75,91 @@ struct MegDNNOprInitInputsModifier<IndexingSetOneHot> | |||
} // namespace opr | |||
} // namespace mgb | |||
/* ==================== Diag ==================== */ | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Diag); | |||
MEGDNN_OPR_INIT1(Diag, "diag") | |||
#if MGB_ENABLE_GRAD | |||
MGB_IMPL_OPR_GRAD(Diag) { | |||
if (wrt_idx == 0) { | |||
SymbolVar data_sym{opr.input(0)}; | |||
return DiagBackward::make(data_sym.symshape(), out_grad[0], opr.param()).node(); | |||
} | |||
return InvalidGrad::make(opr, wrt_idx); | |||
} | |||
#endif | |||
/* ==================== DiagBackward ==================== */ | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(DiagBackward); | |||
DiagBackward::DiagBackward( | |||
VarNode* shape, VarNode* value, const Param& param, | |||
const OperatorNodeConfig& config) | |||
: Super{shape->owner_graph(), config, "diag_backward", {shape, value}}, | |||
m_param{param} { | |||
add_input({shape, value}); | |||
add_output(None)->dtype(value->dtype()); | |||
add_equivalence_component<PODHash<Param>>(&m_param); | |||
} | |||
SymbolVar DiagBackward::make( | |||
SymbolVar shape, SymbolVar value, const Param& param, | |||
const OperatorNodeConfig& config) { | |||
return shape.insert_single_output_opr<DiagBackward>( | |||
shape.node(), value.node(), param, config); | |||
} | |||
cg::OperatorNodeBase::NodeProp* DiagBackward::do_make_node_prop() const { | |||
auto prop = Super::do_make_node_prop(); | |||
using D = NodeProp::DepType; | |||
prop->add_dep_type(input(0), D::HOST_VALUE); | |||
return prop; | |||
} | |||
void DiagBackward::scn_do_execute() { | |||
auto&& dest = output(0)->dev_tensor(); | |||
auto&& val = input(1)->dev_tensor(); | |||
auto&& layout = dest.layout(); | |||
mgb_assert(layout.ndim == 1 || layout.ndim == 2); | |||
if (layout.ndim == 2) { | |||
dev_tensor_memset(dest, 0); | |||
size_t offset = (m_param.k >= 0) ? (m_param.k * layout.stride[1]) | |||
: (-m_param.k * layout.stride[0]); | |||
auto dest_sub = dest.sub(SubTensorSpec::make_from_offset_elem( | |||
{val.shape(), {layout.stride[0] + layout.stride[1]}, val.dtype()}, | |||
offset)); | |||
dest_sub.copy_from_fixlayout(val); | |||
} else { | |||
auto&& opr = m_dnn_opr; | |||
if (!opr) { | |||
opr = intl::create_megdnn_opr<megdnn::Diag>(comp_node()); | |||
opr->param() = m_param; | |||
} | |||
opr->exec(val.as_megdnn(), dest.as_megdnn(), {}); | |||
} | |||
} | |||
void DiagBackward::record_execute_deps(ExecDependencyArray& deps) { | |||
deps.emplace_back(std::make_unique<intl::MegDNNGraphDep>(std::move(m_dnn_opr))); | |||
} | |||
void DiagBackward::init_output_static_infer_desc() { | |||
using namespace cg::static_infer; | |||
auto&& mgr = owner_graph()->static_infer_manager(); | |||
auto infer_shape = [](TensorShape& dest, const InpVal& inp) { | |||
cg::copy_tensor_value_to_shape(dest, inp.val.at(0).value()); | |||
return true; | |||
}; | |||
mgr.register_shape_infer( | |||
output(0), {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_shape}); | |||
} | |||
#if MGB_ENABLE_GRAD | |||
MGB_IMPL_OPR_GRAD(DiagBackward) { | |||
return InvalidGrad::make(opr, wrt_idx); | |||
} | |||
#endif | |||
/* ==================== IndexingOneHot ==================== */ | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(IndexingOneHot); | |||
MEGDNN_OPR_INIT2(IndexingOneHot, "indexing_one_hot") | |||
@@ -1,3 +1,25 @@ | |||
decl_opr( | |||
'Diag', | |||
desc='Extract a diagonal or construct a diagonal array', | |||
inputs=[ | |||
Doc('k', 'Index of the diagonal: 0 (the default) refers to the main ' | |||
'diagonal, a positive value refers to an upper diagonal, and a ' | |||
'negative value to a lower diagonal.') | |||
], | |||
params='Diag' | |||
) | |||
decl_opr( | |||
'DiagBackward', | |||
desc='backward function of Diag', | |||
inputs=[ | |||
Doc('k', 'Index of the diagonal: 0 (the default) refers to the main ' | |||
'diagonal, a positive value refers to an upper diagonal, and a ' | |||
'negative value to a lower diagonal.') | |||
], | |||
params='Diag' | |||
) | |||
decl_opr('IndexingOneHot', pyname='_indexing_one_hot', | |||
inputs=['src', 'index'], | |||
params=[('axis', 'Axis')]) | |||
@@ -25,6 +25,8 @@ MGB_SEREG_MODIFY_SUBTENSOR_OPR(BatchedSetMeshIndexing); | |||
namespace mgb { | |||
namespace opr { | |||
MGB_SEREG_OPR(Diag, 1); | |||
MGB_SEREG_OPR(DiagBackward, 2); | |||
MGB_SEREG_OPR(IndexingOneHot, 2); | |||
MGB_SEREG_OPR(IndexingRemap, 2); | |||
MGB_SEREG_OPR(IndexingRemapBackward, 3); | |||
@@ -19,6 +19,37 @@ | |||
namespace mgb { | |||
namespace opr { | |||
MGB_DEFINE_OPR_CLASS(Diag, intl::MegDNNOprWrapperFwd<megdnn::Diag>) // { | |||
public: | |||
MGE_WIN_DECLSPEC_FUC Diag( | |||
VarNode* src, const Param& param, const OperatorNodeConfig& config); | |||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||
SymbolVar src, const Param& param, const OperatorNodeConfig& config = {}); | |||
}; | |||
MGB_DEFINE_OPR_CLASS(DiagBackward, cg::SingleCNOperatorNodeBase) // { | |||
public: | |||
using Param = megdnn::Diag::Param; | |||
MGE_WIN_DECLSPEC_FUC DiagBackward( | |||
VarNode* shape, VarNode* value, const Param& param, | |||
const OperatorNodeConfig& config); | |||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||
SymbolVar shape, SymbolVar value, const Param& param, | |||
const OperatorNodeConfig& config = {}); | |||
const Param& param() const { return m_param; } | |||
private: | |||
Param m_param; | |||
intl::UniqPtrWithCN<megdnn::Diag> m_dnn_opr; | |||
void scn_do_execute() override; | |||
void init_output_static_infer_desc() override; | |||
NodeProp* do_make_node_prop() const override; | |||
void record_execute_deps(ExecDependencyArray& deps) override; | |||
}; | |||
MGB_DEFINE_OPR_CLASS( | |||
IndexingOneHot, intl::MegDNNOprWrapperFwd<megdnn::IndexingOneHotForward>) // { | |||
public: | |||
@@ -52,6 +52,37 @@ void gen_index_onehot(int* max_value, HostTensorND& dest) { | |||
} | |||
} | |||
void test_diag(int32_t axis, const TensorShapeArray& test_cases) { | |||
using Checker = AutoOprChecker<1, 1>; | |||
auto nopr = megdnn_naive_handle()->create_operator<megdnn::Diag>(); | |||
nopr->param() = {axis}; | |||
auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { | |||
return {opr::Diag::make(inputs[0], {axis})}; | |||
}; | |||
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { | |||
auto&& src = *inp[0]; | |||
TensorShape oshp(src.shape()); | |||
if (oshp.ndim == 1) { | |||
size_t o = oshp.shape[0] + std::abs(axis); | |||
oshp = {o, o}; | |||
} else { | |||
size_t m = oshp.shape[0]; | |||
size_t n = oshp.shape[1]; | |||
size_t o = (axis >= 0) ? std::min(n - axis, m) : std::min(m + axis, n); | |||
oshp = {o}; | |||
} | |||
dest[0].resize(oshp); | |||
nopr->exec(src.as_megdnn(), dest[0].as_megdnn(), {}); | |||
}; | |||
Checker checker{make_graph, fwd}; | |||
for (auto&& i : test_cases) { | |||
checker.run({i}); | |||
} | |||
} | |||
void test_one_hot_get(int32_t axis, const TensorShapeArray& test_cases) { | |||
using Checker = AutoOprChecker<2, 1>; | |||
@@ -145,6 +176,12 @@ void test_one_hot(int32_t axis, const TensorShapeArray& test_cases) { | |||
} // anonymous namespace | |||
TEST(TestOprDiag, Diag) { | |||
TensorShapeArray cases = {{7, 7}, {7, 9}, {9, 7}, {8}}; | |||
for (int32_t k = -3; k < 3; ++k) | |||
test_diag(k, cases); | |||
} | |||
TEST(TestOprIndexing, OneHot2D) { | |||
TensorShapeArray cases = {{1, 1}, {2, 2}, {10, 8}, {8, 10}}; | |||
test_one_hot(0, cases); | |||
@@ -122,6 +122,7 @@ union OperatorParam { | |||
param.RNN = 88, | |||
param.LSTM = 89, | |||
param.Softmax = 90, | |||
param.Diag = 91, | |||
} | |||
table Operator { | |||