diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index b3a30d12..b29b83ce 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -1399,12 +1399,12 @@ void init_tensor(py::module m) { std::function array_comparator; bool compare_value(ValueRef lhs, ValueRef rhs) { - auto lvalue = lhs.numpy(); - auto rvalue = rhs.numpy(); + auto lvalue = lhs.cast_ref(); + auto rvalue = rhs.cast_ref(); if (lvalue->shape() != rvalue->shape()) { return false; } - if (lvalue->shape().is_scalar()) { + if (lvalue->shape().total_nr_elems() == 1) { return lvalue->item() == rvalue->item(); } HostTensorND lnd = lvalue->as_nd(true); diff --git a/imperative/src/impl/transformations/lazy.cpp b/imperative/src/impl/transformations/lazy.cpp index af37e39b..403c1987 100644 --- a/imperative/src/impl/transformations/lazy.cpp +++ b/imperative/src/impl/transformations/lazy.cpp @@ -50,9 +50,10 @@ ValueRefList LazyEvalTransformation::apply_transformation( } if (require_link && m_io_link.node()) { mgb_assert(!input_nodes.empty()); - input_nodes[0] = - opr::VirtualDep::make({SymbolVar(input_nodes[0]), m_io_link}) - .node(); + auto comp_node = m_io_link.node()->comp_node(); + input_nodes[0] = opr::VirtualDep::make( + {SymbolVar(input_nodes[0]), m_io_link}, comp_node) + .node(); } VarNodeArray output_nodes = OpDef::apply_on_var_node(op_val->op(), input_nodes); if (require_link) { diff --git a/imperative/src/impl/transformations/trace.cpp b/imperative/src/impl/transformations/trace.cpp index f13ec73f..afce8017 100644 --- a/imperative/src/impl/transformations/trace.cpp +++ b/imperative/src/impl/transformations/trace.cpp @@ -196,10 +196,11 @@ ValueRefList TracingTransformation::apply_transformation( return outputs; } bool is_const = create_tensor->kind() == CreateTensor::Const; + bool as_const = is_const || m_capture_as_const; auto wrapped_input = record_var( - outputs[0], is_const || m_capture_as_const, - is_const ? VarKind::Constant : VarKind::External); - auto wrapped_output = record_var(outputs[0], false, VarKind::Internal); + outputs[0], as_const, is_const ? VarKind::Constant : VarKind::External); + // bound data to outputs too to reduce runtime overhead for shape/value infer + auto wrapped_output = record_var(outputs[0], as_const, VarKind::Internal); auto input_id = wrapped_input->id(); auto output_id = wrapped_output->id(); m_seq.push_back({{}, {input_id}, {output_id}}); @@ -311,6 +312,18 @@ void CompiledTransformation::compile() { auto make_output = [&](TraceResult::VarInfo* var_info, SymbolVar node) { VarAccessor accessor; accessor.node = node.node(); + if (auto bound_data = var_info->bound_data) { + accessor.shape_getter = [bound_data]() -> TensorShape { + return bound_data.shape()->as_tensor_shape(); + }; + accessor.data_getter = [bound_data]() -> DeviceTensorND { + return bound_data.dev_tensor()->as_nd(); + }; + accessor.value_getter = [bound_data]() -> HostTensorND { + return bound_data.numpy()->as_nd(); + }; + return accessor; + } if (var_info->data_required) { // reduce d2h when data is available // FIXME: compile should not change var_info in-place @@ -410,16 +423,28 @@ void CompiledTransformation::compile() { "internal node should be valid when used as input"); } } - input_vars.push_back(var_accessors[input].node); + auto& node = var_accessors[input].node; + if (input_vars.empty() && require_link && mm_io_link.node()) { + /*mgb_assert( + !input_vars.empty(), + "io-mm operator should have at least one input");*/ + auto comp_node = mm_io_link.node()->comp_node(); + // auto comp_node = input_vars[0]->comp_node(); + node = opr::VirtualDep::make({SymbolVar(node), mm_io_link}, comp_node) + .node(); + } + input_vars.push_back(node); } - if (require_link && mm_io_link.node()) { + /*if (require_link && mm_io_link.node()) { mgb_assert( !input_vars.empty(), "io-mm operator should have at least one input"); - input_vars[0] = - opr::VirtualDep::make({SymbolVar(input_vars[0]), mm_io_link}) - .node(); - } + auto comp_node = mm_io_link.node()->comp_node(); + // auto comp_node = input_vars[0]->comp_node(); + input_vars[0] = opr::VirtualDep::make( + {SymbolVar(input_vars[0]), mm_io_link}, comp_node) + .node(); + }*/ VarNodeArray output_vars; if (item.op) { output_vars = OpDef::apply_on_var_node(*item.op, input_vars); @@ -479,6 +504,12 @@ void CompiledTransformation::recompile() { } void CompiledTransformation::assert_tensor_equal(ValueRef lhs, ValueRef rhs) { + if (!lhs.is()) { + lhs = lhs.numpy(); + } + if (!rhs.is()) { + rhs = rhs.numpy(); + } trace_assert(m_value_comparator(lhs, rhs), "tensors not equals"); } @@ -507,6 +538,7 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) { break; } case VarKind::Constant: { + // expect host value here mgb_assert(var.bound_data, "const var without data bound"); assert_tensor_equal(var.bound_data, value); break; @@ -611,7 +643,17 @@ ValueRefList CompiledTransformation::apply_create_tensor( trace_assert(item.op == nullptr, "operator mismatch"); auto input_id = item.inputs[0]; auto output_id = item.outputs[0]; - auto tensor = imperative::apply(create_tensor, inputs)[0]; + ValueRef tensor; + if (create_tensor.kind() == CreateTensor::Const) { + auto args = create_tensor.parse(inputs); + if (args.host) { + // performance issue + tensor = HostValue::make(*args.host); + } + } + if (!tensor) { + tensor = imperative::apply(create_tensor, inputs)[0]; + } trace_input(input_id, tensor); return {trace_output(output_id)}; } diff --git a/imperative/src/include/megbrain/imperative/basic_values.h b/imperative/src/include/megbrain/imperative/basic_values.h index 4b97bb2f..02e528ed 100644 --- a/imperative/src/include/megbrain/imperative/basic_values.h +++ b/imperative/src/include/megbrain/imperative/basic_values.h @@ -103,7 +103,8 @@ public: CompNode device() const { return m_storage.comp_node(); } const HostTensorStorage& storage() const { return m_storage; } DTypeScalar item() const { - mgb_assert(m_shape.is_scalar()); + // FIXME: check scalar + mgb_assert(m_shape.total_nr_elems()); return DTypeScalar::make_from_raw(m_dtype, m_storage.ptr()); } diff --git a/imperative/src/include/megbrain/imperative/transformations/trace.h b/imperative/src/include/megbrain/imperative/transformations/trace.h index c73b4ce4..36432107 100644 --- a/imperative/src/include/megbrain/imperative/transformations/trace.h +++ b/imperative/src/include/megbrain/imperative/transformations/trace.h @@ -47,7 +47,8 @@ struct TraceResult { DTypeValue::ref_t dtype; CompNodeValue::ref_t device; - // if exists, assert equal when meet + // if exists, for input: assert equal + // for output: get_data/shape/value ValueRef bound_data; std::string mark; std::string name; diff --git a/imperative/src/include/megbrain/imperative/utils/value_shape.h b/imperative/src/include/megbrain/imperative/utils/value_shape.h index c33fbe6a..a9b60bf0 100644 --- a/imperative/src/include/megbrain/imperative/utils/value_shape.h +++ b/imperative/src/include/megbrain/imperative/utils/value_shape.h @@ -49,6 +49,7 @@ struct ValueShape { size_t total_nr_elems() const { size_t prod = 1; + mgb_assert(ndim >= 0 && ndim < 8); for (int i = 0; i < ndim; ++i) { prod *= shape[i]; } @@ -103,4 +104,4 @@ static_assert(sizeof(size_t) >= sizeof(int)); static_assert(TensorShape::MAX_NDIM == 7); static_assert(sizeof(ValueShape) <= sizeof(size_t) * 8); -} // namespace mgb::imperative \ No newline at end of file +} // namespace mgb::imperative