GitOrigin-RevId: 5e2acd4052
tags/v1.0.0-rc1
@@ -78,9 +78,9 @@ class Graph(_imperative_rt.ComputingGraph): | |||||
opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self) | opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self) | ||||
return opnode.outputs[0] | return opnode.outputs[0] | ||||
def make_h2d(self, *, dtype, device): | |||||
def make_h2d(self, *, dtype, device, shape=None, name=None): | |||||
device = as_device(device).to_c() | device = as_device(device).to_c() | ||||
return self._wrap(_imperative_rt.make_h2d(self, device, dtype)) | |||||
return self._wrap(_imperative_rt.make_h2d(self, device, dtype, shape, name)) | |||||
def dump(*args): | def dump(*args): | ||||
@@ -51,6 +51,7 @@ class TensorInfo: | |||||
"value_read", | "value_read", | ||||
"device", | "device", | ||||
"dtype", | "dtype", | ||||
"shape", | |||||
"bound_data", | "bound_data", | ||||
# resources for execution | # resources for execution | ||||
"varnode", | "varnode", | ||||
@@ -107,8 +108,8 @@ class trace: | |||||
self._active_tensors = weakref.WeakSet() | self._active_tensors = weakref.WeakSet() | ||||
self._tensor_remaps = None | self._tensor_remaps = None | ||||
self._inputs_to_restore = None | self._inputs_to_restore = None | ||||
self._args_bindings = None | |||||
self._kwargs_bindings = None | |||||
self._arg_bindings = None | |||||
self._kwarg_bindings = None | |||||
self._output_bindings = None | self._output_bindings = None | ||||
self._output_names = None | self._output_names = None | ||||
@@ -329,9 +330,7 @@ class trace: | |||||
links = () | links = () | ||||
if self._capture_as_const: | if self._capture_as_const: | ||||
for h in itertools.chain( | |||||
self._args_bindings, self._kwargs_bindings.values() | |||||
): | |||||
for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()): | |||||
info = self._tinfo[h] | info = self._tinfo[h] | ||||
opnode = info.data_setter = G.InputNode( | opnode = info.data_setter = G.InputNode( | ||||
device=info.device, dtype=info.dtype, graph=graph | device=info.device, dtype=info.dtype, graph=graph | ||||
@@ -434,15 +433,19 @@ class trace: | |||||
h2v = {} | h2v = {} | ||||
graph = G.Graph() | graph = G.Graph() | ||||
for i, h in enumerate(self._args_bindings): | |||||
for i, h in enumerate(self._arg_bindings): | |||||
info = self._tinfo[h] | info = self._tinfo[h] | ||||
h2v[h] = graph.make_h2d(dtype=info.dtype, device=info.device) | |||||
if arg_names: | |||||
h2v[h].name = arg_names[i] | |||||
for k, h in self._kwargs_bindings.items(): | |||||
h2v[h] = graph.make_h2d( | |||||
dtype=info.dtype, | |||||
device=info.device, | |||||
shape=info.shape, | |||||
name=arg_names[i] if arg_names else None, | |||||
) | |||||
for k, h in self._kwarg_bindings.items(): | |||||
info = self._tinfo[h] | info = self._tinfo[h] | ||||
h2v[h] = graph.make_h2d(dtype=info.dtype, device=info.device) | |||||
h2v[h].name = k | |||||
h2v[h] = graph.make_h2d( | |||||
dtype=info.dtype, device=info.device, shape=info.shape, name=k | |||||
) | |||||
for op, ihandles, ohandles in self._seq: | for op, ihandles, ohandles in self._seq: | ||||
ivars = [] | ivars = [] | ||||
@@ -479,11 +482,12 @@ class trace: | |||||
info.external = False | info.external = False | ||||
info.device = x.device | info.device = x.device | ||||
info.dtype = x.dtype | info.dtype = x.dtype | ||||
info.shape = x.shape | |||||
TraceMixin._TraceMixin__inject(x, h) | TraceMixin._TraceMixin__inject(x, h) | ||||
self._inputs_to_restore.append(x) | self._inputs_to_restore.append(x) | ||||
return h | return h | ||||
self._args_bindings = [] | |||||
self._arg_bindings = [] | |||||
for i, x in enumerate(args): | for i, x in enumerate(args): | ||||
x = find_raw_tensor(x) | x = find_raw_tensor(x) | ||||
if x is None: | if x is None: | ||||
@@ -491,20 +495,20 @@ class trace: | |||||
"positional arguments should all be tensor " | "positional arguments should all be tensor " | ||||
"but args[%d] cannot be recognized as one" % i | "but args[%d] cannot be recognized as one" % i | ||||
) | ) | ||||
self._args_bindings.append(record_input(x)) | |||||
self._arg_bindings.append(record_input(x)) | |||||
self._kwargs_bindings = {} | |||||
self._kwarg_bindings = {} | |||||
for k, x in kwargs.items(): | for k, x in kwargs.items(): | ||||
x = find_raw_tensor(x) | x = find_raw_tensor(x) | ||||
if x is not None: | if x is not None: | ||||
self._kwargs_bindings[k] = record_input(x) | |||||
self._kwarg_bindings[k] = record_input(x) | |||||
else: | else: | ||||
if len(args) != len(self._args_bindings): | |||||
if len(args) != len(self._arg_bindings): | |||||
raise TraceMismatchError("positional argument length mismatch") | raise TraceMismatchError("positional argument length mismatch") | ||||
self._tensor_remaps = {} | self._tensor_remaps = {} | ||||
for i, (h, x) in enumerate(zip(self._args_bindings, args)): | |||||
for i, (h, x) in enumerate(zip(self._arg_bindings, args)): | |||||
x = find_raw_tensor(x) | x = find_raw_tensor(x) | ||||
if x is None: | if x is None: | ||||
raise TypeError( | raise TypeError( | ||||
@@ -524,9 +528,9 @@ class trace: | |||||
x = find_raw_tensor(x) | x = find_raw_tensor(x) | ||||
if x is not None: | if x is not None: | ||||
kwargs_tensors[k] = x | kwargs_tensors[k] = x | ||||
if set(kwargs_tensors) != set(self._kwargs_bindings): | |||||
too_many = set(kwargs_tensors) - set(self._kwargs_bindings) | |||||
too_few = set(self._kwargs_bindings) - set(kwargs_tensors) | |||||
if set(kwargs_tensors) != set(self._kwarg_bindings): | |||||
too_many = set(kwargs_tensors) - set(self._kwarg_bindings) | |||||
too_few = set(self._kwarg_bindings) - set(kwargs_tensors) | |||||
if too_many: | if too_many: | ||||
raise TraceMismatchError( | raise TraceMismatchError( | ||||
"keyword arguments found to be tensor this time " | "keyword arguments found to be tensor this time " | ||||
@@ -537,7 +541,7 @@ class trace: | |||||
"keyword arguments found to be non-tensor this time " | "keyword arguments found to be non-tensor this time " | ||||
"but were tensor previously: %s" % " ".join(too_few) | "but were tensor previously: %s" % " ".join(too_few) | ||||
) | ) | ||||
for k, h in self._kwargs_bindings.items(): | |||||
for k, h in self._kwarg_bindings.items(): | |||||
x = kwargs_tensors[k] | x = kwargs_tensors[k] | ||||
info = self._tinfo[h] | info = self._tinfo[h] | ||||
if x.dtype != info.dtype: | if x.dtype != info.dtype: | ||||
@@ -237,7 +237,7 @@ void init_graph_rt(py::module m) { | |||||
return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); | return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); | ||||
}); | }); | ||||
m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, std::optional<std::string> name) { | |||||
m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, TensorShape shape, std::optional<std::string> name) { | |||||
if (!cn.valid()) { | if (!cn.valid()) { | ||||
throw py::type_error("device must be valid"); | throw py::type_error("device must be valid"); | ||||
} | } | ||||
@@ -248,8 +248,8 @@ void init_graph_rt(py::module m) { | |||||
if (name) { | if (name) { | ||||
config.name(*name); | config.name(*name); | ||||
} | } | ||||
return opr::Host2DeviceCopy::make(graph, std::make_shared<HostTensorND>(cn, dtype), config).node(); | |||||
}, py::arg(), py::arg(), py::arg(), py::arg() = py::none()); | |||||
return opr::Host2DeviceCopy::make(graph, std::make_shared<HostTensorND>(cn, shape, dtype), config).node(); | |||||
}, py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::none()); | |||||
m.def("input_callback", [input_callback](std::function<DeviceTensorND(void)> callback, | m.def("input_callback", [input_callback](std::function<DeviceTensorND(void)> callback, | ||||
const CompNode& comp_node, | const CompNode& comp_node, | ||||
@@ -0,0 +1,136 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import contextlib | |||||
import os | |||||
import tempfile | |||||
import numpy as np | |||||
import megengine as mge | |||||
import megengine.functional as F | |||||
import megengine.module as M | |||||
import megengine.optimizer as optim | |||||
from megengine import tensor | |||||
from megengine.jit import trace | |||||
@contextlib.contextmanager | |||||
def mkstemp(): | |||||
fd, path = tempfile.mkstemp() | |||||
try: | |||||
os.close(fd) | |||||
yield path | |||||
finally: | |||||
os.remove(path) | |||||
def minibatch_generator(batch_size): | |||||
while True: | |||||
inp_data = np.zeros((batch_size, 2)) | |||||
label = np.zeros(batch_size, dtype=np.int32) | |||||
for i in range(batch_size): | |||||
inp_data[i, :] = np.random.rand(2) * 2 - 1 | |||||
label[i] = 1 if np.prod(inp_data[i]) < 0 else 0 | |||||
yield {"data": inp_data.astype(np.float32), "label": label.astype(np.int32)} | |||||
class XORNet(M.Module): | |||||
def __init__(self): | |||||
self.mid_dim = 14 | |||||
self.num_class = 2 | |||||
super().__init__() | |||||
self.fc0 = M.Linear(self.num_class, self.mid_dim, bias=True) | |||||
self.fc1 = M.Linear(self.mid_dim, self.mid_dim, bias=True) | |||||
self.fc2 = M.Linear(self.mid_dim, self.num_class, bias=True) | |||||
def forward(self, x): | |||||
x = self.fc0(x) | |||||
x = F.tanh(x) | |||||
x = self.fc1(x) | |||||
x = F.tanh(x) | |||||
x = self.fc2(x) | |||||
return x | |||||
def test_xornet_trace_dump(): | |||||
net = XORNet() | |||||
opt = optim.SGD(net.parameters(requires_grad=True), lr=0.01, momentum=0.9) | |||||
batch_size = 64 | |||||
train_dataset = minibatch_generator(batch_size) | |||||
val_dataset = minibatch_generator(batch_size) | |||||
@trace | |||||
def train_fun(data, label): | |||||
with opt.record(): | |||||
net.train() | |||||
pred = net(data) | |||||
loss = F.cross_entropy_with_softmax(pred, label) | |||||
opt.backward(loss) | |||||
return pred, loss | |||||
@trace | |||||
def val_fun(data, label): | |||||
net.eval() | |||||
pred = net(data) | |||||
loss = F.cross_entropy_with_softmax(pred, label) | |||||
return pred, loss | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def pred_fun(data): | |||||
net.eval() | |||||
pred = net(data) | |||||
pred_normalized = F.softmax(pred) | |||||
return pred_normalized | |||||
train_loss = [] | |||||
val_loss = [] | |||||
for step, minibatch in enumerate(train_dataset): | |||||
if step > 100: | |||||
break | |||||
data = tensor(minibatch["data"]) | |||||
label = tensor(minibatch["label"]) | |||||
opt.zero_grad() | |||||
_, loss = train_fun(data, label) | |||||
train_loss.append((step, loss.numpy())) | |||||
if step % 50 == 0: | |||||
minibatch = next(val_dataset) | |||||
_, loss = val_fun(data, label) | |||||
loss = loss.numpy()[0] | |||||
val_loss.append((step, loss)) | |||||
print("Step: {} loss={}".format(step, loss)) | |||||
opt.step() | |||||
test_data = np.array( | |||||
[ | |||||
(0.5, 0.5), | |||||
(0.3, 0.7), | |||||
(0.1, 0.9), | |||||
(-0.5, -0.5), | |||||
(-0.3, -0.7), | |||||
(-0.9, -0.1), | |||||
(0.5, -0.5), | |||||
(0.3, -0.7), | |||||
(0.9, -0.1), | |||||
(-0.5, 0.5), | |||||
(-0.3, 0.7), | |||||
(-0.1, 0.9), | |||||
] | |||||
) | |||||
data = tensor(test_data.astype(np.float32)) | |||||
out = pred_fun(data) | |||||
pred_output = out.numpy() | |||||
pred_label = np.argmax(pred_output, 1) | |||||
with np.printoptions(precision=4, suppress=True): | |||||
print("Predicated probability:") | |||||
print(pred_output) | |||||
with mkstemp() as out: | |||||
pred_fun.dump(out, arg_names=["data"], output_names=["label"]) |