|
@@ -0,0 +1,404 @@ |
|
|
|
|
|
/** |
|
|
|
|
|
* \file imperative/src/impl/transformations/trace.cpp |
|
|
|
|
|
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") |
|
|
|
|
|
* |
|
|
|
|
|
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. |
|
|
|
|
|
* |
|
|
|
|
|
* Unless required by applicable law or agreed to in writing, |
|
|
|
|
|
* software distributed under the License is distributed on an |
|
|
|
|
|
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
|
|
*/ |
|
|
|
|
|
|
|
|
|
|
|
#include "megbrain/imperative/transformations/scalar.h" |
|
|
|
|
|
|
|
|
|
|
|
#include "megbrain/imperative/ops/autogen.h" |
|
|
|
|
|
|
|
|
|
|
|
namespace mgb { |
|
|
|
|
|
namespace imperative { |
|
|
|
|
|
|
|
|
|
|
|
namespace { |
|
|
|
|
|
|
|
|
|
|
|
using ScalarRule = std::function<std::vector<ValueRef>(const OpDef&, Span<ValueRef>)>; |
|
|
|
|
|
static std::unordered_map< |
|
|
|
|
|
Typeinfo*, std::function<std::vector<ValueRef>(const OpDef&, Span<ValueRef>)>> |
|
|
|
|
|
scalar_rules; |
|
|
|
|
|
|
|
|
|
|
|
ValueRef unwrap_input(ValueRef input) { |
|
|
|
|
|
if (auto scalar_input = input.as_ref<ScalarValue>()) { |
|
|
|
|
|
return scalar_input->value(); |
|
|
|
|
|
} else { |
|
|
|
|
|
return input; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<ValueRef> unwrap_inputs(Span<ValueRef> inputs) { |
|
|
|
|
|
std::vector<ValueRef> unwrapped_inputs; |
|
|
|
|
|
for (auto&& input : inputs) { |
|
|
|
|
|
unwrapped_inputs.push_back(unwrap_input(input)); |
|
|
|
|
|
} |
|
|
|
|
|
return unwrapped_inputs; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
ValueRef make_scalar_shape(CompNode device) { |
|
|
|
|
|
HostTensorND scalar_shape(device, {1}, dtype::Int32()); |
|
|
|
|
|
scalar_shape.ptr<dt_int32>()[0] = 1; |
|
|
|
|
|
return imperative::apply( |
|
|
|
|
|
CreateTensor(CreateTensor::Const, device, scalar_shape.layout()), |
|
|
|
|
|
HostStorage::make(scalar_shape.storage()))[0]; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool is_scalar_shape(ValueRef shape) { |
|
|
|
|
|
if (shape.is<ScalarValue>()) { |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
auto shape_of_shape = shape.shape(); |
|
|
|
|
|
if (!shape_of_shape) { |
|
|
|
|
|
// assume not scalar |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
return *shape_of_shape == ValueShape{0}; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
|
|
|
void register_scalar_rule(std::vector<ValueRef> (*rule)(const T&, Span<ValueRef>)) { |
|
|
|
|
|
scalar_rules[T::typeinfo()] = [rule](const OpDef& def, Span<ValueRef> inputs) { |
|
|
|
|
|
return (*rule)(def.cast_final_safe<T>(), inputs); |
|
|
|
|
|
}; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<ValueRef> elemwise_rule(const Elemwise& elem, Span<ValueRef> inputs) { |
|
|
|
|
|
bool all_scalar = true; |
|
|
|
|
|
for (auto&& input : inputs) { |
|
|
|
|
|
if (!input.is<ScalarValue>()) { |
|
|
|
|
|
all_scalar = false; |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
auto output = imperative::apply(elem, unwrap_inputs(inputs))[0]; |
|
|
|
|
|
if (all_scalar) { |
|
|
|
|
|
return {ScalarValue::make(output)}; |
|
|
|
|
|
} else { |
|
|
|
|
|
return {output}; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<ValueRef> remove_axis_rule( |
|
|
|
|
|
const RemoveAxis& remove_axis, Span<ValueRef> inputs) { |
|
|
|
|
|
mgb_assert(inputs.size() == 1); |
|
|
|
|
|
mgb_assert(!inputs[0].is<ScalarValue>()); |
|
|
|
|
|
auto output = imperative::apply(remove_axis, inputs)[0]; |
|
|
|
|
|
bool is_scalar = inputs[0].shape()->ndim == remove_axis.axis.size(); |
|
|
|
|
|
if (is_scalar) { |
|
|
|
|
|
return {ScalarValue::make(output)}; |
|
|
|
|
|
} else { |
|
|
|
|
|
return {output}; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<ValueRef> reduce_rule(const Reduce& reduce, Span<ValueRef> inputs) { |
|
|
|
|
|
if (inputs.size() == 1) { |
|
|
|
|
|
return imperative::apply(reduce, unwrap_inputs(inputs)); |
|
|
|
|
|
} |
|
|
|
|
|
mgb_assert(inputs.size() == 2); |
|
|
|
|
|
bool is_scalar = is_scalar_shape(inputs[1]); |
|
|
|
|
|
if (is_scalar) { |
|
|
|
|
|
auto unwrapped_input = unwrap_input(inputs[0]); |
|
|
|
|
|
CompNode device = *unwrapped_input.device(); |
|
|
|
|
|
return {ScalarValue::make(imperative::apply( |
|
|
|
|
|
reduce, unwrapped_input, make_scalar_shape(device))[0])}; |
|
|
|
|
|
} |
|
|
|
|
|
auto output = imperative::apply(reduce, unwrap_inputs(inputs))[0]; |
|
|
|
|
|
if (is_scalar) { |
|
|
|
|
|
return {ScalarValue::make(output)}; |
|
|
|
|
|
} else { |
|
|
|
|
|
return {output}; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<ValueRef> typecvt_rule(const TypeCvt& typecvt, Span<ValueRef> inputs) { |
|
|
|
|
|
mgb_assert(inputs.size() == 1); |
|
|
|
|
|
if (auto scalar_input = inputs[0].as_ref<ScalarValue>()) { |
|
|
|
|
|
return {ScalarValue::make( |
|
|
|
|
|
imperative::apply(typecvt, scalar_input->value())[0])}; |
|
|
|
|
|
} else { |
|
|
|
|
|
return imperative::apply(typecvt, inputs); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<ValueRef> collective_comm_rule( |
|
|
|
|
|
const CollectiveComm& collective_comm, Span<ValueRef> inputs) { |
|
|
|
|
|
mgb_assert(inputs.size() == 1); |
|
|
|
|
|
static std::unordered_set<CollectiveComm::Mode> modes = { |
|
|
|
|
|
CollectiveComm::Mode::ALL_REDUCE_MAX, CollectiveComm::Mode::ALL_REDUCE_MIN, |
|
|
|
|
|
CollectiveComm::Mode::ALL_REDUCE_SUM, CollectiveComm::Mode::BROADCAST, |
|
|
|
|
|
CollectiveComm::Mode::REDUCE_SUM, |
|
|
|
|
|
}; |
|
|
|
|
|
if (modes.count(collective_comm.mode) == 0) { |
|
|
|
|
|
return imperative::apply(collective_comm, inputs); |
|
|
|
|
|
} |
|
|
|
|
|
if (auto scalar_input = inputs[0].as_ref<ScalarValue>()) { |
|
|
|
|
|
return {ScalarValue::make( |
|
|
|
|
|
imperative::apply(collective_comm, scalar_input->value())[0])}; |
|
|
|
|
|
} else { |
|
|
|
|
|
return imperative::apply(collective_comm, inputs); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<ValueRef> param_pack_split_rule( |
|
|
|
|
|
const ParamPackSplit& param_pack_split, Span<ValueRef> inputs) { |
|
|
|
|
|
auto outputs = imperative::apply(param_pack_split, unwrap_inputs(inputs)); |
|
|
|
|
|
size_t nr_outputs = outputs.size(); |
|
|
|
|
|
mgb_assert(nr_outputs == param_pack_split.shapes.size()); |
|
|
|
|
|
for (size_t i = 0; i < nr_outputs; ++i) { |
|
|
|
|
|
if (param_pack_split.shapes[i].empty()) { |
|
|
|
|
|
outputs[i] = ScalarValue::make(outputs[i]); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return outputs; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<ValueRef> dot_rule(const Dot& dot, Span<ValueRef> inputs) { |
|
|
|
|
|
return {ScalarValue::make(imperative::apply(dot, unwrap_inputs(inputs))[0])}; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<ValueRef> add_axis_rule(const AddAxis& add_axis, Span<ValueRef> inputs) { |
|
|
|
|
|
mgb_assert(inputs.size() == 1); |
|
|
|
|
|
if (auto scalar_input = inputs[0].as_ref<ScalarValue>()) { |
|
|
|
|
|
mgb_assert(add_axis.axis[0] == 0); |
|
|
|
|
|
if (add_axis.axis.size() == 1) { |
|
|
|
|
|
return {scalar_input->value()}; |
|
|
|
|
|
} else { |
|
|
|
|
|
std::vector<int32_t> axis(add_axis.axis.begin() + 1, add_axis.axis.end()); |
|
|
|
|
|
return imperative::apply( |
|
|
|
|
|
ApplyOp(*AddAxis::make(axis, add_axis.scope())), |
|
|
|
|
|
scalar_input->value()); |
|
|
|
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
return imperative::apply(add_axis, inputs); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<ValueRef> remote_recv_rule( |
|
|
|
|
|
const RemoteRecv& remote_recv, Span<ValueRef> inputs) { |
|
|
|
|
|
if (remote_recv.shape.empty()) { |
|
|
|
|
|
std::vector<int32_t> shape = {1}; |
|
|
|
|
|
auto remote_recv_no_scalar = RemoteRecv::make( |
|
|
|
|
|
remote_recv.key, remote_recv.addr, remote_recv.port, |
|
|
|
|
|
remote_recv.rank_from, remote_recv.cn, shape, remote_recv.dtype, |
|
|
|
|
|
remote_recv.backend); |
|
|
|
|
|
remote_recv_no_scalar->set_scope(remote_recv.scope()); |
|
|
|
|
|
return imperative::apply( |
|
|
|
|
|
ApplyOp(*remote_recv_no_scalar), unwrap_inputs(inputs)); |
|
|
|
|
|
} else { |
|
|
|
|
|
return imperative::apply(remote_recv, unwrap_inputs(inputs)); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<ValueRef> check_no_finite_rule( |
|
|
|
|
|
const CheckNonFinite& check_no_finite, Span<ValueRef> inputs) { |
|
|
|
|
|
auto outputs = imperative::apply(check_no_finite, unwrap_inputs(inputs)); |
|
|
|
|
|
mgb_assert(outputs.size() == inputs.size() + 1, "output size mismatch"); |
|
|
|
|
|
outputs.back() = ScalarValue::make(outputs.back()); |
|
|
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
|
|
|
if (inputs[i].is<ScalarValue>()) { |
|
|
|
|
|
outputs[i] = ScalarValue::make(outputs[i]); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return outputs; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<ValueRef> subtensor_rule( |
|
|
|
|
|
const Subtensor& subtensor, Span<ValueRef> inputs) { |
|
|
|
|
|
mgb_assert(inputs.size() >= 1); |
|
|
|
|
|
auto input = inputs[0]; |
|
|
|
|
|
size_t ndim = input.is<ScalarValue>() ? 0 : input.shape()->ndim; |
|
|
|
|
|
for (auto&& [axis, begin, end, step, idx] : subtensor.items) { |
|
|
|
|
|
if (idx) { |
|
|
|
|
|
ndim--; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
auto output = imperative::apply(subtensor, unwrap_inputs(inputs))[0]; |
|
|
|
|
|
if (!ndim) { |
|
|
|
|
|
return {ScalarValue::make(output)}; |
|
|
|
|
|
} else { |
|
|
|
|
|
return {output}; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<ValueRef> get_var_shape_rule( |
|
|
|
|
|
const GetVarShape& get_var_shape, Span<ValueRef> inputs) { |
|
|
|
|
|
bool all_scalar = true; |
|
|
|
|
|
mgb_assert(inputs.size() >= 1); |
|
|
|
|
|
for (auto&& input : inputs) { |
|
|
|
|
|
if (!input.is<ScalarValue>()) { |
|
|
|
|
|
all_scalar = false; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
if (all_scalar) { |
|
|
|
|
|
auto device = inputs[0].cast<ScalarValue>().value().device(); |
|
|
|
|
|
auto storage = HostStorage::make(*device); |
|
|
|
|
|
// storage->ensure_size(1); |
|
|
|
|
|
return imperative::apply( |
|
|
|
|
|
CreateTensor( |
|
|
|
|
|
CreateTensor::Const, *device, dtype::Int32(), ValueShape{0}), |
|
|
|
|
|
storage); |
|
|
|
|
|
} else { |
|
|
|
|
|
return imperative::apply(get_var_shape, unwrap_inputs(inputs)); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<ValueRef> fastpath_copy_rule( |
|
|
|
|
|
const FastpathCopy& fastpath_copy, Span<ValueRef> inputs) { |
|
|
|
|
|
mgb_assert(inputs.size() == 1); |
|
|
|
|
|
bool is_scalar = inputs[0].is<ScalarValue>(); |
|
|
|
|
|
auto output = imperative::apply(fastpath_copy, unwrap_inputs(inputs))[0]; |
|
|
|
|
|
if (is_scalar) { |
|
|
|
|
|
return {ScalarValue::make(output)}; |
|
|
|
|
|
} else { |
|
|
|
|
|
return {output}; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<ValueRef> reshape_rule(const Reshape& reshape, Span<ValueRef> inputs) { |
|
|
|
|
|
mgb_assert(inputs.size() == 2); |
|
|
|
|
|
bool is_scalar = |
|
|
|
|
|
(!inputs[1].is<ScalarValue>()) && *inputs[1].shape() == ValueShape{0}; |
|
|
|
|
|
auto unwrapped_input = inputs[0].is<ScalarValue>() |
|
|
|
|
|
? inputs[0].cast<ScalarValue>().value() |
|
|
|
|
|
: inputs[0]; |
|
|
|
|
|
if (is_scalar) { |
|
|
|
|
|
return {ScalarValue::make(imperative::apply( |
|
|
|
|
|
reshape, unwrapped_input, |
|
|
|
|
|
make_scalar_shape(*unwrapped_input.device()))[0])}; |
|
|
|
|
|
} else { |
|
|
|
|
|
return imperative::apply(reshape, unwrap_inputs(inputs)); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<ValueRef> broadcast_rule( |
|
|
|
|
|
const Broadcast& broadcast, Span<ValueRef> inputs) { |
|
|
|
|
|
mgb_assert(inputs.size() == 2); |
|
|
|
|
|
bool is_scalar = is_scalar_shape(inputs[1]); |
|
|
|
|
|
auto unwrapped_input = inputs[0].is<ScalarValue>() |
|
|
|
|
|
? inputs[0].cast<ScalarValue>().value() |
|
|
|
|
|
: inputs[0]; |
|
|
|
|
|
if (is_scalar) { |
|
|
|
|
|
return {ScalarValue::make(imperative::apply( |
|
|
|
|
|
broadcast, unwrapped_input, |
|
|
|
|
|
make_scalar_shape(*unwrapped_input.device()))[0])}; |
|
|
|
|
|
} else { |
|
|
|
|
|
return imperative::apply(broadcast, unwrap_inputs(inputs)); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<ValueRef> copy_rule(const Copy& copy, Span<ValueRef> inputs) { |
|
|
|
|
|
mgb_assert(inputs.size() == 1); |
|
|
|
|
|
bool is_scalar = inputs[0].is<ScalarValue>(); |
|
|
|
|
|
if (is_scalar) { |
|
|
|
|
|
return {ScalarValue::make(imperative::apply(copy, unwrap_inputs(inputs))[0])}; |
|
|
|
|
|
} else { |
|
|
|
|
|
return imperative::apply(copy, unwrap_inputs(inputs)); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<ValueRef> inplace_add_rule( |
|
|
|
|
|
const InplaceAdd& inplace_add, Span<ValueRef> inputs) { |
|
|
|
|
|
mgb_assert(inputs.size() == 4); |
|
|
|
|
|
bool is_scalar = inputs[0].is<ScalarValue>(); |
|
|
|
|
|
if (is_scalar) { |
|
|
|
|
|
return {ScalarValue::make( |
|
|
|
|
|
imperative::apply(inplace_add, unwrap_inputs(inputs))[0])}; |
|
|
|
|
|
} else { |
|
|
|
|
|
return imperative::apply(inplace_add, unwrap_inputs(inputs)); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
struct ScalarRuleRegistry { |
|
|
|
|
|
ScalarRuleRegistry() { |
|
|
|
|
|
register_scalar_rule(elemwise_rule); |
|
|
|
|
|
register_scalar_rule(remove_axis_rule); |
|
|
|
|
|
register_scalar_rule(reduce_rule); |
|
|
|
|
|
register_scalar_rule(typecvt_rule); |
|
|
|
|
|
register_scalar_rule(collective_comm_rule); |
|
|
|
|
|
register_scalar_rule(param_pack_split_rule); |
|
|
|
|
|
register_scalar_rule(dot_rule); |
|
|
|
|
|
register_scalar_rule(add_axis_rule); |
|
|
|
|
|
register_scalar_rule(remote_recv_rule); |
|
|
|
|
|
register_scalar_rule(check_no_finite_rule); |
|
|
|
|
|
register_scalar_rule(subtensor_rule); |
|
|
|
|
|
register_scalar_rule(get_var_shape_rule); |
|
|
|
|
|
register_scalar_rule(fastpath_copy_rule); |
|
|
|
|
|
register_scalar_rule(reshape_rule); |
|
|
|
|
|
register_scalar_rule(broadcast_rule); |
|
|
|
|
|
register_scalar_rule(copy_rule); |
|
|
|
|
|
register_scalar_rule(inplace_add_rule); |
|
|
|
|
|
} |
|
|
|
|
|
} _; |
|
|
|
|
|
} // namespace |
|
|
|
|
|
|
|
|
|
|
|
std::vector<ValueRef> ScalarTransformation::apply_transformation( |
|
|
|
|
|
const Operator& op, Span<ValueRef> inputs) { |
|
|
|
|
|
if (auto apply_op = op.as<ApplyOp>()) { |
|
|
|
|
|
auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo()); |
|
|
|
|
|
if (iter != scalar_rules.end()) { |
|
|
|
|
|
return iter->second(apply_op->op(), inputs); |
|
|
|
|
|
} else { |
|
|
|
|
|
// TODO: repeat op |
|
|
|
|
|
return imperative::apply(op, unwrap_inputs(inputs)); |
|
|
|
|
|
} |
|
|
|
|
|
} else if (auto* create_tensor = op.as<CreateTensor>()) { |
|
|
|
|
|
if (create_tensor->shape().is_scalar()) { |
|
|
|
|
|
ValueShape scalar_shape = {1}; |
|
|
|
|
|
CreateTensor scalar_op( |
|
|
|
|
|
create_tensor->kind(), create_tensor->device(), |
|
|
|
|
|
create_tensor->dtype(), scalar_shape); |
|
|
|
|
|
return {ScalarValue::make(imperative::apply(scalar_op, inputs)[0])}; |
|
|
|
|
|
} else { |
|
|
|
|
|
return imperative::apply(op, inputs); |
|
|
|
|
|
} |
|
|
|
|
|
} else if (auto* get_attr = op.as<GetAttr>()) { |
|
|
|
|
|
bool is_scalar = inputs.as_array<1>()[0].is<ScalarValue>(); |
|
|
|
|
|
auto output = imperative::apply(op, unwrap_inputs(inputs))[0]; |
|
|
|
|
|
if (!is_scalar) { |
|
|
|
|
|
return {output}; |
|
|
|
|
|
} |
|
|
|
|
|
switch (get_attr->attr()) { |
|
|
|
|
|
case GetAttr::Shape: { |
|
|
|
|
|
// Scalar Shape |
|
|
|
|
|
return {ShapeValue::make()}; |
|
|
|
|
|
} |
|
|
|
|
|
case GetAttr::Value: { |
|
|
|
|
|
auto& hv = output.cast<HostValue>(); |
|
|
|
|
|
mgb_assert( |
|
|
|
|
|
hv.shape() == ValueShape({1}), |
|
|
|
|
|
"underlying value should has shape {1}, got %s", |
|
|
|
|
|
hv.shape().to_string().c_str()); |
|
|
|
|
|
return {HostValue::make(hv.dtype(), ValueShape(), hv.storage())}; |
|
|
|
|
|
} |
|
|
|
|
|
case GetAttr::Data: { |
|
|
|
|
|
auto& dv = output.cast<DeviceValue>(); |
|
|
|
|
|
mgb_assert( |
|
|
|
|
|
dv.shape() == ValueShape({1}), |
|
|
|
|
|
"underlying value should has shape {1}, got %s", |
|
|
|
|
|
dv.shape().to_string().c_str()); |
|
|
|
|
|
return {DeviceValue::make(dv.dtype(), ValueShape(), dv.storage())}; |
|
|
|
|
|
} |
|
|
|
|
|
default: |
|
|
|
|
|
return {output}; |
|
|
|
|
|
} |
|
|
|
|
|
} else if (op.as<IsScalar>()) { |
|
|
|
|
|
return {BoolValue::make(inputs.as_array<1>()[0].is<ScalarValue>())}; |
|
|
|
|
|
} else if (op.is<Operator::IdentityLike>()) { |
|
|
|
|
|
bool is_scalar = inputs.as_array<1>()[0].is<ScalarValue>(); |
|
|
|
|
|
if (is_scalar) { |
|
|
|
|
|
return {ScalarValue::make(imperative::apply(op, unwrap_inputs(inputs))[0])}; |
|
|
|
|
|
} else { |
|
|
|
|
|
return imperative::apply(op, inputs); |
|
|
|
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
return imperative::apply(op, unwrap_inputs(inputs)); |
|
|
|
|
|
} |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
} // namespace imperative |
|
|
|
|
|
} // namespace mgb |