@@ -413,7 +413,7 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const { | |||
megdnn_throw_if( | |||
cur_shape != 1 && cur_stride != 0, tensor_reshape_error, | |||
megdnn_mangle(ssprintf( | |||
"brodcast on dim with shape not equal to 1: " | |||
"broadcast on dim with shape not equal to 1: " | |||
"src_shape=%s dst_shape=%s", | |||
to_string().c_str(), tshape.to_string().c_str()))); | |||
result.shape[target_idx] = tshape.shape[target_idx]; | |||
@@ -47,7 +47,9 @@ def _(op: OpDef, inputs, outputs, input_requires_grad): | |||
grad_fn = reduce_sum_grad_fn | |||
else: | |||
grad_fn = default_grad_fn | |||
elif isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD: | |||
elif isinstance(op, Broadcast) or ( | |||
isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD | |||
): | |||
grad_fn = elemwise_add_grad_fn | |||
else: | |||
grad_fn = default_grad_fn | |||
@@ -212,5 +214,4 @@ _oprAttr_grad_fn = { | |||
Reshape.name: reshape_grad_fn, | |||
Subtensor.name: subtensor_grad_fn, | |||
IndexingMultiAxisVec.name: indexingMultiAxisVec_grad_fn, | |||
Broadcast.name: elemwise_add_grad_fn, | |||
} |
@@ -59,29 +59,7 @@ def _transpose(data, axes): | |||
def _broadcast(inp, shape): | |||
def valid_broadcast(src, tar): | |||
def failed(): | |||
raise ValueError( | |||
"the input shape {} can not be broadcasted to target shape {}".format( | |||
src, tar | |||
) | |||
) | |||
if isinstance(src, (TensorBase, TensorWrapperBase)): | |||
src = src.numpy() | |||
if isinstance(tar, (TensorBase, TensorWrapperBase)): | |||
tar = tar.numpy() | |||
if len(src) > len(tar): | |||
failed() | |||
for i in range(min(len(src), len(tar))): | |||
if src[-i - 1] != 1 and src[-i - 1] != tar[-i - 1]: | |||
failed() | |||
shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) | |||
valid_broadcast(inp.shape, shape) | |||
(result,) = apply(builtin.Broadcast(), inp, shape) | |||
return result | |||
@@ -21,6 +21,7 @@ | |||
#include "megbrain/imperative/ops/nms.h" | |||
#include "megbrain/imperative/ops/elemwise.h" | |||
#include "megbrain/imperative/ops/batch_norm.h" | |||
#include "megbrain/imperative/ops/broadcast.h" | |||
namespace py = pybind11; | |||
@@ -206,4 +207,7 @@ void init_ops(py::module m) { | |||
V(INFERENCE); | |||
#undef V | |||
py::class_<Broadcast, std::shared_ptr<Broadcast>, OpDef>(m, "Broadcast") | |||
.def(py::init<>()); | |||
} |
@@ -262,13 +262,13 @@ def test_broadcast(): | |||
opr_test(cases, F.broadcast_to, compare_fn=compare_fn) | |||
x = F.ones((2, 1, 3)) | |||
with pytest.raises(ValueError): | |||
with pytest.raises(RuntimeError): | |||
F.broadcast_to(x, (2, 3, 4)) | |||
with pytest.raises(ValueError): | |||
with pytest.raises(RuntimeError): | |||
F.broadcast_to(x, (4, 1, 3)) | |||
with pytest.raises(ValueError): | |||
with pytest.raises(RuntimeError): | |||
F.broadcast_to(x, (1, 3)) | |||
@@ -0,0 +1,95 @@ | |||
/** | |||
* \file imperative/src/impl/ops/broadcast.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/imperative/ops/broadcast.h" | |||
#include "../op_trait.h" | |||
namespace mgb { | |||
namespace imperative { | |||
namespace { | |||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
node_->cast_final_safe<opr::Broadcast>(); | |||
return Broadcast::make(); | |||
} | |||
cg::OperatorNodeBase* apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
def.cast_final_safe<Broadcast>(); | |||
size_t nr_inp = inputs.size(); | |||
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | |||
return opr::Broadcast::make(inputs[0], inputs[1]).node()->owner_opr(); | |||
} | |||
bool valid_broadcast(const TensorShape& src_shape, | |||
const TensorShape& tar_shape) { | |||
size_t src_ndim = src_shape.ndim, tar_ndim = tar_shape.ndim; | |||
if (src_ndim > tar_ndim) { | |||
return false; | |||
} | |||
size_t min_ndim = src_ndim < tar_ndim ? src_ndim : tar_ndim; | |||
for (size_t i = 0; i < min_ndim; ++i) { | |||
if (src_shape[src_ndim - i - 1] != 1 && | |||
src_shape[src_ndim - i - 1] != tar_shape[tar_ndim - i - 1]) { | |||
return false; | |||
} | |||
} | |||
return true; | |||
} | |||
SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( | |||
const OpDef& def, | |||
const SmallVector<LogicalTensorDesc>& inputs) { | |||
def.cast_final_safe<Broadcast>(); | |||
size_t nr_inp = inputs.size(); | |||
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | |||
auto&& src = inputs[0]; | |||
auto&& tshp = inputs[1]; | |||
TensorLayout out_layout = src.layout; | |||
if (tshp.layout.ndim == 0 || tshp.value.empty()) { | |||
out_layout.ndim = 0; | |||
return {{out_layout, src.comp_node}}; | |||
} | |||
mgb_assert( | |||
tshp.layout.ndim == 1, | |||
"target shape of Broadcast expects ndim=1; got ndim=%lu actually", | |||
tshp.layout.ndim); | |||
size_t target_ndim = tshp.layout.shape[0]; | |||
out_layout.ndim = target_ndim; | |||
auto* ptr = tshp.value.ptr<dt_int32>(); | |||
for(size_t i=0; i<target_ndim; ++i) { | |||
out_layout.shape[i] = ptr[i]; | |||
} | |||
mgb_assert(valid_broadcast(src.layout, out_layout), | |||
"the input shape %s can not be broadcasted to target shape %s", | |||
src.layout.TensorShape::to_string().c_str(), | |||
out_layout.TensorShape::to_string().c_str()); | |||
return {{out_layout, src.comp_node}}; | |||
} | |||
OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) | |||
.make_from_op_node(make_from_op_node) | |||
.apply_on_var_node(apply_on_var_node) | |||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
.fallback(); | |||
} // anonymous namespace | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Broadcast); | |||
} // namespace imperative | |||
} // namespace mgb | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,35 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/ops/broadcast.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megbrain/imperative/ops/opr_attr.h" | |||
#include "megbrain/imperative/op_def.h" | |||
namespace mgb::imperative { | |||
class Broadcast : public OpDefImplBase<Broadcast> { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
public: | |||
Broadcast() = default; | |||
size_t hash() const override { | |||
return reinterpret_cast<std::uintptr_t>(dyn_typeinfo()); | |||
} | |||
bool is_same_st(const Hashable& rhs) const override { | |||
return true; | |||
} | |||
}; | |||
} // namespace mgb::imperative |
@@ -32,8 +32,7 @@ public: | |||
bool is_same_st(const Hashable& rhs_) const override { | |||
auto&& rhs = static_cast<const NMSKeep&>(rhs_); | |||
return rhs.dyn_typeinfo() == dyn_typeinfo() | |||
&& rhs.iou_thresh == iou_thresh | |||
return rhs.iou_thresh == iou_thresh | |||
&& rhs.max_output == max_output; | |||
} | |||