Browse Source

Revert "fix(imperative): fix the matrix mul error when symbolic trace"

This reverts commit 3754dc5a65.

GitOrigin-RevId: 1880b21c10
master
Megvii Engine Team 2 years ago
parent
commit
8fe5e74f5a
5 changed files with 44 additions and 180 deletions
  1. +0
    -34
      imperative/python/test/integration/test_trace_dump.py
  2. +44
    -62
      imperative/src/impl/ops/matmul.cpp
  3. +0
    -25
      imperative/src/test/helper.cpp
  4. +0
    -3
      imperative/src/test/helper.h
  5. +0
    -56
      imperative/src/test/imperative.cpp

+ 0
- 34
imperative/python/test/integration/test_trace_dump.py View File

@@ -12,7 +12,6 @@ import megengine.optimizer as optim
from megengine import tensor from megengine import tensor
from megengine.autodiff import GradManager from megengine.autodiff import GradManager
from megengine.jit import trace from megengine.jit import trace
from megengine.optimizer import SGD




@contextlib.contextmanager @contextlib.contextmanager
@@ -139,36 +138,3 @@ def test_dump_bn_train_mode():
bn_train(data) bn_train(data)
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
bn_train.dump("test.mge") bn_train.dump("test.mge")


class ViTmode(M.Module):
def __init__(self, patch_size=16, in_chans=3, embed_dim=384):
super().__init__()
self.proj = M.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
self.head = M.Linear(embed_dim, 1000)

def forward(self, x):
x = self.proj(x)
x = F.flatten(x, 2).transpose(0, 2, 1)
x = self.head(x)
return x


def test_ViTmode_trace_train():
model = ViTmode(embed_dim=384)
data = mge.random.normal(size=(1, 3, 224, 224))
optim = SGD(model.parameters(), lr=0.01)
gm = GradManager()
gm.attach(model.parameters())

@trace(symbolic=True, capture_as_const=True)
def train():
for i in range(2):
with gm:
loss = model(data)
gm.backward(loss)
optim.step().clear_grad()

train()

+ 44
- 62
imperative/src/impl/ops/matmul.cpp View File

@@ -22,80 +22,62 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
auto inp1 = SymbolVar{inputs[0]}, inp2 = SymbolVar{inputs[1]}; auto inp1 = SymbolVar{inputs[0]}, inp2 = SymbolVar{inputs[1]};
auto dim1 = matmul.dimA, dim2 = matmul.dimB; auto dim1 = matmul.dimA, dim2 = matmul.dimB;
mgb_assert(
dim1 >= 2 && dim2 >= 2,
"the dim of one input the matmul operator dim is less than 2.");


auto cn = inputs[0]->comp_node(); auto cn = inputs[0]->comp_node();
using IndexDesc = opr::Subtensor::IndexDesc; using IndexDesc = opr::Subtensor::IndexDesc;
OperatorNodeConfig config{matmul.make_name(), cn}; OperatorNodeConfig config{matmul.make_name(), cn};


DTypeScalar vi{-1};
auto graph = inputs[0]->owner_graph(); auto graph = inputs[0]->owner_graph();
if (dim1 == 2 && dim2 == 2) {
return opr::MatrixMul::make(
inp1, inp2, matmul.param(), matmul.policy(), config);
}
//! use batched matrix mul
SymbolVar shp_head, batch;
DTypeScalar vi{-2};
auto compress_shape = [&](SymbolVar inp) {
if (inp.shape().ndim > 3) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto shp = inp.symshape();
IndexDesc head_desc(1);
head_desc[0].end = idx;
shp_head = opr::Subtensor::make(shp, head_desc);
batch = opr::Reduce::make(shp_head, {Reduce::Mode::PRODUCT, 0});
IndexDesc tail_desc(1);
tail_desc[0].begin = idx;
auto shp_tail = opr::Subtensor::make(shp, tail_desc);
auto tshp = opr::Concat::make({batch, shp_tail}, 0, cn);
return inp.reshape(tshp);
} else if (inp.shape().ndim == 3) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto shp = inp.symshape();
IndexDesc head_desc(1);
head_desc[0].end = idx;
shp_head = opr::Subtensor::make(shp, head_desc);
batch = opr::Reduce::make(shp_head, {Reduce::Mode::PRODUCT, 0});
return inp;
} else {
return inp;
}
};

