Browse Source

fix(mge/sdk): fix README.md and add missed example

GitOrigin-RevId: 0836029312
tags/v0.3.2
Megvii Engine Team 5 years ago
parent
commit
20450817da
4 changed files with 185 additions and 45 deletions
  1. +47
    -42
      sdk/load-and-run/README.md
  2. +2
    -2
      sdk/xor-deploy/README.md
  3. +136
    -0
      sdk/xor-deploy/xornet.py
  4. +0
    -1
      src/jit/README.md

+ 47
- 42
sdk/load-and-run/README.md View File

@@ -3,60 +3,65 @@
Load a model and run, for testing/debugging/profiling.

## Build
*megvii3 build*
```sh
bazel build //brain/megbrain:load_and_run
```

See [mnist-example](../mnist-example) for detailed explanations on build.
<!--
-->

### Build with cmake

Build MegEngine from source following [README.md](../../README.md). It will also produce the executable, `load_and_run`, which loads a model and runs the test cases attached to the model.


## Dump Model
<!--
-->

There are two methods to dump model:
## Dump Model with Test Cases Using [dump_with_testcase_mge.py](dump_with_testcase_mge.py)

1. Dump by `MegHair/utils/debug/load_network_and_run.py --dump-cpp-model
/path/to/output`, to test on random inputs. Useful for profiling.
2. Pack model as specified by
[`dump_with_testcase.py`](dump_with_testcase.py), and use
that script to dump model. This is useful for checking correctness on
different platforms.
### Step 1

### Input File for `dump_with_testcase.py`
Dump the model by calling the python API `megengine.jit.trace.dump()`.

The input file must be a python pickle. It can be in one of the following two
formats:
### Step 2

1. Contain a network that can be loaded by `meghair.utils.io.load_network`; in
such case, `--data` must be given and network output evaulated on current
computing device is used as groundtruth. All output vars would be checked.
The input data can be one of the following:
1. In the format `var0:file0;var1:file1...` meaning that `var0` should use
image file `file0`, `var1` should use image `file1` and so on. If there
is only one input var, the var name can be omitted. This can be combined
with `--resize-input` option.
2. In the format `var0:#rand(min, max, shape...);var1:#rand(min, max)...`
meaning to fill the corresponding input vars with uniform random numbers
in the range `[min, max)`, optionally overriding its shape.
2. Contain a dict in the format `{"outputs": [], "testcases": []}`, where
`outputs` is a list of output `VarNode`s and `testcases` is a list of test
cases. Each test case should be a dict that maps input var names to
corresponding values as `numpy.ndarray`. The expected outputs should also be
provided as inputs, and correctness should be checked by `AssertEqual`. You
can find more details in `dump_with_testcase.py`.
Append the test cases to the dumped model using [dump_with_testcase_mge.py](dump_with_testcase_mge.py).

### Input File for `dump_with_testcase_mge.py`
The basic usage of [dump_with_testcase_mge.py](dump_with_testcase_mge.py) is

```
python3 dump_with_testcase_mge.py model -d input_description -o model_with_testcases

```

The input file is obtained by calling `megengine.jit.trace.dump()`.
`--data` must be given.
where `model` is the file dumped at step 1, `input_description` describes the input data of the test cases, and `model_with_testcases` is the saved model with test cases.

## Example
`input_description` can be provided in the following ways:

1. Obtain the model file by running [xornet.py](../../python_module/examples/xor/xornet.py)
1. In the format `var0:file0;var1:file1...` meaning that `var0` should use
image file `file0`, `var1` should use image `file1` and so on. If there
is only one input var, the var name can be omitted. This can be combined
with `--resize-input` option.
2. In the format `var0:#rand(min, max, shape...);var1:#rand(min, max)...`
meaning to fill the corresponding input vars with uniform random numbers
in the range `[min, max)`, optionally overriding its shape.

For more usages, run

```
python3 dump_with_testcase_mge.py --help
```

### Example

1. Obtain the model file by running [xornet.py](../xor-deploy/xornet.py).

2. Dump the file with test cases attached to the model.

```
python3 dump_with_testcase_mge.py xornet_deploy.mge -o xornet.mge -d "#rand(0.1, 0.8, 4, 2)"
```
```
python3 dump_with_testcase_mge.py xornet_deploy.mge -o xornet.mge -d "#rand(0.1, 0.8, 4, 2)"
```

3. Verify the correctness by running `load_and_run` at the target platform.

The dumped file `xornet.mge` can be loaded by `load_and_run`.
```
load_and_run xornet.mge
```

+ 2
- 2
sdk/xor-deploy/README.md View File

