GitOrigin-RevId: 6e23456250
tags/v1.6.0-rc1
@@ -227,19 +227,19 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): | |||||
gopt_level = None # disable jit and compile | gopt_level = None # disable jit and compile | ||||
binary_ops = { | binary_ops = { | ||||
"+": builtin.Elemwise(mode="add"), | |||||
"-": builtin.Elemwise(mode="sub"), | |||||
"*": builtin.Elemwise(mode="mul"), | |||||
"/": builtin.Elemwise(mode="true_div"), | |||||
"//": builtin.Elemwise(mode="floor_div"), | |||||
"**": builtin.Elemwise(mode="pow"), | |||||
"√": builtin.Elemwise(mode="expm1"), | |||||
"max": builtin.Elemwise(mode="max"), | |||||
"additive": builtin.Elemwise(mode="add"), | |||||
"+": lambda: builtin.Elemwise(mode="add"), | |||||
"-": lambda: builtin.Elemwise(mode="sub"), | |||||
"*": lambda: builtin.Elemwise(mode="mul"), | |||||
"/": lambda: builtin.Elemwise(mode="true_div"), | |||||
"//": lambda: builtin.Elemwise(mode="floor_div"), | |||||
"**": lambda: builtin.Elemwise(mode="pow"), | |||||
"√": lambda: builtin.Elemwise(mode="expm1"), | |||||
"max": lambda: builtin.Elemwise(mode="max"), | |||||
"additive": lambda: builtin.Elemwise(mode="add"), | |||||
} | } | ||||
unary_ops = { | unary_ops = { | ||||
"-": builtin.Elemwise(mode="negate"), | |||||
"-": lambda: builtin.Elemwise(mode="negate"), | |||||
} | } | ||||
def decorator(func): | def decorator(func): | ||||
@@ -248,9 +248,9 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): | |||||
def apply_expr(op, *args): | def apply_expr(op, *args): | ||||
if isinstance(op, str): | if isinstance(op, str): | ||||
if len(args) == 2: | if len(args) == 2: | ||||
op = binary_ops[op] | |||||
op = binary_ops[op]() | |||||
elif len(args) == 1: | elif len(args) == 1: | ||||
op = unary_ops[op] | |||||
op = unary_ops[op]() | |||||
return builder.apply(op, args, 1)[0] | return builder.apply(op, args, 1)[0] | ||||
def apply_const(value, dtype=dtype, device=device): | def apply_const(value, dtype=dtype, device=device): | ||||
@@ -261,8 +261,8 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): | |||||
builder.outputs(outputs) | builder.outputs(outputs) | ||||
builder.outputs_has_grad(outputs_has_grad) | builder.outputs_has_grad(outputs_has_grad) | ||||
if gopt_level is None: | if gopt_level is None: | ||||
return builder.get() | |||||
return lambda: builder.get() | |||||
else: | else: | ||||
return builder.compile(gopt_level) | |||||
return lambda: builder.compile(gopt_level) | |||||
return decorator | return decorator |
@@ -767,6 +767,19 @@ def matinv(inp: Tensor) -> Tensor: | |||||
return result | return result | ||||
class _Hashable: | |||||
def __init__(self, value) -> None: | |||||
self.value = value | |||||
def __hash__(self) -> int: | |||||
return hash(str(self.value)) | |||||
def __eq__(self, o: object) -> bool: | |||||
if not isinstance(o, _Hashable): | |||||
return False | |||||
return self.value == o.value | |||||
@lru_cache(maxsize=None) | @lru_cache(maxsize=None) | ||||
def _get_extentedMatrixMulOp( | def _get_extentedMatrixMulOp( | ||||
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | ||||
@@ -833,7 +846,7 @@ def _get_extentedMatrixMulOp( | |||||
transposeB=transpose_b, | transposeB=transpose_b, | ||||
compute_mode=compute_mode, | compute_mode=compute_mode, | ||||
format=format, | format=format, | ||||
strategy=strategy, | |||||
strategy=strategy.value, | |||||
) | ) | ||||
result = f(op, inp1, inp2) | result = f(op, inp1, inp2) | ||||
result_shape = f(GetVarShape(), result) | result_shape = f(GetVarShape(), result) | ||||
@@ -954,7 +967,7 @@ def _get_extentedBatchedMatrixMulOp( | |||||
transposeB=transpose_b, | transposeB=transpose_b, | ||||
compute_mode=compute_mode, | compute_mode=compute_mode, | ||||
format=format, | format=format, | ||||
strategy=strategy, | |||||
strategy=strategy.value, | |||||
) | ) | ||||
result = f(op, inp1, inp2) | result = f(op, inp1, inp2) | ||||
@@ -1051,9 +1064,9 @@ def matmul( | |||||
transpose_b, | transpose_b, | ||||
compute_mode, | compute_mode, | ||||
format, | format, | ||||
strategy=get_execution_strategy(), | |||||
strategy=_Hashable(get_execution_strategy()), | |||||
) | ) | ||||
(result,) = apply(extentedMatrixMulOp, inp1, inp2) | |||||
(result,) = apply(extentedMatrixMulOp(), inp1, inp2) | |||||
return result | return result | ||||
else: # dispath to BatchedMatrixMul | else: # dispath to BatchedMatrixMul | ||||
extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( | extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( | ||||
@@ -1065,9 +1078,9 @@ def matmul( | |||||
transpose_b, | transpose_b, | ||||
compute_mode, | compute_mode, | ||||
format, | format, | ||||
strategy=get_execution_strategy(), | |||||
strategy=_Hashable(get_execution_strategy()), | |||||
) | ) | ||||
(result,) = apply(extentedBatchedMatrixMulOp, inp1, inp2) | |||||
(result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) | |||||
return result | return result | ||||
@@ -1328,7 +1328,7 @@ def sync_batch_norm( | |||||
syncbn_split_stats, | syncbn_split_stats, | ||||
) = _get_sync_bn_ops(_device, _dtype, eps_mode, _ndim, _channels) | ) = _get_sync_bn_ops(_device, _dtype, eps_mode, _ndim, _channels) | ||||
reduce_shape, reduce_size, channel_x1s, channel_x2s = apply(syncbn_stage0, inp) | |||||
reduce_shape, reduce_size, channel_x1s, channel_x2s = apply(syncbn_stage0(), inp) | |||||
eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device) | eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device) | ||||
@@ -1338,19 +1338,28 @@ def sync_batch_norm( | |||||
if training: | if training: | ||||
if is_distributed(): | if is_distributed(): | ||||
# reduce all nodes' data to calculate mean and variance | # reduce all nodes' data to calculate mean and variance | ||||
(stat,) = apply(syncbn_concat_stats, reduce_size, channel_x1s, channel_x2s) | |||||
(stat,) = apply( | |||||
syncbn_concat_stats(), reduce_size, channel_x1s, channel_x2s | |||||
) | |||||
stat = all_reduce_sum(stat, group) | stat = all_reduce_sum(stat, group) | ||||
reduce_size, channel_x1s, channel_x2s = apply(syncbn_split_stats, stat) | |||||
reduce_size, channel_x1s, channel_x2s = apply(syncbn_split_stats(), stat) | |||||
outvar, channel_mean, *_ = apply( | outvar, channel_mean, *_ = apply( | ||||
syncbn_stage1, inp, reduce_size, channel_x1s, channel_x2s, eps, weight, bias | |||||
syncbn_stage1(), | |||||
inp, | |||||
reduce_size, | |||||
channel_x1s, | |||||
channel_x2s, | |||||
eps, | |||||
weight, | |||||
bias, | |||||
) | ) | ||||
else: | else: | ||||
assert running_var is not None and running_mean is not None | assert running_var is not None and running_mean is not None | ||||
channel_mean = running_mean | channel_mean = running_mean | ||||
channel_var = running_var | channel_var = running_var | ||||
outvar, *_ = apply( | outvar, *_ = apply( | ||||
syncbn_stage1_inference, inp, channel_mean, channel_var, eps, weight, bias | |||||
syncbn_stage1_inference(), inp, channel_mean, channel_var, eps, weight, bias | |||||
) | ) | ||||
# outvar = output * weight + bias | # outvar = output * weight + bias | ||||
@@ -1362,7 +1371,7 @@ def sync_batch_norm( | |||||
if training and running_var is not None and running_mean is not None: | if training and running_var is not None and running_mean is not None: | ||||
momentum = convert_single_value(momentum, dtype=inp.dtype, device=inp.device) | momentum = convert_single_value(momentum, dtype=inp.dtype, device=inp.device) | ||||
running_mean[...], running_var[...] = apply( | running_mean[...], running_var[...] = apply( | ||||
syncbn_stage2, | |||||
syncbn_stage2(), | |||||
running_mean, | running_mean, | ||||
running_var, | running_var, | ||||
momentum, | momentum, | ||||
@@ -482,9 +482,15 @@ void init_ops(py::module m) { | |||||
struct PySubgraphBuilder { | struct PySubgraphBuilder { | ||||
explicit PySubgraphBuilder(std::string name) : name{name}{} | explicit PySubgraphBuilder(std::string name) : name{name}{} | ||||
std::string name; | std::string name; | ||||
Subgraph graph; | |||||
std::shared_ptr<Subgraph> graph_storage = std::make_shared<Subgraph>(); | |||||
std::shared_ptr<UniqueKey> graph_key = std::make_shared<UniqueKey>(); | |||||
Subgraph& graph = *graph_storage; | |||||
mgb::SmallVector<bool> output_grad_mask; | mgb::SmallVector<bool> output_grad_mask; | ||||
Subgraph::var_t next_var = 1; | Subgraph::var_t next_var = 1; | ||||
std::shared_ptr<OpDef> build() const { | |||||
return SubgraphOp::make(name, graph_storage, output_grad_mask, graph_key); | |||||
} | |||||
}; | }; | ||||
py::class_<PySubgraphBuilder>(m, "SubgraphBuilder") | py::class_<PySubgraphBuilder>(m, "SubgraphBuilder") | ||||
@@ -518,10 +524,9 @@ void init_ops(py::module m) { | |||||
self.output_grad_mask = outputs_has_grad; | self.output_grad_mask = outputs_has_grad; | ||||
}) | }) | ||||
.def("get", [](PySubgraphBuilder& self){ | .def("get", [](PySubgraphBuilder& self){ | ||||
return (std::shared_ptr<OpDef>)SubgraphOp::make(self.name, self.graph, self.output_grad_mask); | |||||
return (std::shared_ptr<OpDef>)self.build(); | |||||
}) | }) | ||||
.def("compile", [](PySubgraphBuilder& self, int gopt_level){ | .def("compile", [](PySubgraphBuilder& self, int gopt_level){ | ||||
auto op = SubgraphOp::make(self.name, self.graph, self.output_grad_mask); | |||||
return (std::shared_ptr<OpDef>)CompiledOp::make(op, gopt_level); | |||||
return (std::shared_ptr<OpDef>)CompiledOp::make(self.build(), gopt_level); | |||||
}); | }); | ||||
} | } |
@@ -181,7 +181,7 @@ OP_TRAIT_REG(Identity, Identity) | |||||
namespace { namespace subgraph { | namespace { namespace subgraph { | ||||
EncodedSubraph make_forward_graph(const OpDef& def, SmallVector<LogicalTensorDesc> inputs) { | EncodedSubraph make_forward_graph(const OpDef& def, SmallVector<LogicalTensorDesc> inputs) { | ||||
return EncodedSubraph::make(def.cast_final_safe<SubgraphOp>().graph); | |||||
return EncodedSubraph::make(*def.cast_final_safe<SubgraphOp>().graph); | |||||
} | } | ||||
EncodedSubraph make_backward_graph( | EncodedSubraph make_backward_graph( | ||||
@@ -197,16 +197,19 @@ EncodedSubraph make_backward_graph( | |||||
} | } | ||||
} | } | ||||
auto bgraph = subgraph_detail::make_backward_graph(def, inputs, input_requires_grad, output_has_grad); | auto bgraph = subgraph_detail::make_backward_graph(def, inputs, input_requires_grad, output_has_grad); | ||||
return EncodedSubraph::make_single(SubgraphOp::make(op.name+"Grad", bgraph.graph), bgraph.input_mask, bgraph.output_mask); | |||||
return EncodedSubraph::make_single( | |||||
SubgraphOp::make(op.name + "Grad", | |||||
std::make_shared<Subgraph>(bgraph.graph)), | |||||
bgraph.input_mask, bgraph.output_mask); | |||||
} | } | ||||
std::vector<std::pair<const char*, std::string>> props(const OpDef& def) { | std::vector<std::pair<const char*, std::string>> props(const OpDef& def) { | ||||
auto& op = def.cast_final_safe<SubgraphOp>(); | auto& op = def.cast_final_safe<SubgraphOp>(); | ||||
return { | return { | ||||
{"name", op.name}, | {"name", op.name}, | ||||
{"inputs", mgb::imperative::to_string(op.graph.inputs)}, | |||||
{"exprs", mgb::imperative::to_string(op.graph.exprs)}, | |||||
{"outputs", mgb::imperative::to_string(op.graph.outputs)}, | |||||
{"inputs", mgb::imperative::to_string(op.graph->inputs)}, | |||||
{"exprs", mgb::imperative::to_string(op.graph->exprs)}, | |||||
{"outputs", mgb::imperative::to_string(op.graph->outputs)}, | |||||
}; | }; | ||||
} | } | ||||
@@ -222,7 +225,7 @@ std::string make_name(const OpDef& def) { | |||||
auto hash(const OpDef& def) { | auto hash(const OpDef& def) { | ||||
auto& op = def.cast_final_safe<SubgraphOp>(); | auto& op = def.cast_final_safe<SubgraphOp>(); | ||||
if (!op.graph_key) { | if (!op.graph_key) { | ||||
return (size_t)reinterpret_cast<uintptr_t>(&op.graph); | |||||
return (size_t)reinterpret_cast<uintptr_t>(op.graph.get()); | |||||
} | } | ||||
return op.graph_key->hash(); | return op.graph_key->hash(); | ||||
} | } | ||||
@@ -238,7 +241,7 @@ auto is_same_st(const OpDef& def, const OpDef& another) { | |||||
if (has_graph_key) { | if (has_graph_key) { | ||||
graph_same = rhs.graph_key && lhs.graph_key->is_same(*rhs.graph_key); | graph_same = rhs.graph_key && lhs.graph_key->is_same(*rhs.graph_key); | ||||
} else { | } else { | ||||
graph_same = !rhs.graph_key && &lhs.graph == &rhs.graph; | |||||
graph_same = !rhs.graph_key && lhs.graph.get() == rhs.graph.get(); | |||||
} | } | ||||
return graph_same; | return graph_same; | ||||
} | } | ||||
@@ -354,7 +357,9 @@ auto apply_on_physical_tensor( | |||||
auto apply_on_var_node( | auto apply_on_var_node( | ||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
return OpDef::apply_on_var_node(*def.cast_final_safe<CompiledOp>().op, inputs); | |||||
auto& op = def.cast_final_safe<CompiledOp>(); | |||||
op.op->set_scope(op.scope()); | |||||
return OpDef::apply_on_var_node(*op.op, inputs); | |||||
} | } | ||||
auto infer_output_attrs_fallible( | auto infer_output_attrs_fallible( | ||||
@@ -397,7 +402,9 @@ EncodedSubraph make_backward_graph( | |||||
if (backward_graph.graph.is_single()) { | if (backward_graph.graph.is_single()) { | ||||
bgraph_op = backward_graph.graph.as_single(); | bgraph_op = backward_graph.graph.as_single(); | ||||
} else { | } else { | ||||
bgraph_op = SubgraphOp::make(name+"Grad", backward_graph.graph, grad_outputs_has_grad, key); | |||||
bgraph_op = SubgraphOp::make( | |||||
name + "Grad", std::make_shared<Subgraph>(backward_graph.graph), | |||||
grad_outputs_has_grad, key); | |||||
} | } | ||||
auto compiled_op = CompiledOp::make(bgraph_op, op.gopt_level); | auto compiled_op = CompiledOp::make(bgraph_op, op.gopt_level); | ||||
auto encoded_graph = EncodedSubraph::make_single(compiled_op, backward_graph.input_mask, backward_graph.output_mask); | auto encoded_graph = EncodedSubraph::make_single(compiled_op, backward_graph.input_mask, backward_graph.output_mask); | ||||
@@ -431,6 +438,8 @@ OP_TRAIT_REG(CompiledOp, CompiledOp) | |||||
.fallback(); | .fallback(); | ||||
}} | }} | ||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(UniqueKey); | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(SubgraphOp); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(SubgraphOp); | ||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardOpKey); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardOpKey); | ||||
@@ -28,7 +28,8 @@ VarNodeArray apply_on_var_node( | |||||
for (auto&& input: inputs) { | for (auto&& input: inputs) { | ||||
input_descs.push_back({TensorLayout{input->dtype()}, input->comp_node()}); | input_descs.push_back({TensorLayout{input->dtype()}, input->comp_node()}); | ||||
} | } | ||||
auto apply_functor = [](const std::shared_ptr<OpDef>& op, const VarNodeArray& inputs, size_t nr_outputs){ | |||||
auto apply_functor = [&](const std::shared_ptr<OpDef>& op, const VarNodeArray& inputs, size_t nr_outputs){ | |||||
op->set_scope(def.scope()); | |||||
return OpDef::apply_on_var_node(*op, inputs); | return OpDef::apply_on_var_node(*op, inputs); | ||||
}; | }; | ||||
auto const_functor = [&](const TensorPtr& value) { | auto const_functor = [&](const TensorPtr& value) { | ||||
@@ -48,16 +48,28 @@ struct ShapeInfer final : OpDefImplBase<ShapeInfer> { | |||||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | MGB_DYN_TYPE_OBJ_FINAL_DECL; | ||||
}; | }; | ||||
struct UniqueKey final: Hashable { | |||||
public: | |||||
size_t hash() const override { | |||||
return reinterpret_cast<uintptr_t>(this); | |||||
} | |||||
protected: | |||||
bool is_same_st(const Hashable& rhs) const override { | |||||
return this == &rhs.cast_final_safe<UniqueKey>(); | |||||
} | |||||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
}; | |||||
struct SubgraphOp final: OpDefImplBase<SubgraphOp> { | struct SubgraphOp final: OpDefImplBase<SubgraphOp> { | ||||
std::string name; | std::string name; | ||||
Subgraph graph; | |||||
std::shared_ptr<Subgraph> graph; | |||||
SmallVector<bool> output_grad_mask; | SmallVector<bool> output_grad_mask; | ||||
std::shared_ptr<Hashable> graph_key; | std::shared_ptr<Hashable> graph_key; | ||||
SubgraphOp() = default; | SubgraphOp() = default; | ||||
SubgraphOp(std::string name, Subgraph graph, SmallVector<bool> output_grad_mask={}, std::shared_ptr<Hashable> key=nullptr) | |||||
SubgraphOp(std::string name, std::shared_ptr<Subgraph> graph, SmallVector<bool> output_grad_mask={}, std::shared_ptr<Hashable> key=nullptr) | |||||
: name{name}, graph{graph}, output_grad_mask{output_grad_mask}, graph_key{std::move(key)}{ | : name{name}, graph{graph}, output_grad_mask{output_grad_mask}, graph_key{std::move(key)}{ | ||||
if (this->output_grad_mask.empty()) { | if (this->output_grad_mask.empty()) { | ||||
this->output_grad_mask.resize(graph.outputs.size(), true); | |||||
this->output_grad_mask.resize(graph->outputs.size(), true); | |||||
} | } | ||||
} | } | ||||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | MGB_DYN_TYPE_OBJ_FINAL_DECL; | ||||