Browse Source

feat(dnn): enable eye to support bool

GitOrigin-RevId: 76d874d5b7
tags/v1.7.0.m1
Megvii Engine Team 3 years ago
parent
commit
1f0cc891b0
6 changed files with 20 additions and 15 deletions
  1. +1
    -0
      dnn/src/cuda/eye/eye.cu
  2. +1
    -0
      dnn/src/cuda/eye/opr_impl.cpp
  3. +1
    -0
      dnn/src/naive/eye/opr_impl.cpp
  4. +1
    -1
      dnn/src/rocm/eye/eye.cpp.hip
  5. +1
    -0
      dnn/src/rocm/eye/opr_impl.cpp
  6. +15
    -14
      imperative/python/test/unit/functional/test_tensor.py

+ 1
- 0
dnn/src/cuda/eye/eye.cu View File

@@ -39,6 +39,7 @@ void exec_internal(T* dst, size_t m, size_t n, int k, cudaStream_t stream) {
#define INST(T) template void exec_internal<T>(T*, size_t, size_t, int, cudaStream_t);
#define cb(DType) INST(typename DTypeTrait<DType>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)

} // namespace eye
} // namespace cuda


+ 1
- 0
dnn/src/cuda/eye/opr_impl.cpp View File

@@ -26,6 +26,7 @@ void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) {
cuda_stream(handle())); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
}



+ 1
- 0
dnn/src/naive/eye/opr_impl.cpp View File

@@ -31,6 +31,7 @@ void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) {
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal<ctype>(dst.ptr<ctype>(), m, n)); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
}



+ 1
- 1
dnn/src/rocm/eye/eye.cpp.hip View File

@@ -44,7 +44,7 @@ void exec_internal(T* dst, size_t m, size_t n, int k, hipStream_t stream) {
template void exec_internal<T>(T*, size_t, size_t, int, hipStream_t);
#define cb(DType) INST(typename DTypeTrait<DType>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
} // namespace eye
} // namespace rocm
} // namespace megdnn


+ 1
- 0
dnn/src/rocm/eye/opr_impl.cpp View File

@@ -27,6 +27,7 @@ void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) {
hip_stream(handle())); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
}



+ 15
- 14
imperative/python/test/unit/functional/test_tensor.py View File

@@ -24,21 +24,22 @@ from megengine.utils.network_node import VarNode


def test_eye():
dtype = np.float32
dtypes = [np.float32, np.bool]
cases = [{"input": [10, 20]}, {"input": [30]}]
for case in cases:
np.testing.assert_allclose(
F.eye(case["input"], dtype=dtype).numpy(),
np.eye(*case["input"]).astype(dtype),
)
np.testing.assert_allclose(
F.eye(*case["input"], dtype=dtype).numpy(),
np.eye(*case["input"]).astype(dtype),
)
np.testing.assert_allclose(
F.eye(tensor(case["input"]), dtype=dtype).numpy(),
np.eye(*case["input"]).astype(dtype),
)
for dtype in dtypes:
for case in cases:
np.testing.assert_allclose(
F.eye(case["input"], dtype=dtype).numpy(),
np.eye(*case["input"]).astype(dtype),
)
np.testing.assert_allclose(
F.eye(*case["input"], dtype=dtype).numpy(),
np.eye(*case["input"]).astype(dtype),
)
np.testing.assert_allclose(
F.eye(tensor(case["input"]), dtype=dtype).numpy(),
np.eye(*case["input"]).astype(dtype),
)


def test_full():


Loading…
Cancel
Save