Compare commits

...

5 Commits

Author SHA1 Message Date
  Megvii Engine Team 5f08b82f2c fix(dnn/cuda): fix ptx mma algo compute bugs 2 years ago
  Megvii Engine Team d3e786ef0f feat(imperative): load_nerwork_and_run enable weight preprocess 2 years ago
  Megvii Engine Team c6ff878d87 feat(mgb): add cu114 wheel 2 years ago
  Megvii Engine Team d1d8ddeeac fix(lite): fix lar multithread options invalid 2 years ago
  温娟 0467778f4e chore(release): bump version 2 years ago
12 changed files with 217 additions and 206 deletions
Unified View
  1. +14
    -12
      dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldg16_128x256_relu.cu
  2. +32
    -25
      dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldg16_256x64_relu.cu
  3. +32
    -25
      dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldgsts16_128x128_relu.cu
  4. +14
    -11
      dnn/src/cuda/ptx/uint4_int4/imma8832_ldg16_128x256_relu.cu
  5. +14
    -11
      dnn/src/cuda/ptx/uint4_int4/imma8832_ldg16_256x64_relu.cu
  6. +32
    -30
      dnn/src/cuda/ptx/uint4_int4/imma8832_ldgsts16_128x128_relu.cu
  7. +19
    -81
      dnn/src/cuda/ptx/uint4_int4/macro.cuh
  8. +15
    -4
      imperative/python/megengine/tools/load_network_and_run.py
  9. +4
    -0
      imperative/python/src/graph_rt.cpp
  10. +9
    -5
      lite/load_and_run/src/options/device_options.cpp
  11. +31
    -1
      scripts/whl/manylinux2014/build_wheel_common.sh
  12. +1
    -1
      src/core/include/megbrain/version.h

+ 14
- 12
dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldg16_128x256_relu.cu View File

@@ -476,6 +476,20 @@ extern "C" __global__ void __launch_bounds__(256)
__syncthreads(); __syncthreads();
} }


size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;

int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
}

// read fuse_z // read fuse_z
int2 reg_fuse_z[reg_m] = {make_int2(z_zero_point, z_zero_point), int2 reg_fuse_z[reg_m] = {make_int2(z_zero_point, z_zero_point),
make_int2(z_zero_point, z_zero_point), make_int2(z_zero_point, z_zero_point),
@@ -595,18 +609,7 @@ extern "C" __global__ void __launch_bounds__(256)
__syncthreads(); __syncthreads();


/// output /// output
size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;

int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) { if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
mul_v4(load_bias0, load_bias0, beta); mul_v4(load_bias0, load_bias0, beta);
mul_v4(load_bias1, load_bias1, beta); mul_v4(load_bias1, load_bias1, beta);
mul_v4(load_bias2, load_bias2, beta); mul_v4(load_bias2, load_bias2, beta);
@@ -617,7 +620,6 @@ extern "C" __global__ void __launch_bounds__(256)


#pragma unroll #pragma unroll
for (int y = 0; y < reg_m; y += 4) { for (int y = 0; y < reg_m; y += 4) {
I2F_4x8(reg_acc, y, 0);
FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
FUSE_Z_4x8(reg_acc, y, 0, reg_fuse_z, gamma, z_zero_point); FUSE_Z_4x8(reg_acc, y, 0, reg_fuse_z, gamma, z_zero_point);
PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point);


+ 32
- 25
dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldg16_256x64_relu.cu View File

@@ -657,6 +657,20 @@ extern "C" __global__ void __launch_bounds__(256)
__syncthreads(); __syncthreads();
} }


size_t oc = bidy * BM + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;

int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
}

// read fuse_z // read fuse_z
int2 reg_fuse_z[reg_m] = {make_int2(z_zero_point, z_zero_point), int2 reg_fuse_z[reg_m] = {make_int2(z_zero_point, z_zero_point),
make_int2(z_zero_point, z_zero_point), make_int2(z_zero_point, z_zero_point),
@@ -712,6 +726,14 @@ extern "C" __global__ void __launch_bounds__(256)
reg_flt[0][j] = make_int4(x, y, z, w); reg_flt[0][j] = make_int4(x, y, z, w);
} }


