From b3889938dc8310e7ce505dae635b0034b754a708 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sun, 6 Sep 2020 17:18:14 +0800 Subject: [PATCH] feat(mge/examples): add trace & dump example of cifar10 quantization GitOrigin-RevId: cfc5e3483a66915e92458d1f502f1dde64ffe564 --- imperative/python/megengine/jit/tracing.py | 14 +++++++++++--- imperative/python/megengine/quantization/observer.py | 8 ++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 66107d62..1b40822f 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -332,6 +332,7 @@ class trace: need_reset_nodes = self._need_reset_nodes = [] # links enforce ordering of I/O nodes links = () + readers = [] if self._capture_as_const: for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()): @@ -345,7 +346,6 @@ class trace: for op, ihandles, ohandles in self._seq: ivars = [] - readers = [] for h in ihandles: info = self._tinfo[h] if not hasattr(info, "varnode"): @@ -431,11 +431,19 @@ class trace: if output_names and not isinstance(output_names, collections.Sequence): output_names = (output_names,) if output_names and len(output_names) != len(self._output_bindings): - raise ValueError("wrong number of output_names") + raise ValueError( + "wrong number of output_names, should be {} values".format( + len(self._output_bindings) + ) + ) if arg_names and not isinstance(arg_names, collections.Sequence): arg_names = (arg_names,) if arg_names and len(arg_names) != len(self._arg_bindings): - raise ValueError("wrong number of arg_names") + raise ValueError( + "wrong number of arg_names, should be {} values".format( + len(self._arg_bindings) + ) + ) output_names = output_names or self._output_names h2v = {} diff --git a/imperative/python/megengine/quantization/observer.py b/imperative/python/megengine/quantization/observer.py index 3aa61082..65f88384 100644 --- a/imperative/python/megengine/quantization/observer.py +++ b/imperative/python/megengine/quantization/observer.py @@ -118,8 +118,8 @@ class MinMaxObserver(Observer): # stop gradient x = x_orig.detach() # find max and min - self.min_val = F.minimum(self.min_val, x.min()) - self.max_val = F.maximum(self.max_val, x.max()) + self.min_val.set_value(F.minimum(self.min_val, x.min())) + self.max_val.set_value(F.maximum(self.max_val, x.max())) return x_orig @@ -144,11 +144,11 @@ class ExponentialMovingAverageObserver(MinMaxObserver): # stop gradient x = x_orig.detach() # Exponential Moving Average - self.min_val = ( + self.min_val.set_value( self.min_val * self.runtime_momentum + (1 - self.runtime_momentum) * x.min() ) - self.max_val = ( + self.max_val.set_value( self.max_val * self.runtime_momentum + (1 - self.runtime_momentum) * x.max() )