@@ -39,6 +39,7 @@ void exec_internal(T* dst, size_t m, size_t n, int k, cudaStream_t stream) { | |||||
#define INST(T) template void exec_internal<T>(T*, size_t, size_t, int, cudaStream_t); | #define INST(T) template void exec_internal<T>(T*, size_t, size_t, int, cudaStream_t); | ||||
#define cb(DType) INST(typename DTypeTrait<DType>::ctype) | #define cb(DType) INST(typename DTypeTrait<DType>::ctype) | ||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
cb(::megdnn::dtype::Bool) | |||||
} // namespace eye | } // namespace eye | ||||
} // namespace cuda | } // namespace cuda | ||||
@@ -26,6 +26,7 @@ void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||||
cuda_stream(handle())); \ | cuda_stream(handle())); \ | ||||
} | } | ||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
cb(::megdnn::dtype::Bool) | |||||
#undef cb | #undef cb | ||||
} | } | ||||
@@ -31,6 +31,7 @@ void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal<ctype>(dst.ptr<ctype>(), m, n)); \ | MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal<ctype>(dst.ptr<ctype>(), m, n)); \ | ||||
} | } | ||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
cb(::megdnn::dtype::Bool) | |||||
#undef cb | #undef cb | ||||
} | } | ||||
@@ -44,7 +44,7 @@ void exec_internal(T* dst, size_t m, size_t n, int k, hipStream_t stream) { | |||||
template void exec_internal<T>(T*, size_t, size_t, int, hipStream_t); | template void exec_internal<T>(T*, size_t, size_t, int, hipStream_t); | ||||
#define cb(DType) INST(typename DTypeTrait<DType>::ctype) | #define cb(DType) INST(typename DTypeTrait<DType>::ctype) | ||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
cb(::megdnn::dtype::Bool) | |||||
} // namespace eye | } // namespace eye | ||||
} // namespace rocm | } // namespace rocm | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -27,6 +27,7 @@ void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||||
hip_stream(handle())); \ | hip_stream(handle())); \ | ||||
} | } | ||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
cb(::megdnn::dtype::Bool) | |||||
#undef cb | #undef cb | ||||
} | } | ||||
@@ -24,21 +24,22 @@ from megengine.utils.network_node import VarNode | |||||
def test_eye(): | def test_eye(): | ||||
dtype = np.float32 | |||||
dtypes = [np.float32, np.bool] | |||||
cases = [{"input": [10, 20]}, {"input": [30]}] | cases = [{"input": [10, 20]}, {"input": [30]}] | ||||
for case in cases: | |||||
np.testing.assert_allclose( | |||||
F.eye(case["input"], dtype=dtype).numpy(), | |||||
np.eye(*case["input"]).astype(dtype), | |||||
) | |||||
np.testing.assert_allclose( | |||||
F.eye(*case["input"], dtype=dtype).numpy(), | |||||
np.eye(*case["input"]).astype(dtype), | |||||
) | |||||
np.testing.assert_allclose( | |||||
F.eye(tensor(case["input"]), dtype=dtype).numpy(), | |||||
np.eye(*case["input"]).astype(dtype), | |||||
) | |||||
for dtype in dtypes: | |||||
for case in cases: | |||||
np.testing.assert_allclose( | |||||
F.eye(case["input"], dtype=dtype).numpy(), | |||||
np.eye(*case["input"]).astype(dtype), | |||||
) | |||||
np.testing.assert_allclose( | |||||
F.eye(*case["input"], dtype=dtype).numpy(), | |||||
np.eye(*case["input"]).astype(dtype), | |||||
) | |||||
np.testing.assert_allclose( | |||||
F.eye(tensor(case["input"]), dtype=dtype).numpy(), | |||||
np.eye(*case["input"]).astype(dtype), | |||||
) | |||||
def test_full(): | def test_full(): | ||||