@@ -413,7 +413,7 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const { | |||||
megdnn_throw_if( | megdnn_throw_if( | ||||
cur_shape != 1 && cur_stride != 0, tensor_reshape_error, | cur_shape != 1 && cur_stride != 0, tensor_reshape_error, | ||||
megdnn_mangle(ssprintf( | megdnn_mangle(ssprintf( | ||||
"brodcast on dim with shape not equal to 1: " | |||||
"broadcast on dim with shape not equal to 1: " | |||||
"src_shape=%s dst_shape=%s", | "src_shape=%s dst_shape=%s", | ||||
to_string().c_str(), tshape.to_string().c_str()))); | to_string().c_str(), tshape.to_string().c_str()))); | ||||
result.shape[target_idx] = tshape.shape[target_idx]; | result.shape[target_idx] = tshape.shape[target_idx]; | ||||
@@ -47,7 +47,9 @@ def _(op: OpDef, inputs, outputs, input_requires_grad): | |||||
grad_fn = reduce_sum_grad_fn | grad_fn = reduce_sum_grad_fn | ||||
else: | else: | ||||
grad_fn = default_grad_fn | grad_fn = default_grad_fn | ||||
elif isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD: | |||||
elif isinstance(op, Broadcast) or ( | |||||
isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD | |||||
): | |||||
grad_fn = elemwise_add_grad_fn | grad_fn = elemwise_add_grad_fn | ||||
else: | else: | ||||
grad_fn = default_grad_fn | grad_fn = default_grad_fn | ||||
@@ -212,5 +214,4 @@ _oprAttr_grad_fn = { | |||||
Reshape.name: reshape_grad_fn, | Reshape.name: reshape_grad_fn, | ||||
Subtensor.name: subtensor_grad_fn, | Subtensor.name: subtensor_grad_fn, | ||||
IndexingMultiAxisVec.name: indexingMultiAxisVec_grad_fn, | IndexingMultiAxisVec.name: indexingMultiAxisVec_grad_fn, | ||||
Broadcast.name: elemwise_add_grad_fn, | |||||
} | } |
@@ -59,29 +59,7 @@ def _transpose(data, axes): | |||||
def _broadcast(inp, shape): | def _broadcast(inp, shape): | ||||
def valid_broadcast(src, tar): | |||||
def failed(): | |||||
raise ValueError( | |||||
"the input shape {} can not be broadcasted to target shape {}".format( | |||||
src, tar | |||||
) | |||||
) | |||||
if isinstance(src, (TensorBase, TensorWrapperBase)): | |||||
src = src.numpy() | |||||
if isinstance(tar, (TensorBase, TensorWrapperBase)): | |||||
tar = tar.numpy() | |||||
if len(src) > len(tar): | |||||
failed() | |||||
for i in range(min(len(src), len(tar))): | |||||
if src[-i - 1] != 1 and src[-i - 1] != tar[-i - 1]: | |||||
failed() | |||||
shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) | shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) | ||||
valid_broadcast(inp.shape, shape) | |||||
(result,) = apply(builtin.Broadcast(), inp, shape) | (result,) = apply(builtin.Broadcast(), inp, shape) | ||||
return result | return result | ||||
@@ -21,6 +21,7 @@ | |||||
#include "megbrain/imperative/ops/nms.h" | #include "megbrain/imperative/ops/nms.h" | ||||
#include "megbrain/imperative/ops/elemwise.h" | #include "megbrain/imperative/ops/elemwise.h" | ||||
#include "megbrain/imperative/ops/batch_norm.h" | #include "megbrain/imperative/ops/batch_norm.h" | ||||
#include "megbrain/imperative/ops/broadcast.h" | |||||
namespace py = pybind11; | namespace py = pybind11; | ||||
@@ -206,4 +207,7 @@ void init_ops(py::module m) { | |||||
V(INFERENCE); | V(INFERENCE); | ||||
#undef V | #undef V | ||||
py::class_<Broadcast, std::shared_ptr<Broadcast>, OpDef>(m, "Broadcast") | |||||
.def(py::init<>()); | |||||
} | } |
@@ -262,13 +262,13 @@ def test_broadcast(): | |||||
opr_test(cases, F.broadcast_to, compare_fn=compare_fn) | opr_test(cases, F.broadcast_to, compare_fn=compare_fn) | ||||
x = F.ones((2, 1, 3)) | x = F.ones((2, 1, 3)) | ||||
with pytest.raises(ValueError): | |||||
with pytest.raises(RuntimeError): | |||||
F.broadcast_to(x, (2, 3, 4)) | F.broadcast_to(x, (2, 3, 4)) | ||||
with pytest.raises(ValueError): | |||||
with pytest.raises(RuntimeError): | |||||
F.broadcast_to(x, (4, 1, 3)) | F.broadcast_to(x, (4, 1, 3)) | ||||
with pytest.raises(ValueError): | |||||
with pytest.raises(RuntimeError): | |||||
F.broadcast_to(x, (1, 3)) | F.broadcast_to(x, (1, 3)) | ||||
@@ -0,0 +1,95 @@ | |||||
/** | |||||
* \file imperative/src/impl/ops/broadcast.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/imperative/ops/broadcast.h" | |||||
#include "../op_trait.h" | |||||
namespace mgb { | |||||
namespace imperative { | |||||
namespace { | |||||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
node_->cast_final_safe<opr::Broadcast>(); | |||||
return Broadcast::make(); | |||||
} | |||||
cg::OperatorNodeBase* apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
def.cast_final_safe<Broadcast>(); | |||||
size_t nr_inp = inputs.size(); | |||||
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | |||||
return opr::Broadcast::make(inputs[0], inputs[1]).node()->owner_opr(); | |||||
} | |||||
bool valid_broadcast(const TensorShape& src_shape, | |||||
const TensorShape& tar_shape) { | |||||
size_t src_ndim = src_shape.ndim, tar_ndim = tar_shape.ndim; | |||||
if (src_ndim > tar_ndim) { | |||||
return false; | |||||
} | |||||
size_t min_ndim = src_ndim < tar_ndim ? src_ndim : tar_ndim; | |||||
for (size_t i = 0; i < min_ndim; ++i) { | |||||
if (src_shape[src_ndim - i - 1] != 1 && | |||||
src_shape[src_ndim - i - 1] != tar_shape[tar_ndim - i - 1]) { | |||||
return false; | |||||
} | |||||
} | |||||
return true; | |||||
} | |||||
SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( | |||||
const OpDef& def, | |||||
const SmallVector<LogicalTensorDesc>& inputs) { | |||||
def.cast_final_safe<Broadcast>(); | |||||
size_t nr_inp = inputs.size(); | |||||
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | |||||
auto&& src = inputs[0]; | |||||
auto&& tshp = inputs[1]; | |||||
TensorLayout out_layout = src.layout; | |||||
if (tshp.layout.ndim == 0 || tshp.value.empty()) { | |||||
out_layout.ndim = 0; | |||||
return {{out_layout, src.comp_node}}; | |||||
} | |||||
mgb_assert( | |||||
tshp.layout.ndim == 1, | |||||
"target shape of Broadcast expects ndim=1; got ndim=%lu actually", | |||||
tshp.layout.ndim); | |||||
size_t target_ndim = tshp.layout.shape[0]; | |||||
out_layout.ndim = target_ndim; | |||||
auto* ptr = tshp.value.ptr<dt_int32>(); | |||||
for(size_t i=0; i<target_ndim; ++i) { | |||||
out_layout.shape[i] = ptr[i]; | |||||
} | |||||
mgb_assert(valid_broadcast(src.layout, out_layout), | |||||
"the input shape %s can not be broadcasted to target shape %s", | |||||
src.layout.TensorShape::to_string().c_str(), | |||||
out_layout.TensorShape::to_string().c_str()); | |||||
return {{out_layout, src.comp_node}}; | |||||
} | |||||
OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) | |||||
.make_from_op_node(make_from_op_node) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
.fallback(); | |||||
} // anonymous namespace | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Broadcast); | |||||
} // namespace imperative | |||||
} // namespace mgb | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,35 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/ops/broadcast.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain/opr/tensor_manip.h" | |||||
#include "megbrain/imperative/ops/opr_attr.h" | |||||
#include "megbrain/imperative/op_def.h" | |||||
namespace mgb::imperative { | |||||
class Broadcast : public OpDefImplBase<Broadcast> { | |||||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
public: | |||||
Broadcast() = default; | |||||
size_t hash() const override { | |||||
return reinterpret_cast<std::uintptr_t>(dyn_typeinfo()); | |||||
} | |||||
bool is_same_st(const Hashable& rhs) const override { | |||||
return true; | |||||
} | |||||
}; | |||||
} // namespace mgb::imperative |
@@ -32,8 +32,7 @@ public: | |||||
bool is_same_st(const Hashable& rhs_) const override { | bool is_same_st(const Hashable& rhs_) const override { | ||||
auto&& rhs = static_cast<const NMSKeep&>(rhs_); | auto&& rhs = static_cast<const NMSKeep&>(rhs_); | ||||
return rhs.dyn_typeinfo() == dyn_typeinfo() | |||||
&& rhs.iou_thresh == iou_thresh | |||||
return rhs.iou_thresh == iou_thresh | |||||
&& rhs.max_output == max_output; | && rhs.max_output == max_output; | ||||
} | } | ||||