/// output
if (oc < param.oc) {
mul_v4(load_bias0, load_bias0, beta);
mul_v4(load_bias1, load_bias1, beta);
mul_v4(load_bias2, load_bias2, beta);
mul_v4(load_bias3, load_bias3, beta);
}

// compute // compute
#pragma unroll #pragma unroll
for (int k_inner = 0; k_inner < BKd32; k_inner++) { for (int k_inner = 0; k_inner < BKd32; k_inner++) {
@@ -773,35 +795,20 @@ extern "C" __global__ void __launch_bounds__(256)


__syncthreads(); __syncthreads();


/// output
size_t oc = bidy * BM + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;

int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
mul_v4(load_bias0, load_bias0, beta);
mul_v4(load_bias1, load_bias1, beta);
mul_v4(load_bias2, load_bias2, beta);
mul_v4(load_bias3, load_bias3, beta);
}

int8_t* __restrict__ g_dst_ptr = dst + d_offset; int8_t* __restrict__ g_dst_ptr = dst + d_offset;


FMA_1x8(reg_acc, 0, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
fuse_z_1x8(reg_acc[0], 0, reg_fuse_z[0], gamma, z_zero_point);
PACK_F2I_WITH_RELU_1x8(reg_acc, 0, 0, relu, dst_zero_point);

#pragma unroll #pragma unroll
for (int y = 0; y < reg_m; y += 4) {
I2F_4x8(reg_acc, y, 0);
FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
FUSE_Z_4x8(reg_acc, y, 0, reg_fuse_z, gamma, z_zero_point);
PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point);
STG_AFTER_LDG_4x1(g_offset, reg_acc, y, 0);
for (int y = 1; y < reg_m; y += 1) {
FMA_1x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
fuse_z_1x8(reg_acc[y], 0, reg_fuse_z[y], gamma, z_zero_point);
PACK_F2I_WITH_RELU_1x8(reg_acc, y, 0, relu, dst_zero_point);
STG_AFTER_LDG(g_offset[y - 1], reg_acc[y - 1][0], stg_guard[y - 1]);
} }
STG_AFTER_LDG(g_offset[7], reg_acc[7][0], stg_guard[7]);
#endif #endif
} }
} // namespace } // namespace


+ 32
- 25
dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldgsts16_128x128_relu.cu View File

@@ -437,7 +437,7 @@ extern "C" __global__ void __launch_bounds__(256)
cp_async_fence(); cp_async_fence();
} }


bool only_one_stage = (stage == 1) ? true : false;
bool only_one_stage = (stage == 1);
if (stage >= 2) { if (stage >= 2) {
cp_async_wait(stages - 2); cp_async_wait(stages - 2);
} else { } else {
@@ -844,6 +844,20 @@ extern "C" __global__ void __launch_bounds__(256)
cp_async_wait(stages - 2); cp_async_wait(stages - 2);
} }


size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;

int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
}

if (!only_one_stage) { if (!only_one_stage) {
#pragma unroll // low #pragma unroll // low
for (int i = 0; i < reg_nd4; ++i) { for (int i = 0; i < reg_nd4; ++i) {
@@ -975,6 +989,13 @@ extern "C" __global__ void __launch_bounds__(256)
reg_flt[0][j] = make_int4(x, y, z, w); reg_flt[0][j] = make_int4(x, y, z, w);
} }


if (oc < param.oc) {
mul_v4(load_bias0, load_bias0, beta);
mul_v4(load_bias1, load_bias1, beta);
mul_v4(load_bias2, load_bias2, beta);
mul_v4(load_bias3, load_bias3, beta);
}

// compute // compute
#pragma unroll #pragma unroll
for (int k_inner = 0; k_inner < BKd32; k_inner++) { for (int k_inner = 0; k_inner < BKd32; k_inner++) {
@@ -1038,34 +1059,20 @@ extern "C" __global__ void __launch_bounds__(256)
__syncthreads(); __syncthreads();


/// output /// output
size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;

int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
mul_v4(load_bias0, load_bias0, beta);
mul_v4(load_bias1, load_bias1, beta);
mul_v4(load_bias2, load_bias2, beta);
mul_v4(load_bias3, load_bias3, beta);
}

int8_t* __restrict__ g_dst_ptr = dst + d_offset; int8_t* __restrict__ g_dst_ptr = dst + d_offset;


FMA_1x8(reg_acc, 0, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
fuse_z_1x8(reg_acc[0], 0, reg_fuse_z[0], gamma, z_zero_point);
PACK_F2I_WITH_RELU_1x8(reg_acc, 0, 0, relu, dst_zero_point);

#pragma unroll #pragma unroll
for (int y = 0; y < reg_m; y += 4) {
I2F_4x8(reg_acc, y, 0);
FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
FUSE_Z_4x8(reg_acc, y, 0, reg_fuse_z, gamma, z_zero_point);
PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point);
STG_AFTER_LDG_4x1(g_offset, reg_acc, y, 0);
for (int y = 1; y < reg_m; y += 1) {
FMA_1x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
fuse_z_1x8(reg_acc[y], 0, reg_fuse_z[y], gamma, z_zero_point);
PACK_F2I_WITH_RELU_1x8(reg_acc, y, 0, relu, dst_zero_point);
STG_AFTER_LDG(g_offset[y - 1], reg_acc[y - 1][0], stg_guard[y - 1]);
} }
STG_AFTER_LDG(g_offset[7], reg_acc[7][0], stg_guard[7]);
#endif #endif
} }
} // namespace } // namespace


