Browse Source

fix(imperative): fix the matrix mul error when symbolic trace

GitOrigin-RevId: 3754dc5a65
master
Megvii Engine Team 2 years ago
parent
commit
7f28005847
5 changed files with 180 additions and 44 deletions
  1. +34
    -0
      imperative/python/test/integration/test_trace_dump.py
  2. +62
    -44
      imperative/src/impl/ops/matmul.cpp
  3. +25
    -0
      imperative/src/test/helper.cpp
  4. +3
    -0
      imperative/src/test/helper.h
  5. +56
    -0
      imperative/src/test/imperative.cpp

+ 34
- 0
imperative/python/test/integration/test_trace_dump.py View File

@@ -12,6 +12,7 @@ import megengine.optimizer as optim
from megengine import tensor from megengine import tensor
from megengine.autodiff import GradManager from megengine.autodiff import GradManager
from megengine.jit import trace from megengine.jit import trace
from megengine.optimizer import SGD




@contextlib.contextmanager @contextlib.contextmanager
@@ -138,3 +139,36 @@ def test_dump_bn_train_mode():
bn_train(data) bn_train(data)
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
bn_train.dump("test.mge") bn_train.dump("test.mge")


class ViTmode(M.Module):
def __init__(self, patch_size=16, in_chans=3, embed_dim=384):
super().__init__()
self.proj = M.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
self.head = M.Linear(embed_dim, 1000)

def forward(self, x):
x = self.proj(x)
x = F.flatten(x, 2).transpose(0, 2, 1)
x = self.head(x)
return x


def test_ViTmode_trace_train():
model = ViTmode(embed_dim=384)
data = mge.random.normal(size=(1, 3, 224, 224))
optim = SGD(model.parameters(), lr=0.01)
gm = GradManager()
gm.attach(model.parameters())

@trace(symbolic=True, capture_as_const=True)
def train():
for i in range(2):
with gm:
loss = model(data)
gm.backward(loss)
optim.step().clear_grad()

train()

+ 62
- 44
imperative/src/impl/ops/matmul.cpp View File

@@ -22,62 +22,80 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
auto inp1 = SymbolVar{inputs[0]}, inp2 = SymbolVar{inputs[1]}; auto inp1 = SymbolVar{inputs[0]}, inp2 = SymbolVar{inputs[1]};
auto dim1 = matmul.dimA, dim2 = matmul.dimB; auto dim1 = matmul.dimA, dim2 = matmul.dimB;
mgb_assert(
dim1 >= 2 && dim2 >= 2,
"the dim of one input the matmul operator dim is less than 2.");


auto cn = inputs[0]->comp_node(); auto cn = inputs[0]->comp_node();
using IndexDesc = opr::Subtensor::IndexDesc; using IndexDesc = opr::Subtensor::IndexDesc;
OperatorNodeConfig config{matmul.make_name(), cn}; OperatorNodeConfig config{matmul.make_name(), cn};


DTypeScalar vi{-1};
auto graph = inputs[0]->owner_graph(); auto graph = inputs[0]->owner_graph();

SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail;
if (dim1 > 2) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto shp1 = inp1.symshape();
IndexDesc head_desc(1);
head_desc[0].end = idx;
shp1_head = opr::Subtensor::make(shp1, head_desc);
auto batch = opr::Reduce::make(shp1_head, {Reduce::Mode::PRODUCT, 0});
IndexDesc tail_desc(1);
tail_desc[0].begin = idx;
shp1_tail = opr::Subtensor::make(shp1, tail_desc);
auto tshp = opr::Concat::make({batch, shp1_tail}, 0, cn);
inp1 = inp1.reshape(tshp);
}
if (dim2 > 2) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto shp2 = inp2.symshape();
IndexDesc head_desc(1);
head_desc[0].end = idx;
shp2_head = opr::Subtensor::make(shp2, head_desc);
auto batch = opr::Reduce::make(shp2_head, {Reduce::Mode::PRODUCT, 0});
IndexDesc tail_desc(1);
tail_desc[0].begin = idx;
auto shp2_tail = opr::Subtensor::make(shp2, tail_desc);
auto tshp = opr::Concat::make({batch, shp2_tail}, 0, cn);
inp2 = inp2.reshape(tshp);
if (dim1 == 2 && dim2 == 2) {
return opr::MatrixMul::make(
inp1, inp2, matmul.param(), matmul.policy(), config);
} }
auto result =
opr::MatrixMul::make(inp1, inp2, matmul.param(), matmul.policy(), config);
if (dim1 > 2) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto result_shape = result.symshape();
IndexDesc tail_desc(1);
tail_desc[0].begin = idx;
auto shp_tail = opr::Subtensor::make(result_shape, tail_desc);
auto tshp = opr::Concat::make({shp1_head, shp_tail}, 0, cn);
result = result.reshape(tshp);
}
if (dim2 > 2) {
//! use batched matrix mul
SymbolVar shp_head, batch;
DTypeScalar vi{-2};
auto compress_shape = [&](SymbolVar inp) {
if (inp.shape().ndim > 3) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto shp = inp.symshape();
IndexDesc head_desc(1);
head_desc[0].end = idx;
shp_head = opr::Subtensor::make(shp, head_desc);
batch = opr::Reduce::make(shp_head, {Reduce::Mode::PRODUCT, 0});
IndexDesc tail_desc(1);
tail_desc[0].begin = idx;
auto shp_tail = opr::Subtensor::make(shp, tail_desc);
auto tshp = opr::Concat::make({batch, shp_tail}, 0, cn);
return inp.reshape(tshp);
} else if (inp.shape().ndim == 3) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto shp = inp.symshape();
IndexDesc head_desc(1);
head_desc[0].end = idx;
shp_head = opr::Subtensor::make(shp, head_desc);
batch = opr::Reduce::make(shp_head, {Reduce::Mode::PRODUCT, 0});
return inp;
} else {
return inp;
}
};

inp1 = compress_shape(inp1);
inp2 = compress_shape(inp2);

auto expand_shape = [&](SymbolVar inp) {
if (inp.shape().ndim < 3) {
auto shp = inp.symshape();
using Desc = opr::AxisAddRemove::AxisDesc;
std::vector<Desc> add_axis_param;
add_axis_param.push_back(Desc::make_add(0));
auto out = opr::AxisAddRemove::make(inp, add_axis_param);
auto target_shape = opr::Concat::make({batch, shp}, 0, cn);
return opr::Broadcast::make(out, target_shape);
} else {
return inp;
}
};
inp1 = expand_shape(inp1);
inp2 = expand_shape(inp2);

auto result = opr::BatchedMatrixMul::make(
inp1, inp2, matmul.param(), matmul.policy(), config);
size_t max_dim = std::max(dim1, dim2);

if (max_dim > 3) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config); auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto result_shape = result.symshape();
auto res_shp = result.symshape();
IndexDesc tail_desc(1); IndexDesc tail_desc(1);
tail_desc[0].begin = idx; tail_desc[0].begin = idx;
auto shp_tail = opr::Subtensor::make(result_shape, tail_desc);
auto tshp = opr::Concat::make({shp2_head, shp_tail}, 0, cn);
auto tail_shape = opr::Subtensor::make(res_shp, tail_desc);
auto tshp = opr::Concat::make({shp_head, tail_shape}, 0, cn);
result = result.reshape(tshp); result = result.reshape(tshp);
} }

return result; return result;
} }




+ 25
- 0
imperative/src/test/helper.cpp View File

@@ -150,6 +150,31 @@ void OprChecker::run(std::vector<InputSpec> inp_keys, std::set<size_t> bypass) {
} }
} }


VarNodeArray OprChecker::run_apply_on_var_node(std::vector<InputSpec> inp_keys) {
HostTensorGenerator<> gen;
size_t nr_inps = inp_keys.size();
SmallVector<HostTensorND> host_inp(nr_inps);
VarNodeArray sym_inp(nr_inps);
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
for (size_t i = 0; i < nr_inps; ++i) {
// TODO: remove std::visit for support osx 10.12
host_inp[i] = std::visit(
[&gen](auto&& arg) -> HostTensorND {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<TensorShape, T>) {
return *gen(arg);
} else {
static_assert(std::is_same_v<HostTensorND, T>);
return arg;
}
},
inp_keys[i]);
sym_inp[i] = opr::SharedDeviceTensor::make(*graph, host_inp[i]).node();
}
return OpDef::apply_on_var_node(*m_op, sym_inp);
}

