@@ -13,6 +13,7 @@ from .._imperative_rt.core2 import ( | |||||
astype_cpp, | astype_cpp, | ||||
batched_matmul_cpp, | batched_matmul_cpp, | ||||
broadcast_cpp, | broadcast_cpp, | ||||
expand_dims_cpp, | |||||
getitem_cpp, | getitem_cpp, | ||||
matmul_cpp, | matmul_cpp, | ||||
reshape_cpp, | reshape_cpp, | ||||
@@ -62,7 +63,6 @@ def _matmul( | |||||
assert dim1 > 0 and dim2 > 0 | assert dim1 > 0 and dim2 > 0 | ||||
maxdim = dim1 if dim1 > dim2 else dim2 | maxdim = dim1 if dim1 > dim2 else dim2 | ||||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | ||||
if dim1 == 1 and dim2 == 1: # dispatch to Dot | if dim1 == 1 and dim2 == 1: # dispatch to Dot | ||||
(result,) = apply(builtin.Dot(), inp1, inp2) | (result,) = apply(builtin.Dot(), inp1, inp2) | ||||
return result | return result | ||||
@@ -72,34 +72,44 @@ def _matmul( | |||||
# 2x2 | # 2x2 | ||||
# nx1(transpose_a=False), n>=3 | # nx1(transpose_a=False), n>=3 | ||||
# nx2(transpose_a=False), n>=3 | # nx2(transpose_a=False), n>=3 | ||||
return matmul_cpp( | |||||
inp1, | |||||
inp2, | |||||
dim1, | |||||
dim2, | |||||
ret = matmul_cpp( | |||||
inp1 if dim1 > 1 else expand_dims_cpp(inp1, 0), | |||||
inp2 if dim2 > 1 else expand_dims_cpp(inp2, -1), | |||||
max(dim1, 2), | |||||
max(dim2, 2), | |||||
transpose_a, | transpose_a, | ||||
transpose_b, | transpose_b, | ||||
compute_mode, | compute_mode, | ||||
_config._benchmark_kernel, | _config._benchmark_kernel, | ||||
_config._deterministic_kernel, | _config._deterministic_kernel, | ||||
) | ) | ||||
if dim1 == 1: | |||||
ret = squeeze_cpp(ret, -2) | |||||
elif dim2 == 1: | |||||
ret = squeeze_cpp(ret, -1) | |||||
return ret | |||||
else: # dispath to BatchedMatrixMul | else: # dispath to BatchedMatrixMul | ||||
# nx1(transpose_a=True), n>=3 | # nx1(transpose_a=True), n>=3 | ||||
# nx2(transpose_a=True), n>=3 | # nx2(transpose_a=True), n>=3 | ||||
# nxm,n>=3,m>=3 | # nxm,n>=3,m>=3 | ||||
# 1xm,m>=3 | # 1xm,m>=3 | ||||
# 2xm,m>=3 | # 2xm,m>=3 | ||||
return batched_matmul_cpp( | |||||
inp1, | |||||
inp2, | |||||
dim1, | |||||
dim2, | |||||
ret = batched_matmul_cpp( | |||||
inp1 if dim1 > 1 else expand_dims_cpp(inp1, 0), | |||||
inp2 if dim2 > 1 else expand_dims_cpp(inp2, -1), | |||||
max(dim1, 2), | |||||
max(dim2, 2), | |||||
transpose_a, | transpose_a, | ||||
transpose_b, | transpose_b, | ||||
compute_mode, | compute_mode, | ||||
_config._benchmark_kernel, | _config._benchmark_kernel, | ||||
_config._deterministic_kernel, | _config._deterministic_kernel, | ||||
) | ) | ||||
if dim1 == 1: | |||||
ret = squeeze_cpp(ret, -2) | |||||
elif dim2 == 1: | |||||
ret = squeeze_cpp(ret, -1) | |||||
return ret | |||||
def _unary_elwise(mode): | def _unary_elwise(mode): | ||||
@@ -87,6 +87,136 @@ ValueRef make_empty_tensor( | |||||
return res; | return res; | ||||
} | } | ||||
std::optional<ValueRefList> matrix_mul_grad_rule( | |||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||||
CustomBackward& backward) { | |||||
auto&& matmul = op.cast_final_safe<MatrixMul>(); | |||||
size_t dimA = matmul.dimA; | |||||
size_t dimB = matmul.dimB; | |||||
auto&& param = matmul.param(); | |||||
auto&& policy = matmul.policy(); | |||||
mgb_assert(inputs.size() == 2); | |||||
std::array<ValueRef, 2> inps, input_shapes; | |||||
for (size_t i = 0; i < 2; ++i) { | |||||
if (inputs_require_grad[i ^ 1]) { | |||||
inps[i] = inputs[i]; | |||||
input_shapes[i] = get_shape(inputs[i]); | |||||
} | |||||
} | |||||
auto maker = CustomGradMaker(backward, inputs.size()); | |||||
maker.output_size(1).output_captured(0, false); | |||||
maker.backward([inps_ = std::move(inps), input_shapes_ = std::move(input_shapes), | |||||
param, policy, dimA, dimB](Span<ValueRef> grads) { | |||||
mgb_assert(grads.size() == 1); | |||||
ValueRef grad = grads[0]; | |||||
SmallVector<ValueRef> ret(2); | |||||
if (!grad) { | |||||
return ret; | |||||
} | |||||
size_t dimG = std::max(dimA, dimB); | |||||
if (inps_[1]) { | |||||
if (param.transposeA) { | |||||
// A^T(2) @ B(2) = G(2), A'(2) = B'(2) @ G'^T(2) -> MatrixMul | |||||
auto&& grad_op = MatrixMul::make( | |||||
param.transposeB, true, param.compute_mode, param.format, | |||||
policy.strategy, policy.workspace_limit, dimB, dimG); | |||||
ret[0] = imperative::apply(*grad_op, inps_[1], grad)[0]; | |||||
} else { | |||||
// A(>=2) @ B(2) = G(>=2), A'(>=2) = G'(>=2) @ B(2) -> MatrixMul | |||||
auto&& grad_op = MatrixMul::make( | |||||
false, !param.transposeB, param.compute_mode, param.format, | |||||
policy.strategy, policy.workspace_limit, dimG, dimB); | |||||
ret[0] = imperative::apply(*grad_op, grad, inps_[1])[0]; | |||||
} | |||||
} | |||||
if (inps_[0]) { | |||||
if (param.transposeB) { | |||||
// A(>=2) @ B^T(2) = G(>=2), B'(2) = G'^T(>=2) @ A(>=2) -> MatrixMul | |||||
// (specialized) | |||||
auto&& grad_op = MatrixMul::make( | |||||
true, param.transposeA, param.compute_mode, param.format, | |||||
policy.strategy, policy.workspace_limit, dimG, dimA); | |||||
ret[1] = imperative::apply(*grad_op, grad, inps_[0])[0]; | |||||
} else { | |||||
// A(>=2) @ B(2) = G(>=2), B'(2) = G'(>=2) @ A(>=2) -> MatrixMul | |||||
// (specialized) | |||||
auto&& grad_op = MatrixMul::make( | |||||
!param.transposeA, false, param.compute_mode, param.format, | |||||
policy.strategy, policy.workspace_limit, dimA, dimG); | |||||
ret[1] = imperative::apply(*grad_op, inps_[0], grad)[0]; | |||||
} | |||||
} | |||||
return ret; | |||||
}); | |||||
maker.finalize(); | |||||
return imperative::apply(ApplyOp(op), inputs); | |||||
} | |||||
std::optional<ValueRefList> batched_matrix_mul_grad_rule( | |||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||||
CustomBackward& backward) { | |||||
auto&& bmm = op.cast_final_safe<BatchedMatrixMul>(); | |||||
size_t dimA = bmm.dimA; | |||||
size_t dimB = bmm.dimB; | |||||
auto&& param = bmm.param(); | |||||
auto&& policy = bmm.policy(); | |||||
mgb_assert(inputs.size() == 2); | |||||
std::array<ValueRef, 2> inps, input_shapes; | |||||
for (size_t i = 0; i < 2; ++i) { | |||||
if (inputs_require_grad[i ^ 1]) { | |||||
inps[i] = inputs[i]; | |||||
input_shapes[i] = get_shape(inputs[i]); | |||||
} | |||||
} | |||||
auto maker = CustomGradMaker(backward, inputs.size()); | |||||
maker.output_size(1).output_captured(0, false); | |||||
maker.backward([inps_ = std::move(inps), input_shapes_ = std::move(input_shapes), | |||||
param, policy, dimA, dimB](Span<ValueRef> grads) { | |||||
mgb_assert(grads.size() == 1); | |||||
ValueRef grad = grads[0]; | |||||
SmallVector<ValueRef> ret(2); | |||||
if (!grad) { | |||||
return ret; | |||||
} | |||||
size_t dimG = std::max(dimA, dimB); | |||||
if (inps_[1]) { | |||||
if (param.transposeA) { | |||||
auto&& grad_op = BatchedMatrixMul::make( | |||||
param.transposeB, true, param.compute_mode, param.format, | |||||
policy.strategy, policy.workspace_limit, dimB, dimG); | |||||
ret[0] = imperative::apply(*grad_op, inps_[1], grad)[0]; | |||||
} else { | |||||
auto&& grad_op = BatchedMatrixMul::make( | |||||
false, !param.transposeB, param.compute_mode, param.format, | |||||
policy.strategy, policy.workspace_limit, dimG, dimB); | |||||
ret[0] = imperative::apply(*grad_op, grad, inps_[1])[0]; | |||||
} | |||||
if (dimG != dimA) { | |||||
ret[0] = reduce_to(ret[0], input_shapes_[0]); | |||||
} | |||||
} | |||||
if (inps_[0]) { | |||||
if (param.transposeB) { | |||||
auto&& grad_op = BatchedMatrixMul::make( | |||||
true, param.transposeA, param.compute_mode, param.format, | |||||
policy.strategy, policy.workspace_limit, dimG, dimA); | |||||
ret[1] = imperative::apply(*grad_op, grad, inps_[0])[0]; | |||||
} else { | |||||
auto&& grad_op = BatchedMatrixMul::make( | |||||
!param.transposeA, false, param.compute_mode, param.format, | |||||
policy.strategy, policy.workspace_limit, dimA, dimG); | |||||
ret[1] = imperative::apply(*grad_op, inps_[0], grad)[0]; | |||||
} | |||||
if (dimG != dimB) { | |||||
ret[1] = reduce_to(ret[1], input_shapes_[1]); | |||||
} | |||||
} | |||||
return ret; | |||||
}); | |||||
maker.finalize(); | |||||
return imperative::apply(ApplyOp(op), inputs); | |||||
} | |||||
std::optional<ValueRefList> elemwise_grad_rule( | std::optional<ValueRefList> elemwise_grad_rule( | ||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | ||||
CustomBackward& backward) { | CustomBackward& backward) { | ||||
@@ -395,6 +525,9 @@ struct Init { | |||||
FastpathCopy::typeinfo(), fastpathcopy_grad_rule); | FastpathCopy::typeinfo(), fastpathcopy_grad_rule); | ||||
CustomBackward::register_grad_rule( | CustomBackward::register_grad_rule( | ||||
PixelShuffle::typeinfo(), pixelShuffle_grad_rule); | PixelShuffle::typeinfo(), pixelShuffle_grad_rule); | ||||
CustomBackward::register_grad_rule(MatrixMul::typeinfo(), matrix_mul_grad_rule); | |||||
CustomBackward::register_grad_rule( | |||||
BatchedMatrixMul::typeinfo(), batched_matrix_mul_grad_rule); | |||||
} | } | ||||
} _; | } _; | ||||
@@ -511,3 +511,45 @@ def test_pixel_shuffle(): | |||||
y = f(x) | y = f(x) | ||||
grad(y, F.ones_like(y)) | grad(y, F.ones_like(y)) | ||||
np.testing.assert_equal(2 * x.numpy(), x.grad.numpy()) | np.testing.assert_equal(2 * x.numpy(), x.grad.numpy()) | ||||
def test_matmul(): | |||||
def test_one(xdim, ydim, transposeA, transposeB): | |||||
xshape = (1, 4) if xdim == 1 else (2,) * (xdim - 2) + (3, 4) | |||||
yshape = (4, 1) if ydim == 1 else (2,) * (ydim - 2) + (4, 5) | |||||
x = np.random.rand(*xshape).astype("float32") | |||||
y = np.random.rand(*yshape).astype("float32") | |||||
gshape = (x @ y).shape | |||||
g = np.random.rand(*gshape).astype("float32") | |||||
dx = g @ np.swapaxes(y, -1, -2) | |||||
dy = np.swapaxes(x, -1, -2) @ g | |||||
while dx.shape != x.shape: | |||||
dx = dx.sum(0) | |||||
while dy.shape != y.shape: | |||||
dy = dy.sum(0) | |||||
if transposeA: | |||||
x = np.swapaxes(x, -1, -2) | |||||
dx = np.swapaxes(dx, -1, -2) | |||||
if transposeB: | |||||
y = np.swapaxes(y, -1, -2) | |||||
dy = np.swapaxes(dy, -1, -2) | |||||
x = mge.Tensor(x.squeeze()) | |||||
y = mge.Tensor(y.squeeze()) | |||||
g = mge.Tensor(g.squeeze()) | |||||
with Grad() as grad: | |||||
grad.wrt(x, callback=save_to(x)) | |||||
grad.wrt(y, callback=save_to(y)) | |||||
z = F.matmul(x, y, transpose_a=transposeA, transpose_b=transposeB) | |||||
grad(z, g) | |||||
np.testing.assert_almost_equal(dx.squeeze(), x.grad.numpy(), decimal=5) | |||||
np.testing.assert_almost_equal(dy.squeeze(), y.grad.numpy(), decimal=5) | |||||
for xdim in [1, 2, 3, 4]: | |||||
for ydim in [1, 2, 3, 4]: | |||||
for transposeA in [False, True]: | |||||
if xdim == 1 and transposeA == True: | |||||
continue | |||||
for transposeB in [False, True]: | |||||
if ydim == 1 and transposeB == True: | |||||
continue | |||||
test_one(xdim, ydim, transposeA, transposeB) |
@@ -31,18 +31,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
DTypeScalar vi{-1}; | DTypeScalar vi{-1}; | ||||
auto graph = inputs[0]->owner_graph(); | auto graph = inputs[0]->owner_graph(); | ||||
bool remove_row = false, remove_col = false; | |||||
if (dim1 == 1) { | |||||
dim1 = 2; | |||||
remove_row = true; | |||||
inp1 = inp1.add_axis(0); | |||||
} | |||||
if (dim2 == 1) { | |||||
dim2 = 2; | |||||
remove_col = true; | |||||
inp2 = inp2.add_axis(1); | |||||
} | |||||
SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; | SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; | ||||
if (dim1 > 2) { | if (dim1 > 2) { | ||||
auto idx = opr::ImmutableTensor::make(*graph, vi, config); | auto idx = opr::ImmutableTensor::make(*graph, vi, config); | ||||
@@ -91,17 +79,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
result = result.reshape(tshp); | result = result.reshape(tshp); | ||||
} | } | ||||
auto maxdim = dim1 > dim2 ? dim1 : dim2; | |||||
if (remove_row) { | |||||
std::vector<Desc> remove_param; | |||||
remove_param.push_back(Desc::make_remove(maxdim - 2)); | |||||
result = opr::AxisAddRemove::make(result, remove_param); | |||||
} | |||||
if (remove_col) { | |||||
std::vector<Desc> remove_param; | |||||
remove_param.push_back(Desc::make_remove(maxdim - 1)); | |||||
result = opr::AxisAddRemove::make(result, remove_param); | |||||
} | |||||
return result; | return result; | ||||
} | } | ||||
@@ -113,6 +90,19 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | ||||
DType dst_dtype; | DType dst_dtype; | ||||
if (dim1 == dim2 && dim2 >= 3) { // only happens in backward | |||||
for (size_t i = 1; i + 1 < layout1.ndim; ++i) { | |||||
layout1[0] *= layout1[i]; | |||||
layout2[0] *= layout2[i]; | |||||
} | |||||
layout1[1] = layout1[layout1.ndim - 1]; | |||||
layout1.ndim = 2; | |||||
layout1.init_contiguous_stride(); | |||||
layout2[1] = layout2[layout2.ndim - 1]; | |||||
layout2.ndim = 2; | |||||
layout2.init_contiguous_stride(); | |||||
dim1 = dim2 = 2; | |||||
} | |||||
DnnOprCaller<megdnn::MatrixMul> dnn_opr(inputs[0].comp_node); | DnnOprCaller<megdnn::MatrixMul> dnn_opr(inputs[0].comp_node); | ||||
dnn_opr.op->param() = matmul.param(); | dnn_opr.op->param() = matmul.param(); | ||||
@@ -156,6 +146,19 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
DnnOprCaller<megdnn::MatrixMul> dnn_opr(cn); | DnnOprCaller<megdnn::MatrixMul> dnn_opr(cn); | ||||
dnn_opr.op->param() = matmul.param(); | dnn_opr.op->param() = matmul.param(); | ||||
if (matmul.dimA == matmul.dimB && matmul.dimB >= 3) { // only happens in backward | |||||
for (size_t i = 1; i + 1 < layout1.ndim; ++i) { | |||||
layout1[0] *= layout1[i]; | |||||
layout2[0] *= layout2[i]; | |||||
} | |||||
layout1[1] = layout1[layout1.ndim - 1]; | |||||
layout1.ndim = 2; | |||||
layout1.init_contiguous_stride(); | |||||
layout2[1] = layout2[layout2.ndim - 1]; | |||||
layout2.ndim = 2; | |||||
layout2.init_contiguous_stride(); | |||||
} | |||||
DType dst_dtype; | DType dst_dtype; | ||||
dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | ||||
@@ -191,11 +194,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
} | } | ||||
TensorLayout layout_a = layout1, layout_b = layout2; | TensorLayout layout_a = layout1, layout_b = layout2; | ||||
if (dim1 == 1) { | |||||
layout_a.add_axis_cont_inplace(0); | |||||
inp_tensornds[0] = inputs[0]->dnn_tensor(); | |||||
inp_tensornds[0].layout = layout_a; | |||||
} else if (dim1 > 2) { | |||||
if (dim1 > 2) { | |||||
size_t batch = std::accumulate( | size_t batch = std::accumulate( | ||||
layout1.shape, layout1.shape + dim1 - 1, (size_t)1, | layout1.shape, layout1.shape + dim1 - 1, (size_t)1, | ||||
std::multiplies<size_t>()); | std::multiplies<size_t>()); | ||||
@@ -216,13 +215,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
inp_tensornds[0] = inputs[0]->dnn_tensor(); | inp_tensornds[0] = inputs[0]->dnn_tensor(); | ||||
} | } | ||||
if (dim2 == 1) { | |||||
layout_b.add_axis_inplace(1, 1, 1); | |||||
inp_tensornds[1] = inputs[1]->dnn_tensor(); | |||||
inp_tensornds[1].layout = layout_b; | |||||
} else { | |||||
inp_tensornds[1] = inputs[1]->dnn_tensor(); | |||||
} | |||||
inp_tensornds[1] = inputs[1]->dnn_tensor(); | |||||
TensorLayout dst_layout = TensorLayout({layout_a[0], layout_b[1]}, dst_dtype); | TensorLayout dst_layout = TensorLayout({layout_a[0], layout_b[1]}, dst_dtype); | ||||
dst_layout.init_contiguous_stride(); | dst_layout.init_contiguous_stride(); | ||||
@@ -232,6 +225,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
if (matmul.transposeB) | if (matmul.transposeB) | ||||
std::swap(layout_b.shape[0], layout_b.shape[1]); | std::swap(layout_b.shape[0], layout_b.shape[1]); | ||||
if (matmul.dimA == matmul.dimB && matmul.dimB >= 3) { // only happens in backward | |||||
inp_tensornds[0].layout = layout_a; | |||||
inp_tensornds[1].layout = layout_b; | |||||
} | |||||
DeviceTensorND out = | DeviceTensorND out = | ||||
BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); | BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); | ||||
size_t sz = setup_algo<megdnn::MatrixMul>( | size_t sz = setup_algo<megdnn::MatrixMul>( | ||||
@@ -279,18 +277,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto graph = inputs[0]->owner_graph(); | auto graph = inputs[0]->owner_graph(); | ||||
auto idx = opr::ImmutableTensor::make(*graph, vi, config); | auto idx = opr::ImmutableTensor::make(*graph, vi, config); | ||||
bool remove_row = false, remove_col = false; | |||||
if (dim1 == 1) { | |||||
dim1 = 2; | |||||
remove_row = true; | |||||
inp1 = inp1.add_axis(0); | |||||
} | |||||
if (dim2 == 1) { | |||||
dim2 = 2; | |||||
remove_col = true; | |||||
inp2 = inp2.add_axis(1); | |||||
} | |||||
auto shp1 = inp1.symshape(); | auto shp1 = inp1.symshape(); | ||||
auto shp2 = inp2.symshape(); | auto shp2 = inp2.symshape(); | ||||
SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; | SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; | ||||
@@ -349,16 +335,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
result_shp = opr::Concat::make({batch_shape, shp_tail}, 0, cn); | result_shp = opr::Concat::make({batch_shape, shp_tail}, 0, cn); | ||||
result = result.reshape(result_shp); | result = result.reshape(result_shp); | ||||
} | } | ||||
if (remove_row) { | |||||
std::vector<Desc> remove_param; | |||||
remove_param.push_back(Desc::make_remove(maxdim - 2)); | |||||
result = opr::AxisAddRemove::make(result, remove_param); | |||||
} | |||||
if (remove_col) { | |||||
std::vector<Desc> remove_param; | |||||
remove_param.push_back(Desc::make_remove(maxdim - 1)); | |||||
result = opr::AxisAddRemove::make(result, remove_param); | |||||
} | |||||
return result; | return result; | ||||
} | } | ||||
@@ -418,21 +394,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
DType dst_dtype; | DType dst_dtype; | ||||
dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | ||||
bool remove_row = false, remove_col = false; | |||||
if (dim1 == 1) { | |||||
dim1 = 2; | |||||
remove_row = true; | |||||
} | |||||
if (dim2 == 1) { | |||||
dim2 = 2; | |||||
remove_col = true; | |||||
} | |||||
if (remove_row) | |||||
layout1.add_axis_cont_inplace(0); | |||||
if (remove_col) | |||||
layout2.add_axis_inplace(1, 1, 1); | |||||
TensorShape tshp, batch_shp; | TensorShape tshp, batch_shp; | ||||
size_t j = 0; | size_t j = 0; | ||||
auto inp1 = inputs[0], inp2 = inputs[1]; | auto inp1 = inputs[0], inp2 = inputs[1]; | ||||
@@ -530,12 +491,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
if (maxdim > 3) { | if (maxdim > 3) { | ||||
dst_layout = dst_layout.reshape(shp1); | dst_layout = dst_layout.reshape(shp1); | ||||
} | } | ||||
if (remove_row) { | |||||
dst_layout = dst_layout.remove_axis(maxdim - 2); | |||||
} | |||||
if (remove_col) { | |||||
dst_layout = dst_layout.remove_axis(maxdim - 1); | |||||
} | |||||
return {Tensor::make(out.sub(SubTensorSpec::make_from_layout(dst_layout)))}; | return {Tensor::make(out.sub(SubTensorSpec::make_from_layout(dst_layout)))}; | ||||
} | } | ||||