+ 14
- 11
dnn/src/cuda/ptx/uint4_int4/imma8832_ldg16_128x256_relu.cu View File

@@ -475,6 +475,20 @@ extern "C" __global__ void __launch_bounds__(256)
__syncthreads(); __syncthreads();
} }


size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;

int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
}

guard = iter < 0; guard = iter < 0;
#pragma unroll #pragma unroll
for (int i = 0; i < reg_nd4; ++i) { for (int i = 0; i < reg_nd4; ++i) {
@@ -574,18 +588,8 @@ extern "C" __global__ void __launch_bounds__(256)
size_t nhw_post3 = nhw_post0 + 24; size_t nhw_post3 = nhw_post0 + 24;


size_t stg_oc = bidy * BM + (warp_y << 6); size_t stg_oc = bidy * BM + (warp_y << 6);
size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;


int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) { if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
mul_v4(load_bias0, load_bias0, beta); mul_v4(load_bias0, load_bias0, beta);
mul_v4(load_bias1, load_bias1, beta); mul_v4(load_bias1, load_bias1, beta);
mul_v4(load_bias2, load_bias2, beta); mul_v4(load_bias2, load_bias2, beta);
@@ -599,7 +603,6 @@ extern "C" __global__ void __launch_bounds__(256)


#pragma unroll #pragma unroll
for (int y = 0; y < reg_m; y += 4) { for (int y = 0; y < reg_m; y += 4) {
I2F_4x8(reg_acc, y, 0);
FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point);
STG_4x1(stg_ptr, reg_acc, y, 0); STG_4x1(stg_ptr, reg_acc, y, 0);


+ 14
- 11
dnn/src/cuda/ptx/uint4_int4/imma8832_ldg16_256x64_relu.cu View File

@@ -659,6 +659,20 @@ extern "C" __global__ void __launch_bounds__(256)
__syncthreads(); __syncthreads();
} }


size_t oc = bidy * BM + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;

int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
}