inp1 = compress_shape(inp1);
inp2 = compress_shape(inp2);

auto expand_shape = [&](SymbolVar inp) {
if (inp.shape().ndim < 3) {
auto shp = inp.symshape();
using Desc = opr::AxisAddRemove::AxisDesc;
std::vector<Desc> add_axis_param;
add_axis_param.push_back(Desc::make_add(0));
auto out = opr::AxisAddRemove::make(inp, add_axis_param);
auto target_shape = opr::Concat::make({batch, shp}, 0, cn);
return opr::Broadcast::make(out, target_shape);
} else {
return inp;
}
};
inp1 = expand_shape(inp1);
inp2 = expand_shape(inp2);

auto result = opr::BatchedMatrixMul::make(
inp1, inp2, matmul.param(), matmul.policy(), config);
size_t max_dim = std::max(dim1, dim2);


if (max_dim > 3) {
SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail;
if (dim1 > 2) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config); auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto res_shp = result.symshape();
auto shp1 = inp1.symshape();
IndexDesc head_desc(1);
head_desc[0].end = idx;
shp1_head = opr::Subtensor::make(shp1, head_desc);
auto batch = opr::Reduce::make(shp1_head, {Reduce::Mode::PRODUCT, 0});
IndexDesc tail_desc(1); IndexDesc tail_desc(1);
tail_desc[0].begin = idx; tail_desc[0].begin = idx;
auto tail_shape = opr::Subtensor::make(res_shp, tail_desc);
auto tshp = opr::Concat::make({shp_head, tail_shape}, 0, cn);
shp1_tail = opr::Subtensor::make(shp1, tail_desc);
auto tshp = opr::Concat::make({batch, shp1_tail}, 0, cn);
inp1 = inp1.reshape(tshp);
}
if (dim2 > 2) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto shp2 = inp2.symshape();
IndexDesc head_desc(1);
head_desc[0].end = idx;
shp2_head = opr::Subtensor::make(shp2, head_desc);
auto batch = opr::Reduce::make(shp2_head, {Reduce::Mode::PRODUCT, 0});
IndexDesc tail_desc(1);
tail_desc[0].begin = idx;
auto shp2_tail = opr::Subtensor::make(shp2, tail_desc);
auto tshp = opr::Concat::make({batch, shp2_tail}, 0, cn);
inp2 = inp2.reshape(tshp);
}
auto result =
opr::MatrixMul::make(inp1, inp2, matmul.param(), matmul.policy(), config);
if (dim1 > 2) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto result_shape = result.symshape();
IndexDesc tail_desc(1);
tail_desc[0].begin = idx;
auto shp_tail = opr::Subtensor::make(result_shape, tail_desc);
auto tshp = opr::Concat::make({shp1_head, shp_tail}, 0, cn);
result = result.reshape(tshp);
}
if (dim2 > 2) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto result_shape = result.symshape();
IndexDesc tail_desc(1);
tail_desc[0].begin = idx;
auto shp_tail = opr::Subtensor::make(result_shape, tail_desc);
auto tshp = opr::Concat::make({shp2_head, shp_tail}, 0, cn);
result = result.reshape(tshp); result = result.reshape(tshp);
} }

return result; return result;
} }




+ 0
- 25
imperative/src/test/helper.cpp View File

@@ -150,31 +150,6 @@ void OprChecker::run(std::vector<InputSpec> inp_keys, std::set<size_t> bypass) {
} }
} }


VarNodeArray OprChecker::run_apply_on_var_node(std::vector<InputSpec> inp_keys) {
HostTensorGenerator<> gen;
size_t nr_inps = inp_keys.size();
SmallVector<HostTensorND> host_inp(nr_inps);
VarNodeArray sym_inp(nr_inps);
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
for (size_t i = 0; i < nr_inps; ++i) {
// TODO: remove std::visit for support osx 10.12
host_inp[i] = std::visit(
[&gen](auto&& arg) -> HostTensorND {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<TensorShape, T>) {
return *gen(arg);
} else {
static_assert(std::is_same_v<HostTensorND, T>);
return arg;
}
},
inp_keys[i]);
sym_inp[i] = opr::SharedDeviceTensor::make(*graph, host_inp[i]).node();
}
return OpDef::apply_on_var_node(*m_op, sym_inp);
}

