diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index e8a1c31f..b9af63a6 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -78,9 +78,9 @@ class Graph(_imperative_rt.ComputingGraph): opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self) return opnode.outputs[0] - def make_h2d(self, *, dtype, device): + def make_h2d(self, *, dtype, device, shape=None, name=None): device = as_device(device).to_c() - return self._wrap(_imperative_rt.make_h2d(self, device, dtype)) + return self._wrap(_imperative_rt.make_h2d(self, device, dtype, shape, name)) def dump(*args): diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index ca71c2c9..b3a1af89 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -51,6 +51,7 @@ class TensorInfo: "value_read", "device", "dtype", + "shape", "bound_data", # resources for execution "varnode", @@ -107,8 +108,8 @@ class trace: self._active_tensors = weakref.WeakSet() self._tensor_remaps = None self._inputs_to_restore = None - self._args_bindings = None - self._kwargs_bindings = None + self._arg_bindings = None + self._kwarg_bindings = None self._output_bindings = None self._output_names = None @@ -329,9 +330,7 @@ class trace: links = () if self._capture_as_const: - for h in itertools.chain( - self._args_bindings, self._kwargs_bindings.values() - ): + for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()): info = self._tinfo[h] opnode = info.data_setter = G.InputNode( device=info.device, dtype=info.dtype, graph=graph @@ -434,15 +433,19 @@ class trace: h2v = {} graph = G.Graph() - for i, h in enumerate(self._args_bindings): + for i, h in enumerate(self._arg_bindings): info = self._tinfo[h] - h2v[h] = graph.make_h2d(dtype=info.dtype, device=info.device) - if arg_names: - h2v[h].name = arg_names[i] - for k, h in self._kwargs_bindings.items(): + h2v[h] = graph.make_h2d( + dtype=info.dtype, + device=info.device, + shape=info.shape, + name=arg_names[i] if arg_names else None, + ) + for k, h in self._kwarg_bindings.items(): info = self._tinfo[h] - h2v[h] = graph.make_h2d(dtype=info.dtype, device=info.device) - h2v[h].name = k + h2v[h] = graph.make_h2d( + dtype=info.dtype, device=info.device, shape=info.shape, name=k + ) for op, ihandles, ohandles in self._seq: ivars = [] @@ -479,11 +482,12 @@ class trace: info.external = False info.device = x.device info.dtype = x.dtype + info.shape = x.shape TraceMixin._TraceMixin__inject(x, h) self._inputs_to_restore.append(x) return h - self._args_bindings = [] + self._arg_bindings = [] for i, x in enumerate(args): x = find_raw_tensor(x) if x is None: @@ -491,20 +495,20 @@ class trace: "positional arguments should all be tensor " "but args[%d] cannot be recognized as one" % i ) - self._args_bindings.append(record_input(x)) + self._arg_bindings.append(record_input(x)) - self._kwargs_bindings = {} + self._kwarg_bindings = {} for k, x in kwargs.items(): x = find_raw_tensor(x) if x is not None: - self._kwargs_bindings[k] = record_input(x) + self._kwarg_bindings[k] = record_input(x) else: - if len(args) != len(self._args_bindings): + if len(args) != len(self._arg_bindings): raise TraceMismatchError("positional argument length mismatch") self._tensor_remaps = {} - for i, (h, x) in enumerate(zip(self._args_bindings, args)): + for i, (h, x) in enumerate(zip(self._arg_bindings, args)): x = find_raw_tensor(x) if x is None: raise TypeError( @@ -524,9 +528,9 @@ class trace: x = find_raw_tensor(x) if x is not None: kwargs_tensors[k] = x - if set(kwargs_tensors) != set(self._kwargs_bindings): - too_many = set(kwargs_tensors) - set(self._kwargs_bindings) - too_few = set(self._kwargs_bindings) - set(kwargs_tensors) + if set(kwargs_tensors) != set(self._kwarg_bindings): + too_many = set(kwargs_tensors) - set(self._kwarg_bindings) + too_few = set(self._kwarg_bindings) - set(kwargs_tensors) if too_many: raise TraceMismatchError( "keyword arguments found to be tensor this time " @@ -537,7 +541,7 @@ class trace: "keyword arguments found to be non-tensor this time " "but were tensor previously: %s" % " ".join(too_few) ) - for k, h in self._kwargs_bindings.items(): + for k, h in self._kwarg_bindings.items(): x = kwargs_tensors[k] info = self._tinfo[h] if x.dtype != info.dtype: diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index 500d8dd3..f6073644 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -237,7 +237,7 @@ void init_graph_rt(py::module m) { return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); }); - m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, std::optional name) { + m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, TensorShape shape, std::optional name) { if (!cn.valid()) { throw py::type_error("device must be valid"); } @@ -248,8 +248,8 @@ void init_graph_rt(py::module m) { if (name) { config.name(*name); } - return opr::Host2DeviceCopy::make(graph, std::make_shared(cn, dtype), config).node(); - }, py::arg(), py::arg(), py::arg(), py::arg() = py::none()); + return opr::Host2DeviceCopy::make(graph, std::make_shared(cn, shape, dtype), config).node(); + }, py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::none()); m.def("input_callback", [input_callback](std::function callback, const CompNode& comp_node, diff --git a/imperative/python/test/integration/test_trace_dump.py b/imperative/python/test/integration/test_trace_dump.py new file mode 100644 index 00000000..6af911d8 --- /dev/null +++ b/imperative/python/test/integration/test_trace_dump.py @@ -0,0 +1,136 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +import contextlib +import os +import tempfile + +import numpy as np + +import megengine as mge +import megengine.functional as F +import megengine.module as M +import megengine.optimizer as optim +from megengine import tensor +from megengine.jit import trace + + +@contextlib.contextmanager +def mkstemp(): + fd, path = tempfile.mkstemp() + try: + os.close(fd) + yield path + finally: + os.remove(path) + + +def minibatch_generator(batch_size): + while True: + inp_data = np.zeros((batch_size, 2)) + label = np.zeros(batch_size, dtype=np.int32) + for i in range(batch_size): + inp_data[i, :] = np.random.rand(2) * 2 - 1 + label[i] = 1 if np.prod(inp_data[i]) < 0 else 0 + yield {"data": inp_data.astype(np.float32), "label": label.astype(np.int32)} + + +class XORNet(M.Module): + def __init__(self): + self.mid_dim = 14 + self.num_class = 2 + super().__init__() + self.fc0 = M.Linear(self.num_class, self.mid_dim, bias=True) + self.fc1 = M.Linear(self.mid_dim, self.mid_dim, bias=True) + self.fc2 = M.Linear(self.mid_dim, self.num_class, bias=True) + + def forward(self, x): + x = self.fc0(x) + x = F.tanh(x) + x = self.fc1(x) + x = F.tanh(x) + x = self.fc2(x) + return x + + +def test_xornet_trace_dump(): + net = XORNet() + opt = optim.SGD(net.parameters(requires_grad=True), lr=0.01, momentum=0.9) + batch_size = 64 + train_dataset = minibatch_generator(batch_size) + val_dataset = minibatch_generator(batch_size) + + @trace + def train_fun(data, label): + with opt.record(): + net.train() + pred = net(data) + loss = F.cross_entropy_with_softmax(pred, label) + opt.backward(loss) + return pred, loss + + @trace + def val_fun(data, label): + net.eval() + pred = net(data) + loss = F.cross_entropy_with_softmax(pred, label) + return pred, loss + + @trace(symbolic=True, capture_as_const=True) + def pred_fun(data): + net.eval() + pred = net(data) + pred_normalized = F.softmax(pred) + return pred_normalized + + train_loss = [] + val_loss = [] + for step, minibatch in enumerate(train_dataset): + if step > 100: + break + data = tensor(minibatch["data"]) + label = tensor(minibatch["label"]) + opt.zero_grad() + _, loss = train_fun(data, label) + train_loss.append((step, loss.numpy())) + if step % 50 == 0: + minibatch = next(val_dataset) + _, loss = val_fun(data, label) + loss = loss.numpy()[0] + val_loss.append((step, loss)) + print("Step: {} loss={}".format(step, loss)) + opt.step() + + test_data = np.array( + [ + (0.5, 0.5), + (0.3, 0.7), + (0.1, 0.9), + (-0.5, -0.5), + (-0.3, -0.7), + (-0.9, -0.1), + (0.5, -0.5), + (0.3, -0.7), + (0.9, -0.1), + (-0.5, 0.5), + (-0.3, 0.7), + (-0.1, 0.9), + ] + ) + + data = tensor(test_data.astype(np.float32)) + out = pred_fun(data) + pred_output = out.numpy() + pred_label = np.argmax(pred_output, 1) + + with np.printoptions(precision=4, suppress=True): + print("Predicated probability:") + print(pred_output) + + with mkstemp() as out: + pred_fun.dump(out, arg_names=["data"], output_names=["label"])