guard = iter < 0; guard = iter < 0;
#pragma unroll // low #pragma unroll // low
for (int i = 0; i < reg_nd4; ++i) { for (int i = 0; i < reg_nd4; ++i) {
@@ -755,18 +769,8 @@ extern "C" __global__ void __launch_bounds__(256)
size_t nhw_post3 = nhw_post0 + 24; size_t nhw_post3 = nhw_post0 + 24;


size_t stg_oc = bidy * BM; size_t stg_oc = bidy * BM;
size_t oc = bidy * BM + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;


int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) { if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
mul_v4(load_bias0, load_bias0, beta); mul_v4(load_bias0, load_bias0, beta);
mul_v4(load_bias1, load_bias1, beta); mul_v4(load_bias1, load_bias1, beta);
mul_v4(load_bias2, load_bias2, beta); mul_v4(load_bias2, load_bias2, beta);
@@ -779,7 +783,6 @@ extern "C" __global__ void __launch_bounds__(256)


#pragma unroll #pragma unroll
for (int y = 0; y < reg_m; y += 4) { for (int y = 0; y < reg_m; y += 4) {
I2F_4x8(reg_acc, y, 0);
FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point);
STG_4x1(stg_ptr, reg_acc, y, 0); STG_4x1(stg_ptr, reg_acc, y, 0);


+ 32
- 30
dnn/src/cuda/ptx/uint4_int4/imma8832_ldgsts16_128x128_relu.cu View File

@@ -449,15 +449,15 @@ extern "C" __global__ void __launch_bounds__(256)
bool stg_guard[8]; bool stg_guard[8];
#pragma unroll #pragma unroll
for (int y = 0; y < reg_m; y += 4) { for (int y = 0; y < reg_m; y += 4) {
COMPUTE_OFFSET_4x1(reg_fuse_z, g_offset, y)
COMPUTE_OFFSET_4x1(g_offset, y);


nhw_post0 += 32;
nhw_post0 += 32;
nhw_post1 += 32; nhw_post1 += 32;
nhw_post2 += 32; nhw_post2 += 32;
nhw_post3 += 32; nhw_post3 += 32;
} }


bool only_one_stage = (stage == 1) ? true : false;
bool only_one_stage = (stage == 1);
if (stage >= 2) { if (stage >= 2) {
cp_async_wait(stages - 2); cp_async_wait(stages - 2);
} else { } else {
@@ -835,6 +835,20 @@ extern "C" __global__ void __launch_bounds__(256)
cp_async_wait(stages - 2); cp_async_wait(stages - 2);
} }


size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;

int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
}

if (!only_one_stage) { if (!only_one_stage) {
#pragma unroll // low #pragma unroll // low
for (int i = 0; i < reg_nd4; ++i) { for (int i = 0; i < reg_nd4; ++i) {
@@ -965,6 +979,13 @@ extern "C" __global__ void __launch_bounds__(256)
reg_flt[0][j] = make_int4(x, y, z, w); reg_flt[0][j] = make_int4(x, y, z, w);
} }


if (oc < param.oc) {
mul_v4(load_bias0, load_bias0, beta);
mul_v4(load_bias1, load_bias1, beta);
mul_v4(load_bias2, load_bias2, beta);
mul_v4(load_bias3, load_bias3, beta);
}

// compute // compute
#pragma unroll #pragma unroll
for (int k_inner = 0; k_inner < BKd32; k_inner++) { for (int k_inner = 0; k_inner < BKd32; k_inner++) {
@@ -1028,38 +1049,19 @@ extern "C" __global__ void __launch_bounds__(256)
__syncthreads(); __syncthreads();


/// output /// output
size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;

int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
mul_v4(load_bias0, load_bias0, beta);
mul_v4(load_bias1, load_bias1, beta);
mul_v4(load_bias2, load_bias2, beta);
mul_v4(load_bias3, load_bias3, beta);
}


int8_t* __restrict__ g_dst_ptr = dst + d_offset; int8_t* __restrict__ g_dst_ptr = dst + d_offset;


#pragma unroll
for (int y = 0; y < reg_m; y += 4) {
I2F_4x8(reg_acc, y, 0);
FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point);
STG_AFTER_LDG_4x1(g_offset, reg_acc, y, 0);
FMA_1x8(reg_acc, 0, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
PACK_F2I_WITH_RELU_1x8(reg_acc, 0, 0, relu, dst_zero_point);


nhw_post0 += 32;
nhw_post1 += 32;
nhw_post2 += 32;
nhw_post3 += 32;
#pragma unroll
for (int y = 1; y < reg_m; y += 1) {
FMA_1x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
PACK_F2I_WITH_RELU_1x8(reg_acc, y, 0, relu, dst_zero_point);
STG_AFTER_LDG(g_offset[y - 1], reg_acc[y - 1][0], stg_guard[y - 1]);
} }
STG_AFTER_LDG(g_offset[7], reg_acc[7][0], stg_guard[7]);
#endif #endif
} }
} // namespace } // namespace


+ 19
- 81
dnn/src/cuda/ptx/uint4_int4/macro.cuh View File

@@ -23,78 +23,26 @@ __device__ __forceinline__ void mul_v4<float>(
__device__ __forceinline__ void fma2( __device__ __forceinline__ void fma2(
int2& c0, const int2 a0, int2& c1, const int2 a1, const float alpha, int2& c0, const int2 a0, int2& c1, const int2 a1, const float alpha,
const int4 b) { const int4 b) {
asm("fma.rz.f32 %0, %1, %2, %3;"
: "=f"(((float*)&c0)[0])
: "f"(((float*)&a0)[0]), "f"(alpha), "f"(((float*)&b)[0]));
asm("fma.rz.f32 %0, %1, %2, %3;"
: "=f"(((float*)&c0)[1])
: "f"(((float*)&a0)[1]), "f"(alpha), "f"(((float*)&b)[1]));
asm("fma.rz.f32 %0, %1, %2, %3;"
: "=f"(((float*)&c1)[0])
: "f"(((float*)&a1)[0]), "f"(alpha), "f"(((float*)&b)[2]));
asm("fma.rz.f32 %0, %1, %2, %3;"
: "=f"(((float*)&c1)[1])
: "f"(((float*)&a1)[1]), "f"(alpha), "f"(((float*)&b)[3]));
}

__device__ __forceinline__ void fuse_z_1x8(
int4* a, const int& j, const int4& fuse_z, const float& gamma,
const int32_t& zero_point) {
const int2 z[2] = {
*reinterpret_cast<const int2*>(&fuse_z),
*(reinterpret_cast<const int2*>(&fuse_z) + 1)};
for (int k = 0; k < 4; k++) {
int f = ((z[0].x >> (k * 8)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k]))[0] += (f - zero_point) * gamma;
f = ((z[0].x >> (k * 8 + 4)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k]))[1] += (f - zero_point) * gamma;

f = ((z[1].x >> (k * 8)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k]))[2] += (f - zero_point) * gamma;
f = ((z[1].x >> (k * 8 + 4)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k]))[3] += (f - zero_point) * gamma;
}
for (int k = 0; k < 4; k++) {
int f = ((z[0].y >> (k * 8)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k + 4]))[0] += (f - zero_point) * gamma;
f = ((z[0].y >> (k * 8 + 4)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k + 4]))[1] += (f - zero_point) * gamma;

