|
|
@@ -14,7 +14,7 @@ using device_reduce::CheckNonFiniteOp; |
|
|
|
template <typename T> |
|
|
|
size_t CheckNonFiniteImpl::_get_workspace_in_bytes() { |
|
|
|
// Call the _get_workspace_in_bytes to reduce the loop fetch workspace bytes |
|
|
|
typedef CheckNonFiniteOp<T, dt_float32, dt_int32, dt_int32> Op; |
|
|
|
typedef CheckNonFiniteOp<T, dt_int32, dt_int32> Op; |
|
|
|
megdnn_assert(m_size > 0); |
|
|
|
WorkspaceBundle bundle( |
|
|
|
nullptr, { |
|
|
@@ -59,7 +59,7 @@ void CheckNonFiniteImpl::_exec( |
|
|
|
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, |
|
|
|
_megdnn_workspace workspace) { |
|
|
|
check_exec(srcs, dst, workspace.size); |
|
|
|
typedef CheckNonFiniteOp<T, dt_float32, dt_int32, dt_int32> Op; |
|
|
|
typedef CheckNonFiniteOp<T, dt_int32, dt_int32> Op; |
|
|
|
auto stream = cuda_stream(this->handle()); |
|
|
|
SmallVector<size_t> workspace_sizes{ |
|
|
|
sizeof(T*) * m_size, |
|
|
|