GitOrigin-RevId: 5e2acd4052
tags/v1.0.0-rc1
@@ -78,9 +78,9 @@ class Graph(_imperative_rt.ComputingGraph): | |||
opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self) | |||
return opnode.outputs[0] | |||
def make_h2d(self, *, dtype, device): | |||
def make_h2d(self, *, dtype, device, shape=None, name=None): | |||
device = as_device(device).to_c() | |||
return self._wrap(_imperative_rt.make_h2d(self, device, dtype)) | |||
return self._wrap(_imperative_rt.make_h2d(self, device, dtype, shape, name)) | |||
def dump(*args): | |||
@@ -51,6 +51,7 @@ class TensorInfo: | |||
"value_read", | |||
"device", | |||
"dtype", | |||
"shape", | |||
"bound_data", | |||
# resources for execution | |||
"varnode", | |||
@@ -107,8 +108,8 @@ class trace: | |||
self._active_tensors = weakref.WeakSet() | |||
self._tensor_remaps = None | |||
self._inputs_to_restore = None | |||
self._args_bindings = None | |||
self._kwargs_bindings = None | |||
self._arg_bindings = None | |||
self._kwarg_bindings = None | |||
self._output_bindings = None | |||
self._output_names = None | |||
@@ -329,9 +330,7 @@ class trace: | |||
links = () | |||
if self._capture_as_const: | |||
for h in itertools.chain( | |||
self._args_bindings, self._kwargs_bindings.values() | |||
): | |||
for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()): | |||
info = self._tinfo[h] | |||
opnode = info.data_setter = G.InputNode( | |||
device=info.device, dtype=info.dtype, graph=graph | |||
@@ -434,15 +433,19 @@ class trace: | |||
h2v = {} | |||
graph = G.Graph() | |||
for i, h in enumerate(self._args_bindings): | |||
for i, h in enumerate(self._arg_bindings): | |||
info = self._tinfo[h] | |||
h2v[h] = graph.make_h2d(dtype=info.dtype, device=info.device) | |||
if arg_names: | |||
h2v[h].name = arg_names[i] | |||
for k, h in self._kwargs_bindings.items(): | |||
h2v[h] = graph.make_h2d( | |||
dtype=info.dtype, | |||
device=info.device, | |||
shape=info.shape, | |||
name=arg_names[i] if arg_names else None, | |||
) | |||
for k, h in self._kwarg_bindings.items(): | |||
info = self._tinfo[h] | |||
h2v[h] = graph.make_h2d(dtype=info.dtype, device=info.device) | |||
h2v[h].name = k | |||
h2v[h] = graph.make_h2d( | |||
dtype=info.dtype, device=info.device, shape=info.shape, name=k | |||
) | |||
for op, ihandles, ohandles in self._seq: | |||
ivars = [] | |||
@@ -479,11 +482,12 @@ class trace: | |||
info.external = False | |||
info.device = x.device | |||
info.dtype = x.dtype | |||
info.shape = x.shape | |||
TraceMixin._TraceMixin__inject(x, h) | |||
self._inputs_to_restore.append(x) | |||
return h | |||
self._args_bindings = [] | |||
self._arg_bindings = [] | |||
for i, x in enumerate(args): | |||
x = find_raw_tensor(x) | |||
if x is None: | |||
@@ -491,20 +495,20 @@ class trace: | |||
"positional arguments should all be tensor " | |||
"but args[%d] cannot be recognized as one" % i | |||
) | |||
self._args_bindings.append(record_input(x)) | |||
self._arg_bindings.append(record_input(x)) | |||
self._kwargs_bindings = {} | |||
self._kwarg_bindings = {} | |||
for k, x in kwargs.items(): | |||
x = find_raw_tensor(x) | |||
if x is not None: | |||
self._kwargs_bindings[k] = record_input(x) | |||
self._kwarg_bindings[k] = record_input(x) | |||
else: | |||
if len(args) != len(self._args_bindings): | |||
if len(args) != len(self._arg_bindings): | |||
raise TraceMismatchError("positional argument length mismatch") | |||
self._tensor_remaps = {} | |||
for i, (h, x) in enumerate(zip(self._args_bindings, args)): | |||
for i, (h, x) in enumerate(zip(self._arg_bindings, args)): | |||
x = find_raw_tensor(x) | |||
if x is None: | |||
raise TypeError( | |||
@@ -524,9 +528,9 @@ class trace: | |||
x = find_raw_tensor(x) | |||
if x is not None: | |||
kwargs_tensors[k] = x | |||
if set(kwargs_tensors) != set(self._kwargs_bindings): | |||
too_many = set(kwargs_tensors) - set(self._kwargs_bindings) | |||
too_few = set(self._kwargs_bindings) - set(kwargs_tensors) | |||
if set(kwargs_tensors) != set(self._kwarg_bindings): | |||
too_many = set(kwargs_tensors) - set(self._kwarg_bindings) | |||
too_few = set(self._kwarg_bindings) - set(kwargs_tensors) | |||
if too_many: | |||
raise TraceMismatchError( | |||
"keyword arguments found to be tensor this time " | |||
@@ -537,7 +541,7 @@ class trace: | |||
"keyword arguments found to be non-tensor this time " | |||
"but were tensor previously: %s" % " ".join(too_few) | |||
) | |||
for k, h in self._kwargs_bindings.items(): | |||
for k, h in self._kwarg_bindings.items(): | |||
x = kwargs_tensors[k] | |||
info = self._tinfo[h] | |||
if x.dtype != info.dtype: | |||
@@ -237,7 +237,7 @@ void init_graph_rt(py::module m) { | |||
return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); | |||
}); | |||
m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, std::optional<std::string> name) { | |||
m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, TensorShape shape, std::optional<std::string> name) { | |||
if (!cn.valid()) { | |||
throw py::type_error("device must be valid"); | |||
} | |||
@@ -248,8 +248,8 @@ void init_graph_rt(py::module m) { | |||
if (name) { | |||
config.name(*name); | |||
} | |||
return opr::Host2DeviceCopy::make(graph, std::make_shared<HostTensorND>(cn, dtype), config).node(); | |||
}, py::arg(), py::arg(), py::arg(), py::arg() = py::none()); | |||
return opr::Host2DeviceCopy::make(graph, std::make_shared<HostTensorND>(cn, shape, dtype), config).node(); | |||
}, py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::none()); | |||
m.def("input_callback", [input_callback](std::function<DeviceTensorND(void)> callback, | |||
const CompNode& comp_node, | |||
@@ -0,0 +1,136 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import contextlib | |||
import os | |||
import tempfile | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.functional as F | |||
import megengine.module as M | |||
import megengine.optimizer as optim | |||
from megengine import tensor | |||
from megengine.jit import trace | |||
@contextlib.contextmanager | |||
def mkstemp(): | |||
fd, path = tempfile.mkstemp() | |||
try: | |||
os.close(fd) | |||
yield path | |||
finally: | |||
os.remove(path) | |||
def minibatch_generator(batch_size): | |||
while True: | |||
inp_data = np.zeros((batch_size, 2)) | |||
label = np.zeros(batch_size, dtype=np.int32) | |||
for i in range(batch_size): | |||
inp_data[i, :] = np.random.rand(2) * 2 - 1 | |||
label[i] = 1 if np.prod(inp_data[i]) < 0 else 0 | |||
yield {"data": inp_data.astype(np.float32), "label": label.astype(np.int32)} | |||
class XORNet(M.Module): | |||
def __init__(self): | |||
self.mid_dim = 14 | |||
self.num_class = 2 | |||
super().__init__() | |||
self.fc0 = M.Linear(self.num_class, self.mid_dim, bias=True) | |||
self.fc1 = M.Linear(self.mid_dim, self.mid_dim, bias=True) | |||
self.fc2 = M.Linear(self.mid_dim, self.num_class, bias=True) | |||
def forward(self, x): | |||
x = self.fc0(x) | |||
x = F.tanh(x) | |||
x = self.fc1(x) | |||
x = F.tanh(x) | |||
x = self.fc2(x) | |||
return x | |||
def test_xornet_trace_dump(): | |||
net = XORNet() | |||
opt = optim.SGD(net.parameters(requires_grad=True), lr=0.01, momentum=0.9) | |||
batch_size = 64 | |||
train_dataset = minibatch_generator(batch_size) | |||
val_dataset = minibatch_generator(batch_size) | |||
@trace | |||
def train_fun(data, label): | |||
with opt.record(): | |||
net.train() | |||
pred = net(data) | |||
loss = F.cross_entropy_with_softmax(pred, label) | |||
opt.backward(loss) | |||
return pred, loss | |||
@trace | |||
def val_fun(data, label): | |||
net.eval() | |||
pred = net(data) | |||
loss = F.cross_entropy_with_softmax(pred, label) | |||
return pred, loss | |||
@trace(symbolic=True, capture_as_const=True) | |||
def pred_fun(data): | |||
net.eval() | |||
pred = net(data) | |||
pred_normalized = F.softmax(pred) | |||
return pred_normalized | |||
train_loss = [] | |||
val_loss = [] | |||
for step, minibatch in enumerate(train_dataset): | |||
if step > 100: | |||
break | |||
data = tensor(minibatch["data"]) | |||
label = tensor(minibatch["label"]) | |||
opt.zero_grad() | |||
_, loss = train_fun(data, label) | |||
train_loss.append((step, loss.numpy())) | |||
if step % 50 == 0: | |||
minibatch = next(val_dataset) | |||
_, loss = val_fun(data, label) | |||
loss = loss.numpy()[0] | |||
val_loss.append((step, loss)) | |||
print("Step: {} loss={}".format(step, loss)) | |||
opt.step() | |||
test_data = np.array( | |||
[ | |||
(0.5, 0.5), | |||
(0.3, 0.7), | |||
(0.1, 0.9), | |||
(-0.5, -0.5), | |||
(-0.3, -0.7), | |||
(-0.9, -0.1), | |||
(0.5, -0.5), | |||
(0.3, -0.7), | |||
(0.9, -0.1), | |||
(-0.5, 0.5), | |||
(-0.3, 0.7), | |||
(-0.1, 0.9), | |||
] | |||
) | |||
data = tensor(test_data.astype(np.float32)) | |||
out = pred_fun(data) | |||
pred_output = out.numpy() | |||
pred_label = np.argmax(pred_output, 1) | |||
with np.printoptions(precision=4, suppress=True): | |||
print("Predicated probability:") | |||
print(pred_output) | |||
with mkstemp() as out: | |||
pred_fun.dump(out, arg_names=["data"], output_names=["label"]) |