f = ((z[1].y >> (k * 8)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k + 4]))[2] += (f - zero_point) * gamma;
f = ((z[1].y >> (k * 8 + 4)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k + 4]))[3] += (f - zero_point) * gamma;
}
((float*)&c0)[0] = a0.x * alpha + ((float*)&b)[0];
((float*)&c0)[1] = a0.y * alpha + ((float*)&b)[1];
((float*)&c1)[0] = a1.x * alpha + ((float*)&b)[2];
((float*)&c1)[1] = a1.y * alpha + ((float*)&b)[3];
} }


__device__ __forceinline__ void fuse_z_1x8( __device__ __forceinline__ void fuse_z_1x8(
int2* a, const int& j, const int2& fuse_z, const float& gamma, int2* a, const int& j, const int2& fuse_z, const float& gamma,
const int32_t& zero_point) { const int32_t& zero_point) {
float x = zero_point * gamma;
#pragma unroll #pragma unroll
for (int k = 0; k < 4; k++) { for (int k = 0; k < 4; k++) {
int f = ((fuse_z.x >> (k * 8)) & 15); int f = ((fuse_z.x >> (k * 8)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k]))[0] += (f - zero_point) * gamma;
((float*)&(a[j + k]))[0] += f * gamma - x;
f = ((fuse_z.x >> (k * 8 + 4)) & 15); f = ((fuse_z.x >> (k * 8 + 4)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k]))[1] += (f - zero_point) * gamma;
}
#pragma unroll
for (int k = 0; k < 4; k++) {
int f = ((fuse_z.y >> (k * 8)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k + 4]))[0] += (f - zero_point) * gamma;
((float*)&(a[j + k]))[1] += f * gamma - x;
f = ((fuse_z.y >> (k * 8)) & 15);
((float*)&(a[j + k + 4]))[0] += f * gamma - x;
f = ((fuse_z.y >> (k * 8 + 4)) & 15); f = ((fuse_z.y >> (k * 8 + 4)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k + 4]))[1] += (f - zero_point) * gamma;
((float*)&(a[j + k + 4]))[1] += f * gamma - x;
} }
} }


