Browse Source

fix(mgb/imperative): fix repeat bug in trace mode

GitOrigin-RevId: 9547fc6102
HuaHua404-patch-4
Megvii Engine Team 2 years ago
parent
commit
bccda5c427
2 changed files with 28 additions and 4 deletions
  1. +16
    -4
      imperative/python/src/tensor_utils.cpp
  2. +12
    -0
      imperative/python/test/unit/jit/test_tracing.py

+ 16
- 4
imperative/python/src/tensor_utils.cpp View File

@@ -530,11 +530,23 @@ py::object _astensor1d_cpp(
return get_res_by_refhdl(value, dtype, device, ref);
}
if (lis.size() > 1) {
std::vector<PyObject*> c_args(lis.size() + 1);
for (size_t i = 0; i < lis.size(); ++i) {
c_args[i] = lis[i].ptr();
py::list flat_list;
for (auto item : lis) {
if (!PyList_Check(item.ptr())) {
flat_list.append(item);
} else {
py::list sub_lis =
py::reinterpret_steal<py::list>(PySequence_List(item.ptr()));
for (auto sub_item : sub_lis) {
flat_list.append(sub_item);
}
}
}
std::vector<PyObject*> c_args(flat_list.size() + 1);
for (size_t i = 0; i < flat_list.size(); ++i) {
c_args[i] = flat_list[i].ptr();
}
c_args[lis.size()] = Py_None;
c_args[flat_list.size()] = Py_None;
py::tuple inp_tup = py::reinterpret_steal<py::tuple>(
convert_inputs_cpp(NULL, c_args.data(), c_args.size()));
if (device_obj.is_none()) {


+ 12
- 0
imperative/python/test/unit/jit/test_tracing.py View File

@@ -161,6 +161,18 @@ def test_elemwise_fuse_in_grad(trace_mode):
y.numpy()


def test_repeat_in_trace():
@trace(symbolic=False)
def fun(data, repeats):
F.repeat(data, repeats)

data = tensor(np.random.random([1, 2, 3]).astype(np.float32))

for i in range(1, 5):
repeats = tensor(i)
fun(data, repeats)


def test_print_in_trace():
for symbolic in [False]: # cannot read value in symbolic mode



Loading…
Cancel
Save