From 7f28005847c6d806b418f56d4cfb434a630b721d Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 9 Nov 2022 19:45:53 +0800 Subject: [PATCH] fix(imperative): fix the matrix mul error when symbolic trace GitOrigin-RevId: 3754dc5a658718a4b00bd16d58e581116574adb7 --- .../python/test/integration/test_trace_dump.py | 34 +++++++ imperative/src/impl/ops/matmul.cpp | 106 ++++++++++++--------- imperative/src/test/helper.cpp | 25 +++++ imperative/src/test/helper.h | 3 + imperative/src/test/imperative.cpp | 56 +++++++++++ 5 files changed, 180 insertions(+), 44 deletions(-) diff --git a/imperative/python/test/integration/test_trace_dump.py b/imperative/python/test/integration/test_trace_dump.py index eb6dc1a0..72604fa3 100644 --- a/imperative/python/test/integration/test_trace_dump.py +++ b/imperative/python/test/integration/test_trace_dump.py @@ -12,6 +12,7 @@ import megengine.optimizer as optim from megengine import tensor from megengine.autodiff import GradManager from megengine.jit import trace +from megengine.optimizer import SGD @contextlib.contextmanager @@ -138,3 +139,36 @@ def test_dump_bn_train_mode(): bn_train(data) with pytest.raises(RuntimeError): bn_train.dump("test.mge") + + +class ViTmode(M.Module): + def __init__(self, patch_size=16, in_chans=3, embed_dim=384): + super().__init__() + self.proj = M.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + self.head = M.Linear(embed_dim, 1000) + + def forward(self, x): + x = self.proj(x) + x = F.flatten(x, 2).transpose(0, 2, 1) + x = self.head(x) + return x + + +def test_ViTmode_trace_train(): + model = ViTmode(embed_dim=384) + data = mge.random.normal(size=(1, 3, 224, 224)) + optim = SGD(model.parameters(), lr=0.01) + gm = GradManager() + gm.attach(model.parameters()) + + @trace(symbolic=True, capture_as_const=True) + def train(): + for i in range(2): + with gm: + loss = model(data) + gm.backward(loss) + optim.step().clear_grad() + + train() diff --git a/imperative/src/impl/ops/matmul.cpp b/imperative/src/impl/ops/matmul.cpp index a5b28b31..07061f93 100644 --- a/imperative/src/impl/ops/matmul.cpp +++ b/imperative/src/impl/ops/matmul.cpp @@ -22,62 +22,80 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { mgb_assert(inputs.size() == 2); auto inp1 = SymbolVar{inputs[0]}, inp2 = SymbolVar{inputs[1]}; auto dim1 = matmul.dimA, dim2 = matmul.dimB; + mgb_assert( + dim1 >= 2 && dim2 >= 2, + "the dim of one input the matmul operator dim is less than 2."); auto cn = inputs[0]->comp_node(); using IndexDesc = opr::Subtensor::IndexDesc; OperatorNodeConfig config{matmul.make_name(), cn}; - DTypeScalar vi{-1}; auto graph = inputs[0]->owner_graph(); - - SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; - if (dim1 > 2) { - auto idx = opr::ImmutableTensor::make(*graph, vi, config); - auto shp1 = inp1.symshape(); - IndexDesc head_desc(1); - head_desc[0].end = idx; - shp1_head = opr::Subtensor::make(shp1, head_desc); - auto batch = opr::Reduce::make(shp1_head, {Reduce::Mode::PRODUCT, 0}); - IndexDesc tail_desc(1); - tail_desc[0].begin = idx; - shp1_tail = opr::Subtensor::make(shp1, tail_desc); - auto tshp = opr::Concat::make({batch, shp1_tail}, 0, cn); - inp1 = inp1.reshape(tshp); - } - if (dim2 > 2) { - auto idx = opr::ImmutableTensor::make(*graph, vi, config); - auto shp2 = inp2.symshape(); - IndexDesc head_desc(1); - head_desc[0].end = idx; - shp2_head = opr::Subtensor::make(shp2, head_desc); - auto batch = opr::Reduce::make(shp2_head, {Reduce::Mode::PRODUCT, 0}); - IndexDesc tail_desc(1); - tail_desc[0].begin = idx; - auto shp2_tail = opr::Subtensor::make(shp2, tail_desc); - auto tshp = opr::Concat::make({batch, shp2_tail}, 0, cn); - inp2 = inp2.reshape(tshp); + if (dim1 == 2 && dim2 == 2) { + return opr::MatrixMul::make( + inp1, inp2, matmul.param(), matmul.policy(), config); } - auto result = - opr::MatrixMul::make(inp1, inp2, matmul.param(), matmul.policy(), config); - if (dim1 > 2) { - auto idx = opr::ImmutableTensor::make(*graph, vi, config); - auto result_shape = result.symshape(); - IndexDesc tail_desc(1); - tail_desc[0].begin = idx; - auto shp_tail = opr::Subtensor::make(result_shape, tail_desc); - auto tshp = opr::Concat::make({shp1_head, shp_tail}, 0, cn); - result = result.reshape(tshp); - } - if (dim2 > 2) { + //! use batched matrix mul + SymbolVar shp_head, batch; + DTypeScalar vi{-2}; + auto compress_shape = [&](SymbolVar inp) { + if (inp.shape().ndim > 3) { + auto idx = opr::ImmutableTensor::make(*graph, vi, config); + auto shp = inp.symshape(); + IndexDesc head_desc(1); + head_desc[0].end = idx; + shp_head = opr::Subtensor::make(shp, head_desc); + batch = opr::Reduce::make(shp_head, {Reduce::Mode::PRODUCT, 0}); + IndexDesc tail_desc(1); + tail_desc[0].begin = idx; + auto shp_tail = opr::Subtensor::make(shp, tail_desc); + auto tshp = opr::Concat::make({batch, shp_tail}, 0, cn); + return inp.reshape(tshp); + } else if (inp.shape().ndim == 3) { + auto idx = opr::ImmutableTensor::make(*graph, vi, config); + auto shp = inp.symshape(); + IndexDesc head_desc(1); + head_desc[0].end = idx; + shp_head = opr::Subtensor::make(shp, head_desc); + batch = opr::Reduce::make(shp_head, {Reduce::Mode::PRODUCT, 0}); + return inp; + } else { + return inp; + } + }; + + inp1 = compress_shape(inp1); + inp2 = compress_shape(inp2); + + auto expand_shape = [&](SymbolVar inp) { + if (inp.shape().ndim < 3) { + auto shp = inp.symshape(); + using Desc = opr::AxisAddRemove::AxisDesc; + std::vector add_axis_param; + add_axis_param.push_back(Desc::make_add(0)); + auto out = opr::AxisAddRemove::make(inp, add_axis_param); + auto target_shape = opr::Concat::make({batch, shp}, 0, cn); + return opr::Broadcast::make(out, target_shape); + } else { + return inp; + } + }; + inp1 = expand_shape(inp1); + inp2 = expand_shape(inp2); + + auto result = opr::BatchedMatrixMul::make( + inp1, inp2, matmul.param(), matmul.policy(), config); + size_t max_dim = std::max(dim1, dim2); + + if (max_dim > 3) { auto idx = opr::ImmutableTensor::make(*graph, vi, config); - auto result_shape = result.symshape(); + auto res_shp = result.symshape(); IndexDesc tail_desc(1); tail_desc[0].begin = idx; - auto shp_tail = opr::Subtensor::make(result_shape, tail_desc); - auto tshp = opr::Concat::make({shp2_head, shp_tail}, 0, cn); + auto tail_shape = opr::Subtensor::make(res_shp, tail_desc); + auto tshp = opr::Concat::make({shp_head, tail_shape}, 0, cn); result = result.reshape(tshp); } - return result; } diff --git a/imperative/src/test/helper.cpp b/imperative/src/test/helper.cpp index 5aa8f7ff..52e51395 100644 --- a/imperative/src/test/helper.cpp +++ b/imperative/src/test/helper.cpp @@ -150,6 +150,31 @@ void OprChecker::run(std::vector inp_keys, std::set bypass) { } } +VarNodeArray OprChecker::run_apply_on_var_node(std::vector inp_keys) { + HostTensorGenerator<> gen; + size_t nr_inps = inp_keys.size(); + SmallVector host_inp(nr_inps); + VarNodeArray sym_inp(nr_inps); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + for (size_t i = 0; i < nr_inps; ++i) { + // TODO: remove std::visit for support osx 10.12 + host_inp[i] = std::visit( + [&gen](auto&& arg) -> HostTensorND { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return *gen(arg); + } else { + static_assert(std::is_same_v); + return arg; + } + }, + inp_keys[i]); + sym_inp[i] = opr::SharedDeviceTensor::make(*graph, host_inp[i]).node(); + } + return OpDef::apply_on_var_node(*m_op, sym_inp); +} + TEST(TestHelper, PyModule) { py::module m = PyEnv::get(); py::print(m); diff --git a/imperative/src/test/helper.h b/imperative/src/test/helper.h index 2cd0a460..1700dd4d 100644 --- a/imperative/src/test/helper.h +++ b/imperative/src/test/helper.h @@ -14,6 +14,9 @@ public: OprChecker(std::shared_ptr opdef); void run(std::vector inp_shapes, std::set bypass = {}); + //! test the interface of apply_on_var_node + VarNodeArray run_apply_on_var_node(std::vector inp_shapes); + private: std::shared_ptr m_op; }; diff --git a/imperative/src/test/imperative.cpp b/imperative/src/test/imperative.cpp index a7f6023b..c702b33c 100644 --- a/imperative/src/test/imperative.cpp +++ b/imperative/src/test/imperative.cpp @@ -1,9 +1,11 @@ #include "./helper.h" #include "megbrain/comp_node_env.h" #include "megbrain/imperative/blob_manager.h" +#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/opr/basic_arith.h" #include "megbrain/opr/basic_arith_wrapper.h" +#include "megbrain/opr/blas.h" #include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/tensor_manip.h" @@ -164,4 +166,58 @@ TEST(TestImperative, Defragment) { } #endif // MGB_CUDA && MGB_ENABLE_EXCEPTION +TEST(TestImperative, MatrixMulApplyOnVarNode) { + using Param = opr::MatrixMul::Param; + Param param; + std::vector> shapes; + std::vector target_shapes; + std::vector params; + //! testcase 0 + params.push_back(param); + shapes.push_back({TensorShape{10, 5}, TensorShape{5, 10}}); + target_shapes.push_back(TensorShape{10, 10}); + //! testcase 1 + params.push_back(param); + shapes.push_back({TensorShape{3, 10, 5}, TensorShape{5, 10}}); + target_shapes.push_back(TensorShape{3, 10, 10}); + //! testcase 2 + param.transposeA = true; + param.transposeB = false; + params.push_back(param); + shapes.push_back({TensorShape{3, 7, 6}, TensorShape{7, 10}}); + target_shapes.push_back(TensorShape{3, 6, 10}); + //! testcase 3 + param.transposeA = true; + param.transposeB = false; + params.push_back(param); + shapes.push_back({TensorShape{2, 3, 7, 6}, TensorShape{7, 10}}); + target_shapes.push_back(TensorShape{2, 3, 6, 10}); + //! testcase 4 + param.transposeA = false; + param.transposeB = true; + params.push_back(param); + shapes.push_back({TensorShape{2, 3, 7, 6}, TensorShape{2, 3, 8, 6}}); + target_shapes.push_back(TensorShape{2, 3, 7, 8}); + + //! testcase 5 + param.transposeA = false; + param.transposeB = true; + params.push_back(param); + shapes.push_back({TensorShape{2, 3, 7, 6}, TensorShape{8, 6}}); + target_shapes.push_back(TensorShape{2, 3, 7, 8}); + + for (size_t i = 0; i < params.size(); i++) { + auto& shape = shapes[i]; + auto op = MatrixMul::make( + params[i], ::megdnn::param::ExecutionPolicy{}, shape.first.ndim, + shape.second.ndim); + auto result = OprChecker(op).run_apply_on_var_node({shape.first, shape.second}); + ASSERT_GT(result.size(), 0); + ASSERT_EQ(target_shapes[i].ndim, result[0]->shape().ndim); + for (size_t id = 0; id < target_shapes[i].ndim; id++) { + ASSERT_EQ(target_shapes[i][id], result[0]->shape()[id]); + } + } +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}