@@ -282,12 +230,6 @@ __device__ __forceinline__ void pack_f2i_with_relu(
fuse_z_1x8(a[i + 2], j, fuse_z[i + 2], gamma, zero_point); \ fuse_z_1x8(a[i + 2], j, fuse_z[i + 2], gamma, zero_point); \
fuse_z_1x8(a[i + 3], j, fuse_z[i + 3], gamma, zero_point); fuse_z_1x8(a[i + 3], j, fuse_z[i + 3], gamma, zero_point);


#define FUSE_Z_4x8(a, i, j, fuse_z, gamma, zero_point) \
fuse_z_1x8(a[i], j, fuse_z[i], gamma, zero_point); \
fuse_z_1x8(a[i + 1], j, fuse_z[i + 1], gamma, zero_point); \
fuse_z_1x8(a[i + 2], j, fuse_z[i + 2], gamma, zero_point); \
fuse_z_1x8(a[i + 3], j, fuse_z[i + 3], gamma, zero_point);

// 1x8 1x(2x8 int2) to 2 int2 // 1x8 1x(2x8 int2) to 2 int2
#define PACK_F2I_1x8(a, i, j) \ #define PACK_F2I_1x8(a, i, j) \
pack_f2i(a[i][j].x, a[i][j].z, a[i][j], a[i][j + 1], a[i][j + 2], a[i][j + 3]); \ pack_f2i(a[i][j].x, a[i][j].z, a[i][j], a[i][j + 1], a[i][j + 2], a[i][j + 3]); \
@@ -316,24 +258,20 @@ __device__ __forceinline__ void pack_f2i_with_relu(
stg_guard[i + 2]) \ stg_guard[i + 2]) \
LDG(d[i + 3], s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, stg_guard[i + 3]) LDG(d[i + 3], s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, stg_guard[i + 3])


#define COMPUTE_OFFSET(d, s, idx, n_reuse, hw_reuse, g) \
#define COMPUTE_OFFSET(s, idx, n_reuse, hw_reuse, g) \
n_reuse = nhw_post##idx / param.div_ohow; \ n_reuse = nhw_post##idx / param.div_ohow; \
hw_reuse = nhw_post##idx % param.div_ohow; \ hw_reuse = nhw_post##idx % param.div_ohow; \
s = n_reuse * param.obs + hw_reuse * (packed_channel >> 1); \ s = n_reuse * param.obs + hw_reuse * (packed_channel >> 1); \
g = nhw_post##idx < param.nhw; g = nhw_post##idx < param.nhw;


#define COMPUTE_OFFSET_4x1(d, s, i) \
COMPUTE_OFFSET( \
d[i], s[i], 0, reg_src_cache[0].x, reg_src_cache[1].x, stg_guard[i]) \
COMPUTE_OFFSET( \
d[i + 1], s[i + 1], 1, reg_src_cache[0].y, reg_src_cache[1].y, \
stg_guard[i + 1]) \
COMPUTE_OFFSET( \
d[i + 2], s[i + 2], 2, reg_src_cache[0].z, reg_src_cache[1].z, \
stg_guard[i + 2]) \
COMPUTE_OFFSET( \
d[i + 3], s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, \
stg_guard[i + 3])
#define COMPUTE_OFFSET_4x1(s, i) \
COMPUTE_OFFSET(s[i], 0, reg_src_cache[0].x, reg_src_cache[1].x, stg_guard[i]) \
COMPUTE_OFFSET( \
s[i + 1], 1, reg_src_cache[0].y, reg_src_cache[1].y, stg_guard[i + 1]) \
COMPUTE_OFFSET( \
s[i + 2], 2, reg_src_cache[0].z, reg_src_cache[1].z, stg_guard[i + 2]) \
COMPUTE_OFFSET( \
s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, stg_guard[i + 3])


