|
|
@@ -16,7 +16,6 @@ from typing import Dict, List, Union |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from ...utils.comp_graph_tools import set_priority_to_id as _set_priority_to_id |
|
|
|
from .. import _imperative_rt |
|
|
|
from .._imperative_rt import GraphOptimizeOptions |
|
|
|
from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode |
|
|
@@ -26,6 +25,19 @@ from ..ops.builtin import OpDef |
|
|
|
from .core import OpBase, TensorBase |
|
|
|
|
|
|
|
|
|
|
|
def set_priority_to_id(dest_vars): |
|
|
|
""" |
|
|
|
For all oprs in the subgraph constructed by dest_vars, |
|
|
|
sets its priority to id if its original priority is zero. |
|
|
|
:param dest_vars: target vars representing the graph. |
|
|
|
""" |
|
|
|
dest_vec = [] |
|
|
|
for i in dest_vars: |
|
|
|
assert isinstance(i, _imperative_rt.VarNode) |
|
|
|
dest_vec.append(i) |
|
|
|
_imperative_rt.graph._set_priority_to_id(dest_vec) |
|
|
|
|
|
|
|
|
|
|
|
class Graph(_imperative_rt.ComputingGraph): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
@@ -46,8 +58,8 @@ class Graph(_imperative_rt.ComputingGraph): |
|
|
|
cache[obj] = wrapper(obj) |
|
|
|
return cache[obj] |
|
|
|
|
|
|
|
def set_priority_to_id(self, dest_vars): |
|
|
|
_set_priority_to_id(_unwrap(dest_vars)) |
|
|
|
def _set_priority_to_id(self, dest_vars): |
|
|
|
set_priority_to_id(_unwrap(dest_vars)) |
|
|
|
|
|
|
|
def compile(self, *args): |
|
|
|
self._function = super().compile(_unwrap(args)) |
|
|
|