Browse Source

feat(mge/examples): add trace & dump example of cifar10 quantization

GitOrigin-RevId: cfc5e3483a
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
b3889938dc
2 changed files with 15 additions and 7 deletions
  1. +11
    -3
      imperative/python/megengine/jit/tracing.py
  2. +4
    -4
      imperative/python/megengine/quantization/observer.py

+ 11
- 3
imperative/python/megengine/jit/tracing.py View File

@@ -332,6 +332,7 @@ class trace:
need_reset_nodes = self._need_reset_nodes = []
# links enforce ordering of I/O nodes
links = ()
readers = []

if self._capture_as_const:
for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()):
@@ -345,7 +346,6 @@ class trace:

for op, ihandles, ohandles in self._seq:
ivars = []
readers = []
for h in ihandles:
info = self._tinfo[h]
if not hasattr(info, "varnode"):
@@ -431,11 +431,19 @@ class trace:
if output_names and not isinstance(output_names, collections.Sequence):
output_names = (output_names,)
if output_names and len(output_names) != len(self._output_bindings):
raise ValueError("wrong number of output_names")
raise ValueError(
"wrong number of output_names, should be {} values".format(
len(self._output_bindings)
)
)
if arg_names and not isinstance(arg_names, collections.Sequence):
arg_names = (arg_names,)
if arg_names and len(arg_names) != len(self._arg_bindings):
raise ValueError("wrong number of arg_names")
raise ValueError(
"wrong number of arg_names, should be {} values".format(
len(self._arg_bindings)
)
)
output_names = output_names or self._output_names

h2v = {}


+ 4
- 4
imperative/python/megengine/quantization/observer.py View File

@@ -118,8 +118,8 @@ class MinMaxObserver(Observer):
# stop gradient
x = x_orig.detach()
# find max and min
self.min_val = F.minimum(self.min_val, x.min())
self.max_val = F.maximum(self.max_val, x.max())
self.min_val.set_value(F.minimum(self.min_val, x.min()))
self.max_val.set_value(F.maximum(self.max_val, x.max()))
return x_orig


@@ -144,11 +144,11 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
# stop gradient
x = x_orig.detach()
# Exponential Moving Average
self.min_val = (
self.min_val.set_value(
self.min_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * x.min()
)
self.max_val = (
self.max_val.set_value(
self.max_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * x.max()
)


Loading…
Cancel
Save