#define STG_AFTER_LDG(d, s, g) \ #define STG_AFTER_LDG(d, s, g) \
if (stg_oc < param.oc && g) { \ if (stg_oc < param.oc && g) { \


+ 15
- 4
imperative/python/megengine/tools/load_network_and_run.py View File

@@ -121,6 +121,9 @@ def run_model(args, graph, inputs, outputs, data):
# must use level0 to avoid unintended opr modification # must use level0 to avoid unintended opr modification
graph.options.graph_opt_level = 0 graph.options.graph_opt_level = 0


if args.weight_preprocess:
graph.enable_weight_preprocess()

logger.info("input tensors: ") logger.info("input tensors: ")
for k, v in data.items(): for k, v in data.items():
logger.info(" {}: {}".format(k, v.shape)) logger.info(" {}: {}".format(k, v.shape))
@@ -161,8 +164,8 @@ def run_model(args, graph, inputs, outputs, data):
func.wait() func.wait()
return [oup_node.get_value().numpy() for oup_node in output_dict.values()] return [oup_node.get_value().numpy() for oup_node in output_dict.values()]


if args.warm_up:
logger.info("warming up")
for i in range(args.warm_up):
logger.info("warming up {}".format(i))
run() run()


total_time = 0 total_time = 0
@@ -276,8 +279,9 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--warm-up", "--warm-up",
action="store_true",
help="warm up model before do timing " " for better estimation",
type=int,
default=0,
help="times of warm up model before do timing " " for better estimation",
) )
parser.add_argument( parser.add_argument(
"--verbose", "--verbose",
@@ -394,6 +398,13 @@ def main():
parser.add_argument( parser.add_argument(
"--custom-op-lib", type=str, help="path of the custom op", "--custom-op-lib", type=str, help="path of the custom op",
) )
parser.add_argument(
"--weight-preprocess",
action="store_true",
help="Execute operators with weight preprocess, which can"
"optimize the operator execution time with algo of winograd,"
"im2col ,etc.,but it may consume more memory.",
)


args = parser.parse_args() args = parser.parse_args()




+ 4
- 0
imperative/python/src/graph_rt.cpp View File

@@ -253,6 +253,10 @@ void init_graph_rt(py::module m) {
} }
return graph.compile(spec); return graph.compile(spec);
}) })
.def("enable_weight_preprocess",
[](cg::ComputingGraph& graph) {
graph.options().graph_opt.enable_weight_preprocess();
})
.def_property_readonly( .def_property_readonly(
"options", "options",
py::overload_cast<>(&cg::ComputingGraph::options)); py::overload_cast<>(&cg::ComputingGraph::options));


+ 9
- 5
lite/load_and_run/src/options/device_options.cpp View File