TEST(TestHelper, PyModule) { TEST(TestHelper, PyModule) {
py::module m = PyEnv::get(); py::module m = PyEnv::get();
py::print(m); py::print(m);


+ 3
- 0
imperative/src/test/helper.h View File

@@ -14,6 +14,9 @@ public:
OprChecker(std::shared_ptr<OpDef> opdef); OprChecker(std::shared_ptr<OpDef> opdef);
void run(std::vector<InputSpec> inp_shapes, std::set<size_t> bypass = {}); void run(std::vector<InputSpec> inp_shapes, std::set<size_t> bypass = {});


//! test the interface of apply_on_var_node
VarNodeArray run_apply_on_var_node(std::vector<InputSpec> inp_shapes);

private: private:
std::shared_ptr<OpDef> m_op; std::shared_ptr<OpDef> m_op;
}; };


+ 56
- 0
imperative/src/test/imperative.cpp View File

@@ -1,9 +1,11 @@
#include "./helper.h" #include "./helper.h"
#include "megbrain/comp_node_env.h" #include "megbrain/comp_node_env.h"
#include "megbrain/imperative/blob_manager.h" #include "megbrain/imperative/blob_manager.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/basic_arith_wrapper.h" #include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/tensor_manip.h"
@@ -164,4 +166,58 @@ TEST(TestImperative, Defragment) {
} }
#endif // MGB_CUDA && MGB_ENABLE_EXCEPTION #endif // MGB_CUDA && MGB_ENABLE_EXCEPTION


TEST(TestImperative, MatrixMulApplyOnVarNode) {
using Param = opr::MatrixMul::Param;
Param param;
std::vector<std::pair<TensorShape, TensorShape>> shapes;
std::vector<TensorShape> target_shapes;
std::vector<Param> params;
//! testcase 0
params.push_back(param);
shapes.push_back({TensorShape{10, 5}, TensorShape{5, 10}});
target_shapes.push_back(TensorShape{10, 10});
//! testcase 1
params.push_back(param);
shapes.push_back({TensorShape{3, 10, 5}, TensorShape{5, 10}});
target_shapes.push_back(TensorShape{3, 10, 10});
//! testcase 2
param.transposeA = true;
param.transposeB = false;
params.push_back(param);
shapes.push_back({TensorShape{3, 7, 6}, TensorShape{7, 10}});
target_shapes.push_back(TensorShape{3, 6, 10});
//! testcase 3
param.transposeA = true;
param.transposeB = false;
params.push_back(param);
shapes.push_back({TensorShape{2, 3, 7, 6}, TensorShape{7, 10}});
target_shapes.push_back(TensorShape{2, 3, 6, 10});
//! testcase 4
param.transposeA = false;
param.transposeB = true;
params.push_back(param);
shapes.push_back({TensorShape{2, 3, 7, 6}, TensorShape{2, 3, 8, 6}});
target_shapes.push_back(TensorShape{2, 3, 7, 8});

//! testcase 5
param.transposeA = false;
param.transposeB = true;
params.push_back(param);
shapes.push_back({TensorShape{2, 3, 7, 6}, TensorShape{8, 6}});
target_shapes.push_back(TensorShape{2, 3, 7, 8});

for (size_t i = 0; i < params.size(); i++) {
auto& shape = shapes[i];
auto op = MatrixMul::make(
params[i], ::megdnn::param::ExecutionPolicy{}, shape.first.ndim,
shape.second.ndim);
auto result = OprChecker(op).run_apply_on_var_node({shape.first, shape.second});
ASSERT_GT(result.size(), 0);
ASSERT_EQ(target_shapes[i].ndim, result[0]->shape().ndim);
for (size_t id = 0; id < target_shapes[i].ndim; id++) {
ASSERT_EQ(target_shapes[i][id], result[0]->shape()[id]);
}
}
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

Loading…
Cancel
Save