@@ -12,11 +12,11 @@

* Step 3: run with dumped model

The dumped model can be obtained by running [xornet.py](../../python_module/examples/xor/xornet.py)
The dumped model can be obtained by running [xornet.py](xornet.py)


```
LD_LIBRARY_PATH=$MGE_INSTALL_PATH:$LD_LIBRARY_PATH ./xor_deploy xornet_deploy.mge 0.6 0.9
LD_LIBRARY_PATH=$MGE_INSTALL_PATH/lib64:$LD_LIBRARY_PATH ./xor_deploy xornet_deploy.mge 0.6 0.9
```

Sample output:


+ 136
- 0
sdk/xor-deploy/xornet.py View File

@@ -0,0 +1,136 @@
import numpy as np

import megengine as mge
import megengine.functional as F
import megengine.module as M
import megengine.optimizer as optim
from megengine.jit import trace


def minibatch_generator(batch_size):
while True:
inp_data = np.zeros((batch_size, 2))
label = np.zeros(batch_size, dtype=np.int32)
for i in range(batch_size):
inp_data[i, :] = np.random.rand(2) * 2 - 1
label[i] = 1 if np.prod(inp_data[i]) < 0 else 0
yield {"data": inp_data.astype(np.float32), "label": label.astype(np.int32)}


class XORNet(M.Module):
def __init__(self):
self.mid_dim = 14
self.num_class = 2
super().__init__()
self.fc0 = M.Linear(self.num_class, self.mid_dim, bias=True)
self.fc1 = M.Linear(self.mid_dim, self.mid_dim, bias=True)
self.fc2 = M.Linear(self.mid_dim, self.num_class, bias=True)

def forward(self, x):
x = self.fc0(x)
x = F.tanh(x)
x = self.fc1(x)
x = F.tanh(x)
x = self.fc2(x)
return x


@trace(symbolic=True)
def train_fun(data, label, net=None, opt=None):
net.train()
pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label)
opt.backward(loss)
return pred, loss


@trace(symbolic=True)
def val_fun(data, label, net=None):
net.eval()
pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label)
return pred, loss


@trace(symbolic=True)
def pred_fun(data, net=None):
net.eval()
pred = net(data)
pred_normalized = F.softmax(pred)
return pred_normalized


def main():

if not mge.is_cuda_available():
mge.set_default_device("cpux")

net = XORNet()
opt = optim.SGD(net.parameters(requires_grad=True), lr=0.01, momentum=0.9)
batch_size = 64
train_dataset = minibatch_generator(batch_size)
val_dataset = minibatch_generator(batch_size)

data = mge.tensor()
label = mge.tensor(np.zeros((batch_size,)), dtype=np.int32)
train_loss = []
val_loss = []
for step, minibatch in enumerate(train_dataset):
if step > 1000:
break
data.set_value(minibatch["data"])
label.set_value(minibatch["label"])
opt.zero_grad()
_, loss = train_fun(data, label, net=net, opt=opt)
train_loss.append((step, loss.numpy()))
if step % 50 == 0:
minibatch = next(val_dataset)
_, loss = val_fun(data, label, net=net)
loss = loss.numpy()[0]
val_loss.append((step, loss))
print("Step: {} loss={}".format(step, loss))
opt.step()

test_data = np.array(
[
(0.5, 0.5),
(0.3, 0.7),
(0.1, 0.9),
(-0.5, -0.5),
(-0.3, -0.7),
(-0.9, -0.1),
(0.5, -0.5),
(0.3, -0.7),
(0.9, -0.1),
(-0.5, 0.5),
(-0.3, 0.7),
(-0.1, 0.9),
]
)

data.set_value(test_data)
out = pred_fun(data, net=net)
pred_output = out.numpy()
pred_label = np.argmax(pred_output, 1)

print("Test data")
print(test_data)

with np.printoptions(precision=4, suppress=True):
print("Predicated probability:")
print(pred_output)

print("Predicated label")
print(pred_label)

model_name = "xornet_deploy.mge"

if pred_fun.enabled:
print("Dump model as {}".format(model_name))
pred_fun.dump(model_name, arg_names=["data"])
else:
print("pred_fun must be run with trace enabled in order to dump model")


if __name__ == "__main__":
main()

+ 0
- 1
src/jit/README.md View File

@@ -33,7 +33,6 @@ elemwise expressions by JIT.
|speed |100%|122% | 124% |



## What does JIT do
Detection the subgraph can be fused and compiling the subgraph into a fusion
kernel are the most two important parts in JIT.


Loading…
Cancel
Save