@@ -28,7 +28,7 @@ void XPUDeviceOption::config_model_internel<ModelLite>(
model->get_config().device_type = LiteDeviceType::LITE_CUDA; model->get_config().device_type = LiteDeviceType::LITE_CUDA;
} }
#endif #endif
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
} else if (runtime_param.stage == RunStage::AFTER_NETWORK_CREATED) {
auto&& network = model->get_lite_network(); auto&& network = model->get_lite_network();
if (enable_cpu_default) { if (enable_cpu_default) {
LITE_LOG("using cpu default device\n"); LITE_LOG("using cpu default device\n");
@@ -86,7 +86,7 @@ void XPUDeviceOption::config_model_internel<ModelMdl>(
}; };
} }
if (enable_multithread) { if (enable_multithread) {
mgb_log("using multithread device\n");
mgb_log("using multithread(threads number:%ld) device\n", thread_num);
model->get_mdl_config().comp_node_mapper = model->get_mdl_config().comp_node_mapper =
[&](mgb::CompNode::Locator& loc) { [&](mgb::CompNode::Locator& loc) {
loc.type = mgb::CompNode::DeviceType::MULTITHREAD; loc.type = mgb::CompNode::DeviceType::MULTITHREAD;
@@ -217,11 +217,15 @@ void XPUDeviceOption::config_model(
std::static_pointer_cast<lar::NumberInt32>(m_option["multithread"]) std::static_pointer_cast<lar::NumberInt32>(m_option["multithread"])
->get_value(); ->get_value();
enable_multithread = num_of_thread >= 0; enable_multithread = num_of_thread >= 0;
num_of_thread =
int32_t num_of_thread_dft =
std::static_pointer_cast<lar::NumberInt32>(m_option["multithread_default"]) std::static_pointer_cast<lar::NumberInt32>(m_option["multithread_default"])
->get_value(); ->get_value();
enable_multithread_default = num_of_thread >= 0;
thread_num = num_of_thread >= 0 ? num_of_thread : 0;
enable_multithread_default = num_of_thread_dft >= 0;
mgb_assert(
num_of_thread < 0 || num_of_thread_dft < 0,
"multithread and multithread_default should not bet set at the same time");
thread_num = num_of_thread >= 0 ? num_of_thread
: (num_of_thread_dft >= 0 ? num_of_thread_dft : -1);
std::string core_id_str = std::string core_id_str =
std::static_pointer_cast<lar::String>(m_option["multi_thread_core_ids"]) std::static_pointer_cast<lar::String>(m_option["multi_thread_core_ids"])
->get_value(); ->get_value();


+ 31
- 1
scripts/whl/manylinux2014/build_wheel_common.sh View File

@@ -12,7 +12,7 @@ CUDA_LIB_DIR="/usr/local/cuda/lib64/"
TensorRT_LIB_DIR="/opt/tensorrt/lib/" TensorRT_LIB_DIR="/opt/tensorrt/lib/"


SDK_NAME="unknown" SDK_NAME="unknown"
x86_64_support_version="cu101 cu111 cu112 cpu cu111_cudnn821_tensorRT825"
x86_64_support_version="cu101 cu111 cu112 cpu cu111_cudnn821_tensorRT825 cu114"
aarch64_support_version="cu102_JetsonNano cu111 cpu" aarch64_support_version="cu102_JetsonNano cu111 cpu"
if [[ -z ${IN_CI} ]] if [[ -z ${IN_CI} ]]
then then
@@ -193,6 +193,36 @@ elif [ $SDK_NAME == "cu112" ];then
REQUIR_TENSORRT_VERSION="7.2.2.3" REQUIR_TENSORRT_VERSION="7.2.2.3"
REQUIR_CUBLAS_VERSION="11.3.1.68" REQUIR_CUBLAS_VERSION="11.3.1.68"



elif [ $SDK_NAME == "cu114" ];then
BUILD_GCC8="ON"
REQUIR_CUDA_VERSION="11040"
REQUIR_CUDNN_VERSION="8.2.1"
REQUIR_TENSORRT_VERSION="7.2.2.3"
REQUIR_CUBLAS_VERSION="11.6.5.2"


CUDA_COPY_LIB_LIST="\
${CUDA_LIB_DIR}/libnvrtc.so.11.2:\
${CUDA_LIB_DIR}/libcublasLt.so.11:\
${CUDA_LIB_DIR}/libcublas.so.11:\
${CUDNN_LIB_DIR}/libcudnn_adv_infer.so.8:\
${CUDNN_LIB_DIR}/libcudnn_adv_train.so.8:\
${CUDNN_LIB_DIR}/libcudnn_cnn_infer.so.8:\
${CUDNN_LIB_DIR}/libcudnn_cnn_train.so.8:\
${CUDNN_LIB_DIR}/libcudnn_ops_infer.so.8:\
${CUDNN_LIB_DIR}/libcudnn_ops_train.so.8:\
${CUDNN_LIB_DIR}/libcudnn.so.8"

EXTRA_CMAKE_FLAG=" -DMGE_WITH_CUDNN_SHARED=ON -DMGE_WITH_CUBLAS_SHARED=ON \
-DMGE_CUDA_GENCODE=\"-gencode arch=compute_61,code=sm_61 \
-gencode arch=compute_70,code=sm_70 \
-gencode arch=compute_75,code=sm_75 \
-gencode arch=compute_80,code=sm_80 \
-gencode arch=compute_86,code=sm_86 \
-gencode arch=compute_86,code=compute_86\" "


elif [ $SDK_NAME == "cpu" ];then elif [ $SDK_NAME == "cpu" ];then
echo "use $SDK_NAME without cuda support" echo "use $SDK_NAME without cuda support"
BUILD_WHL_CPU_ONLY="ON" BUILD_WHL_CPU_ONLY="ON"


+ 1
- 1
src/core/include/megbrain/version.h View File

@@ -3,7 +3,7 @@
#include "megbrain_build_config.h" #include "megbrain_build_config.h"


#define MGE_MAJOR 1 #define MGE_MAJOR 1
#define MGE_MINOR 9999
#define MGE_MINOR 11
#define MGE_PATCH 0 #define MGE_PATCH 0


// for rc version, could be like "rc1", "rc2", etc // for rc version, could be like "rc1", "rc2", etc


Loading…
Cancel
Save