GitOrigin-RevId: 6e23456250
tags/v1.6.0-rc1
@@ -227,19 +227,19 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): | |||
gopt_level = None # disable jit and compile | |||
binary_ops = { | |||
"+": builtin.Elemwise(mode="add"), | |||
"-": builtin.Elemwise(mode="sub"), | |||
"*": builtin.Elemwise(mode="mul"), | |||
"/": builtin.Elemwise(mode="true_div"), | |||
"//": builtin.Elemwise(mode="floor_div"), | |||
"**": builtin.Elemwise(mode="pow"), | |||
"√": builtin.Elemwise(mode="expm1"), | |||
"max": builtin.Elemwise(mode="max"), | |||
"additive": builtin.Elemwise(mode="add"), | |||
"+": lambda: builtin.Elemwise(mode="add"), | |||
"-": lambda: builtin.Elemwise(mode="sub"), | |||
"*": lambda: builtin.Elemwise(mode="mul"), | |||
"/": lambda: builtin.Elemwise(mode="true_div"), | |||
"//": lambda: builtin.Elemwise(mode="floor_div"), | |||
"**": lambda: builtin.Elemwise(mode="pow"), | |||
"√": lambda: builtin.Elemwise(mode="expm1"), | |||
"max": lambda: builtin.Elemwise(mode="max"), | |||
"additive": lambda: builtin.Elemwise(mode="add"), | |||
} | |||
unary_ops = { | |||
"-": builtin.Elemwise(mode="negate"), | |||
"-": lambda: builtin.Elemwise(mode="negate"), | |||
} | |||
def decorator(func): | |||
@@ -248,9 +248,9 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): | |||
def apply_expr(op, *args): | |||
if isinstance(op, str): | |||
if len(args) == 2: | |||
op = binary_ops[op] | |||
op = binary_ops[op]() | |||
elif len(args) == 1: | |||
op = unary_ops[op] | |||
op = unary_ops[op]() | |||
return builder.apply(op, args, 1)[0] | |||
def apply_const(value, dtype=dtype, device=device): | |||
@@ -261,8 +261,8 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): | |||
builder.outputs(outputs) | |||
builder.outputs_has_grad(outputs_has_grad) | |||
if gopt_level is None: | |||
return builder.get() | |||
return lambda: builder.get() | |||
else: | |||
return builder.compile(gopt_level) | |||
return lambda: builder.compile(gopt_level) | |||
return decorator |
@@ -767,6 +767,19 @@ def matinv(inp: Tensor) -> Tensor: | |||
return result | |||
class _Hashable: | |||
def __init__(self, value) -> None: | |||
self.value = value | |||
def __hash__(self) -> int: | |||
return hash(str(self.value)) | |||
def __eq__(self, o: object) -> bool: | |||
if not isinstance(o, _Hashable): | |||
return False | |||
return self.value == o.value | |||
@lru_cache(maxsize=None) | |||
def _get_extentedMatrixMulOp( | |||
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||
@@ -833,7 +846,7 @@ def _get_extentedMatrixMulOp( | |||
transposeB=transpose_b, | |||
compute_mode=compute_mode, | |||
format=format, | |||
strategy=strategy, | |||
strategy=strategy.value, | |||
) | |||
result = f(op, inp1, inp2) | |||
result_shape = f(GetVarShape(), result) | |||
@@ -954,7 +967,7 @@ def _get_extentedBatchedMatrixMulOp( | |||
transposeB=transpose_b, | |||
compute_mode=compute_mode, | |||
format=format, | |||
strategy=strategy, | |||
strategy=strategy.value, | |||
) | |||
result = f(op, inp1, inp2) | |||
@@ -1051,9 +1064,9 @@ def matmul( | |||
transpose_b, | |||
compute_mode, | |||
format, | |||
strategy=get_execution_strategy(), | |||
strategy=_Hashable(get_execution_strategy()), | |||
) | |||
(result,) = apply(extentedMatrixMulOp, inp1, inp2) | |||
(result,) = apply(extentedMatrixMulOp(), inp1, inp2) | |||
return result | |||
else: # dispath to BatchedMatrixMul | |||
extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( | |||
@@ -1065,9 +1078,9 @@ def matmul( | |||
transpose_b, | |||
compute_mode, | |||
format, | |||
strategy=get_execution_strategy(), | |||
strategy=_Hashable(get_execution_strategy()), | |||
) | |||
(result,) = apply(extentedBatchedMatrixMulOp, inp1, inp2) | |||
(result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) | |||
return result | |||
@@ -1328,7 +1328,7 @@ def sync_batch_norm( | |||
syncbn_split_stats, | |||
) = _get_sync_bn_ops(_device, _dtype, eps_mode, _ndim, _channels) | |||
reduce_shape, reduce_size, channel_x1s, channel_x2s = apply(syncbn_stage0, inp) | |||
reduce_shape, reduce_size, channel_x1s, channel_x2s = apply(syncbn_stage0(), inp) | |||
eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device) | |||
@@ -1338,19 +1338,28 @@ def sync_batch_norm( | |||
if training: | |||
if is_distributed(): | |||
# reduce all nodes' data to calculate mean and variance | |||
(stat,) = apply(syncbn_concat_stats, reduce_size, channel_x1s, channel_x2s) | |||
(stat,) = apply( | |||
syncbn_concat_stats(), reduce_size, channel_x1s, channel_x2s | |||
) | |||
stat = all_reduce_sum(stat, group) | |||
reduce_size, channel_x1s, channel_x2s = apply(syncbn_split_stats, stat) | |||
reduce_size, channel_x1s, channel_x2s = apply(syncbn_split_stats(), stat) | |||
outvar, channel_mean, *_ = apply( | |||
syncbn_stage1, inp, reduce_size, channel_x1s, channel_x2s, eps, weight, bias | |||
syncbn_stage1(), | |||
inp, | |||
reduce_size, | |||
channel_x1s, | |||
channel_x2s, | |||
eps, | |||
weight, | |||
bias, | |||
) | |||
else: | |||
assert running_var is not None and running_mean is not None | |||
channel_mean = running_mean | |||
channel_var = running_var | |||
outvar, *_ = apply( | |||
syncbn_stage1_inference, inp, channel_mean, channel_var, eps, weight, bias | |||
syncbn_stage1_inference(), inp, channel_mean, channel_var, eps, weight, bias | |||
) | |||
# outvar = output * weight + bias | |||
@@ -1362,7 +1371,7 @@ def sync_batch_norm( | |||
if training and running_var is not None and running_mean is not None: | |||
momentum = convert_single_value(momentum, dtype=inp.dtype, device=inp.device) | |||
running_mean[...], running_var[...] = apply( | |||
syncbn_stage2, | |||
syncbn_stage2(), | |||
running_mean, | |||
running_var, | |||
momentum, | |||
@@ -482,9 +482,15 @@ void init_ops(py::module m) { | |||
struct PySubgraphBuilder { | |||
explicit PySubgraphBuilder(std::string name) : name{name}{} | |||
std::string name; | |||
Subgraph graph; | |||
std::shared_ptr<Subgraph> graph_storage = std::make_shared<Subgraph>(); | |||
std::shared_ptr<UniqueKey> graph_key = std::make_shared<UniqueKey>(); | |||
Subgraph& graph = *graph_storage; | |||
mgb::SmallVector<bool> output_grad_mask; | |||
Subgraph::var_t next_var = 1; | |||
std::shared_ptr<OpDef> build() const { | |||
return SubgraphOp::make(name, graph_storage, output_grad_mask, graph_key); | |||
} | |||
}; | |||
py::class_<PySubgraphBuilder>(m, "SubgraphBuilder") | |||
@@ -518,10 +524,9 @@ void init_ops(py::module m) { | |||
self.output_grad_mask = outputs_has_grad; | |||
}) | |||
.def("get", [](PySubgraphBuilder& self){ | |||
return (std::shared_ptr<OpDef>)SubgraphOp::make(self.name, self.graph, self.output_grad_mask); | |||
return (std::shared_ptr<OpDef>)self.build(); | |||
}) | |||
.def("compile", [](PySubgraphBuilder& self, int gopt_level){ | |||
auto op = SubgraphOp::make(self.name, self.graph, self.output_grad_mask); | |||
return (std::shared_ptr<OpDef>)CompiledOp::make(op, gopt_level); | |||
return (std::shared_ptr<OpDef>)CompiledOp::make(self.build(), gopt_level); | |||
}); | |||
} |
@@ -181,7 +181,7 @@ OP_TRAIT_REG(Identity, Identity) | |||
namespace { namespace subgraph { | |||
EncodedSubraph make_forward_graph(const OpDef& def, SmallVector<LogicalTensorDesc> inputs) { | |||
return EncodedSubraph::make(def.cast_final_safe<SubgraphOp>().graph); | |||
return EncodedSubraph::make(*def.cast_final_safe<SubgraphOp>().graph); | |||
} | |||
EncodedSubraph make_backward_graph( | |||
@@ -197,16 +197,19 @@ EncodedSubraph make_backward_graph( | |||
} | |||
} | |||
auto bgraph = subgraph_detail::make_backward_graph(def, inputs, input_requires_grad, output_has_grad); | |||
return EncodedSubraph::make_single(SubgraphOp::make(op.name+"Grad", bgraph.graph), bgraph.input_mask, bgraph.output_mask); | |||
return EncodedSubraph::make_single( | |||
SubgraphOp::make(op.name + "Grad", | |||
std::make_shared<Subgraph>(bgraph.graph)), | |||
bgraph.input_mask, bgraph.output_mask); | |||
} | |||
std::vector<std::pair<const char*, std::string>> props(const OpDef& def) { | |||
auto& op = def.cast_final_safe<SubgraphOp>(); | |||
return { | |||
{"name", op.name}, | |||
{"inputs", mgb::imperative::to_string(op.graph.inputs)}, | |||
{"exprs", mgb::imperative::to_string(op.graph.exprs)}, | |||
{"outputs", mgb::imperative::to_string(op.graph.outputs)}, | |||
{"inputs", mgb::imperative::to_string(op.graph->inputs)}, | |||
{"exprs", mgb::imperative::to_string(op.graph->exprs)}, | |||
{"outputs", mgb::imperative::to_string(op.graph->outputs)}, | |||
}; | |||
} | |||
@@ -222,7 +225,7 @@ std::string make_name(const OpDef& def) { | |||
auto hash(const OpDef& def) { | |||
auto& op = def.cast_final_safe<SubgraphOp>(); | |||
if (!op.graph_key) { | |||
return (size_t)reinterpret_cast<uintptr_t>(&op.graph); | |||
return (size_t)reinterpret_cast<uintptr_t>(op.graph.get()); | |||
} | |||
return op.graph_key->hash(); | |||
} | |||
@@ -238,7 +241,7 @@ auto is_same_st(const OpDef& def, const OpDef& another) { | |||
if (has_graph_key) { | |||
graph_same = rhs.graph_key && lhs.graph_key->is_same(*rhs.graph_key); | |||
} else { | |||
graph_same = !rhs.graph_key && &lhs.graph == &rhs.graph; | |||
graph_same = !rhs.graph_key && lhs.graph.get() == rhs.graph.get(); | |||
} | |||
return graph_same; | |||
} | |||
@@ -354,7 +357,9 @@ auto apply_on_physical_tensor( | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
return OpDef::apply_on_var_node(*def.cast_final_safe<CompiledOp>().op, inputs); | |||
auto& op = def.cast_final_safe<CompiledOp>(); | |||
op.op->set_scope(op.scope()); | |||
return OpDef::apply_on_var_node(*op.op, inputs); | |||
} | |||
auto infer_output_attrs_fallible( | |||
@@ -397,7 +402,9 @@ EncodedSubraph make_backward_graph( | |||
if (backward_graph.graph.is_single()) { | |||
bgraph_op = backward_graph.graph.as_single(); | |||
} else { | |||
bgraph_op = SubgraphOp::make(name+"Grad", backward_graph.graph, grad_outputs_has_grad, key); | |||
bgraph_op = SubgraphOp::make( | |||
name + "Grad", std::make_shared<Subgraph>(backward_graph.graph), | |||
grad_outputs_has_grad, key); | |||
} | |||
auto compiled_op = CompiledOp::make(bgraph_op, op.gopt_level); | |||
auto encoded_graph = EncodedSubraph::make_single(compiled_op, backward_graph.input_mask, backward_graph.output_mask); | |||
@@ -431,6 +438,8 @@ OP_TRAIT_REG(CompiledOp, CompiledOp) | |||
.fallback(); | |||
}} | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(UniqueKey); | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(SubgraphOp); | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardOpKey); | |||
@@ -28,7 +28,8 @@ VarNodeArray apply_on_var_node( | |||
for (auto&& input: inputs) { | |||
input_descs.push_back({TensorLayout{input->dtype()}, input->comp_node()}); | |||
} | |||
auto apply_functor = [](const std::shared_ptr<OpDef>& op, const VarNodeArray& inputs, size_t nr_outputs){ | |||
auto apply_functor = [&](const std::shared_ptr<OpDef>& op, const VarNodeArray& inputs, size_t nr_outputs){ | |||
op->set_scope(def.scope()); | |||
return OpDef::apply_on_var_node(*op, inputs); | |||
}; | |||
auto const_functor = [&](const TensorPtr& value) { | |||
@@ -48,16 +48,28 @@ struct ShapeInfer final : OpDefImplBase<ShapeInfer> { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
}; | |||
struct UniqueKey final: Hashable { | |||
public: | |||
size_t hash() const override { | |||
return reinterpret_cast<uintptr_t>(this); | |||
} | |||
protected: | |||
bool is_same_st(const Hashable& rhs) const override { | |||
return this == &rhs.cast_final_safe<UniqueKey>(); | |||
} | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
}; | |||
struct SubgraphOp final: OpDefImplBase<SubgraphOp> { | |||
std::string name; | |||
Subgraph graph; | |||
std::shared_ptr<Subgraph> graph; | |||
SmallVector<bool> output_grad_mask; | |||
std::shared_ptr<Hashable> graph_key; | |||
SubgraphOp() = default; | |||
SubgraphOp(std::string name, Subgraph graph, SmallVector<bool> output_grad_mask={}, std::shared_ptr<Hashable> key=nullptr) | |||
SubgraphOp(std::string name, std::shared_ptr<Subgraph> graph, SmallVector<bool> output_grad_mask={}, std::shared_ptr<Hashable> key=nullptr) | |||
: name{name}, graph{graph}, output_grad_mask{output_grad_mask}, graph_key{std::move(key)}{ | |||
if (this->output_grad_mask.empty()) { | |||
this->output_grad_mask.resize(graph.outputs.size(), true); | |||
this->output_grad_mask.resize(graph->outputs.size(), true); | |||
} | |||
} | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||