GitOrigin-RevId: f36495a46a
tags/v0.4.0
@@ -14,43 +14,7 @@ | |||
#include "src/cuda/utils.h" | |||
#include "src/cuda/handle.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace local { | |||
void check_input(size_t N, | |||
size_t IC, size_t IH, size_t IW, | |||
size_t OC, size_t OH, size_t OW, | |||
size_t FH, size_t FW, | |||
size_t INs, size_t ONs, | |||
size_t PH, size_t PW, | |||
size_t SH, size_t SW, | |||
bool is_xcorr) | |||
{ | |||
megdnn_ignore(N); | |||
megdnn_ignore(IC); | |||
megdnn_ignore(IH); | |||
megdnn_ignore(IW); | |||
megdnn_ignore(OC); | |||
megdnn_ignore(OH); | |||
megdnn_ignore(OW); | |||
megdnn_ignore(FH); | |||
megdnn_ignore(FW); | |||
megdnn_ignore(INs); | |||
megdnn_ignore(ONs); | |||
megdnn_ignore(PH); | |||
megdnn_ignore(PW); | |||
megdnn_ignore(SH); | |||
megdnn_ignore(SW); | |||
megdnn_ignore(is_xcorr); | |||
// shared memory constraint | |||
megdnn_assert(IH*IW <= 768, "spatial size should not be larger than 768."); | |||
// megdnn_assert(4 * 4 * 4 * IH * IW <= 49152); | |||
} | |||
} // namespace local | |||
} // namespace cuda | |||
} // namespace megdnn | |||
#include "src/common/utils.cuh" | |||
namespace megdnn { | |||
namespace cuda { | |||
@@ -94,13 +58,9 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src, | |||
param().stride_h, param().stride_w, | |||
cublas, stream, | |||
one, zero); | |||
} else { | |||
local::check_input(N, IC, IH, IW, OC, OH, OW, FH, FW, | |||
IC*IH*IW, OC*OH*OW, | |||
param().pad_h, param().pad_w, | |||
param().stride_h, param().stride_w, | |||
is_xcorr); | |||
local::forward_proxy_weiming(src.ptr<dt_float32>(), | |||
} else if (local::forward_proxy_default_share_mem_in_bytes(IH, IW) <= | |||
handle->device_prop().sharedMemPerBlock) { | |||
local::forward_proxy_default(src.ptr<dt_float32>(), | |||
filter.ptr<dt_float32>(), | |||
dst.ptr<dt_float32>(), | |||
N, | |||
@@ -112,6 +72,11 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src, | |||
param().stride_h, param().stride_w, | |||
is_xcorr, | |||
stream); | |||
} else { | |||
megdnn_throw(ssprintf( | |||
"No usable kernel for local conv, src: %s filter: %s \n", | |||
src.layout.to_string().c_str(), | |||
filter.layout.to_string().c_str())); | |||
} | |||
} | |||
@@ -18,6 +18,12 @@ namespace megdnn { | |||
namespace cuda { | |||
namespace local { | |||
constexpr size_t Ns = 4, ICs = 4; | |||
size_t forward_proxy_default_share_mem_in_bytes(size_t IH, size_t IW) { | |||
return Ns * ICs * sizeof(float) * IH * IW; | |||
} | |||
// blockIdx.y is OC*OH*OW/1024 | |||
// blockIdx.x is N/4 | |||
// threadIdx.x is [0, 1024) | |||
@@ -96,7 +102,7 @@ __global__ void forward_kernel(const float * __restrict__ src, | |||
} | |||
} | |||
void forward_proxy_weiming(const float *src, const float *filter, float *dst, | |||
void forward_proxy_default(const float *src, const float *filter, float *dst, | |||
size_t N, | |||
size_t IC, size_t IH, size_t IW, | |||
size_t OC, size_t OH, size_t OW, | |||
@@ -108,7 +114,6 @@ void forward_proxy_weiming(const float *src, const float *filter, float *dst, | |||
cudaStream_t stream) | |||
{ | |||
size_t threads = 256; | |||
const size_t Ns = 4, ICs = 4; | |||
dim3 blocks = dim3(DIVUP(N, Ns), DIVUP(OC*OH*OW, threads)); | |||
if (is_xcorr) { | |||
forward_kernel<Ns, ICs, true><<<blocks, threads, | |||
@@ -17,17 +17,10 @@ namespace megdnn { | |||
namespace cuda { | |||
namespace local { | |||
void check_input(size_t N, | |||
size_t IC, size_t IH, size_t IW, | |||
size_t OC, size_t OH, size_t OW, | |||
size_t FH, size_t FW, | |||
size_t INs, size_t ONs, | |||
size_t PH, size_t PW, | |||
size_t SH, size_t SW, | |||
bool is_xcorr); | |||
size_t forward_proxy_default_share_mem_in_bytes(size_t IH, size_t IW); | |||
void forward_proxy_weiming(const float *src, const float *filter, float *dst, | |||
size_t N, | |||
void forward_proxy_default(const float *src, const float *filter, float *dst, | |||
size_t N, | |||
size_t IC, size_t IH, size_t IW, | |||
size_t OC, size_t OH, size_t OW, | |||
size_t FH, size_t FW, | |||
@@ -39,7 +32,7 @@ void forward_proxy_weiming(const float *src, const float *filter, float *dst, | |||
/// forward | |||
bool can_forward_proxy_convnet(size_t N, | |||
bool can_forward_proxy_convnet(size_t N, | |||
size_t IC, size_t IH, size_t IW, | |||
size_t OC, size_t OH, size_t OW, | |||
size_t FH, size_t FW, | |||
@@ -70,7 +63,7 @@ size_t get_workspace_in_floats_forward_proxy_convnet(size_t N, | |||
/// bwd data | |||
bool can_backward_data_proxy_convnet(size_t N, | |||
bool can_backward_data_proxy_convnet(size_t N, | |||
size_t IC, size_t IH, size_t IW, | |||
size_t OC, size_t OH, size_t OW, | |||
size_t FH, size_t FW, | |||
@@ -78,7 +71,7 @@ bool can_backward_data_proxy_convnet(size_t N, | |||
size_t PH, size_t PW, | |||
size_t SH, size_t SW); | |||
void backward_data_proxy_convnet(const float *filter, | |||
void backward_data_proxy_convnet(const float *filter, | |||
const float *diff, | |||
float *grad, | |||
float *workspace, | |||
@@ -103,7 +96,7 @@ size_t get_workspace_in_floats_backward_data_proxy_convnet(size_t N, | |||
/// bwd filter | |||
bool can_backward_filter_proxy_convnet(size_t N, | |||
bool can_backward_filter_proxy_convnet(size_t N, | |||
size_t IC, size_t IH, size_t IW, | |||
size_t OC, size_t OH, size_t OW, | |||
size_t FH, size_t FW, | |||