GitOrigin-RevId: 44a0adddba
release-1.7
@@ -29,17 +29,24 @@ bool is_transpose_single( | |||||
* assuming contig layout is: | * assuming contig layout is: | ||||
* shape: b, m, n, c | * shape: b, m, n, c | ||||
* stride: mnc, nc, c, 1 | * stride: mnc, nc, c, 1 | ||||
* assuming non-contig layout is: | |||||
* shape: b, m, n, c | |||||
* stride: m*stride_m*c, stride_m*c, c, 1 | |||||
* | * | ||||
* then given layout should be: | * then given layout should be: | ||||
* shape: b, n, m, c | * shape: b, n, m, c | ||||
* stride: mnc, c, nc, 1 | * stride: mnc, c, nc, 1 | ||||
* non-contig stride: m*stride_m*c, c, stride_m*c, 1 | |||||
* | * | ||||
* if c == 1: | * if c == 1: | ||||
* shape: b, n, m | * shape: b, n, m | ||||
* stride: mn, 1, n | * stride: mn, 1, n | ||||
* non-contig stride: m*stride_m, 1, stride_m | |||||
* | |||||
* if b == 1: | * if b == 1: | ||||
* shape: n, m, c | * shape: n, m, c | ||||
* stride: c, nc, 1 | * stride: c, nc, 1 | ||||
* non-contig stride: c, stride_m*c, 1 | |||||
* | * | ||||
* if b == 1 && c == 1: | * if b == 1 && c == 1: | ||||
* shape: n, m | * shape: n, m | ||||
@@ -65,7 +72,16 @@ bool is_transpose_single( | |||||
p.n = layout[1]; | p.n = layout[1]; | ||||
p.m = layout[2]; | p.m = layout[2]; | ||||
p.c = 1; | p.c = 1; | ||||
return strd(2, p.n) && strd(0, p.m * p.n); | |||||
if (strd(2, p.n) && strd(0, p.m * p.n)) { | |||||
return true; | |||||
} else if ( | |||||
allow_no_contig && (size_t)(layout.stride[2]) >= p.n && | |||||
strd(0, p.m * (size_t)(layout.stride[2])) && strd(1, 1)) { | |||||
p.stride_m = layout.stride[2]; | |||||
return true; | |||||
} | |||||
return false; | |||||
} | } | ||||
if (strd(2, 1)) { | if (strd(2, 1)) { | ||||
// b == 1 | // b == 1 | ||||
@@ -41,6 +41,20 @@ TEST_F(AARCH64, Relayout) { | |||||
} | } | ||||
} | } | ||||
TEST_F(AARCH64, RelayoutNonContig) { | |||||
Checker<Relayout> checker(handle()); | |||||
std::vector<::megdnn::DType> dtype_vec; | |||||
dtype_vec.push_back(dtype::Float32()); | |||||
dtype_vec.push_back(dtype::Int16()); | |||||
dtype_vec.push_back(dtype::Uint16()); | |||||
dtype_vec.push_back(dtype::Int8()); | |||||
for (auto dtype : dtype_vec) { | |||||
TensorLayout src({4, 90, 15, 29}, {41760, 1, 2784, 96}, dtype); | |||||
TensorLayout dst({4, 90, 15, 29}, {39150, 435, 29, 1}, dtype); | |||||
checker.execl({src, dst}); | |||||
} | |||||
} | |||||
TEST_F(AARCH64, RelayoutBig) { | TEST_F(AARCH64, RelayoutBig) { | ||||
Checker<Relayout> checker(handle()); | Checker<Relayout> checker(handle()); | ||||
ConsecutiveRNG rng; | ConsecutiveRNG rng; | ||||