TEST(TestHelper, PyModule) { TEST(TestHelper, PyModule) {
py::module m = PyEnv::get(); py::module m = PyEnv::get();
py::print(m); py::print(m);


+ 0
- 3
imperative/src/test/helper.h View File

@@ -14,9 +14,6 @@ public:
OprChecker(std::shared_ptr<OpDef> opdef); OprChecker(std::shared_ptr<OpDef> opdef);
void run(std::vector<InputSpec> inp_shapes, std::set<size_t> bypass = {}); void run(std::vector<InputSpec> inp_shapes, std::set<size_t> bypass = {});


//! test the interface of apply_on_var_node
VarNodeArray run_apply_on_var_node(std::vector<InputSpec> inp_shapes);

private: private:
std::shared_ptr<OpDef> m_op; std::shared_ptr<OpDef> m_op;
}; };


+ 0
- 56
imperative/src/test/imperative.cpp View File

@@ -1,11 +1,9 @@
#include "./helper.h" #include "./helper.h"
#include "megbrain/comp_node_env.h" #include "megbrain/comp_node_env.h"
#include "megbrain/imperative/blob_manager.h" #include "megbrain/imperative/blob_manager.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/basic_arith_wrapper.h" #include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/tensor_manip.h"
@@ -166,58 +164,4 @@ TEST(TestImperative, Defragment) {
} }
#endif // MGB_CUDA && MGB_ENABLE_EXCEPTION #endif // MGB_CUDA && MGB_ENABLE_EXCEPTION


TEST(TestImperative, MatrixMulApplyOnVarNode) {
using Param = opr::MatrixMul::Param;
Param param;
std::vector<std::pair<TensorShape, TensorShape>> shapes;
std::vector<TensorShape> target_shapes;
std::vector<Param> params;
//! testcase 0
params.push_back(param);
shapes.push_back({TensorShape{10, 5}, TensorShape{5, 10}});
target_shapes.push_back(TensorShape{10, 10});
//! testcase 1
params.push_back(param);
shapes.push_back({TensorShape{3, 10, 5}, TensorShape{5, 10}});
target_shapes.push_back(TensorShape{3, 10, 10});
//! testcase 2
param.transposeA = true;
param.transposeB = false;
params.push_back(param);
shapes.push_back({TensorShape{3, 7, 6}, TensorShape{7, 10}});
target_shapes.push_back(TensorShape{3, 6, 10});
//! testcase 3
param.transposeA = true;
param.transposeB = false;
params.push_back(param);
shapes.push_back({TensorShape{2, 3, 7, 6}, TensorShape{7, 10}});
target_shapes.push_back(TensorShape{2, 3, 6, 10});
//! testcase 4
param.transposeA = false;
param.transposeB = true;
params.push_back(param);
shapes.push_back({TensorShape{2, 3, 7, 6}, TensorShape{2, 3, 8, 6}});
target_shapes.push_back(TensorShape{2, 3, 7, 8});

//! testcase 5
param.transposeA = false;
param.transposeB = true;
params.push_back(param);
shapes.push_back({TensorShape{2, 3, 7, 6}, TensorShape{8, 6}});
target_shapes.push_back(TensorShape{2, 3, 7, 8});

for (size_t i = 0; i < params.size(); i++) {
auto& shape = shapes[i];
auto op = MatrixMul::make(
params[i], ::megdnn::param::ExecutionPolicy{}, shape.first.ndim,
shape.second.ndim);
auto result = OprChecker(op).run_apply_on_var_node({shape.first, shape.second});
ASSERT_GT(result.size(), 0);
ASSERT_EQ(target_shapes[i].ndim, result[0]->shape().ndim);
for (size_t id = 0; id < target_shapes[i].ndim; id++) {
ASSERT_EQ(target_shapes[i][id], result[0]->shape()[id]);
}
}
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

Loading…
Cancel
Save