and enable_opencl_deploy lite api
GitOrigin-RevId: 9d932ff27e
tags/v1.9.0
@@ -16,6 +16,7 @@ | |||||
#include "decryption/rc4_cryption.h" | #include "decryption/rc4_cryption.h" | ||||
#include "lite/global.h" | #include "lite/global.h" | ||||
#include "misc.h" | #include "misc.h" | ||||
#include "network_impl_base.h" | |||||
#include "parse_info/default_parse.h" | #include "parse_info/default_parse.h" | ||||
#include "parse_info/parse_info_base.h" | #include "parse_info/parse_info_base.h" | ||||
@@ -51,6 +51,29 @@ LITE_API std::string ssprintf(const char* fmt = 0, ...) | |||||
*/ | */ | ||||
LITE_API void print_log(LiteLogLevel level, const char* format = 0, ...) | LITE_API void print_log(LiteLogLevel level, const char* format = 0, ...) | ||||
__attribute__((format(printf, 2, 3))); | __attribute__((format(printf, 2, 3))); | ||||
/*! | |||||
* \brief NonCopyableObj base. | |||||
*/ | |||||
class NonCopyableObj { | |||||
public: | |||||
NonCopyableObj() {} | |||||
private: | |||||
NonCopyableObj(const NonCopyableObj&); | |||||
NonCopyableObj& operator=(const NonCopyableObj&); | |||||
}; | |||||
template <class T> | |||||
class Singleton : public NonCopyableObj { | |||||
public: | |||||
Singleton() {} | |||||
static T& Instance() { | |||||
static T _; | |||||
return _; | |||||
} | |||||
}; | |||||
} // namespace lite | } // namespace lite | ||||
#if LITE_ENABLE_LOGGING | #if LITE_ENABLE_LOGGING | ||||
@@ -16,11 +16,33 @@ | |||||
#include "tensor_impl_base.h" | #include "tensor_impl_base.h" | ||||
#include "type_info.h" | #include "type_info.h" | ||||
#include <atomic> | |||||
#include <unordered_map> | #include <unordered_map> | ||||
namespace lite { | namespace lite { | ||||
/*! | /*! | ||||
* \brief network reference count | |||||
*/ | |||||
class NetworkRefCount : public Singleton<NetworkRefCount> { | |||||
public: | |||||
NetworkRefCount() : count(0) {} | |||||
NetworkRefCount& operator++(int) { | |||||
++count; | |||||
return *this; | |||||
} | |||||
NetworkRefCount& operator--(int) { | |||||
--count; | |||||
return *this; | |||||
} | |||||
int refcount() { return count; } | |||||
private: | |||||
std::atomic<int> count; | |||||
}; | |||||
/*! | |||||
* \brief the Inner IO data struct, add some inner data from IO | * \brief the Inner IO data struct, add some inner data from IO | ||||
*/ | */ | ||||
class IOInner : public IO { | class IOInner : public IO { | ||||
@@ -54,7 +76,8 @@ struct NetworkIOInner { | |||||
*/ | */ | ||||
class Network::NetworkImplBase : public DynTypeObj { | class Network::NetworkImplBase : public DynTypeObj { | ||||
public: | public: | ||||
virtual ~NetworkImplBase() = default; | |||||
virtual ~NetworkImplBase() { NetworkRefCount::Instance()--; }; | |||||
NetworkImplBase() { NetworkRefCount::Instance()++; }; | |||||
//! set the config of the network, include: | //! set the config of the network, include: | ||||
//! the inference device | //! the inference device | ||||
@@ -70,6 +70,19 @@ TEST(TestNetWork, Basic) { | |||||
compare_lite_tensor<float>(result_lite, result_mgb); | compare_lite_tensor<float>(result_lite, result_mgb); | ||||
} | } | ||||
TEST(TestNetWork, RefCount) { | |||||
Config config; | |||||
ASSERT_EQ(NetworkRefCount::Instance().refcount(), 0); | |||||
std::shared_ptr<Network> network = std::make_shared<Network>(config); | |||||
ASSERT_EQ(NetworkRefCount::Instance().refcount(), 1); | |||||
std::shared_ptr<Network> network_s = std::make_shared<Network>(config); | |||||
ASSERT_EQ(NetworkRefCount::Instance().refcount(), 2); | |||||
network.reset(); | |||||
ASSERT_EQ(NetworkRefCount::Instance().refcount(), 1); | |||||
network_s.reset(); | |||||
ASSERT_EQ(NetworkRefCount::Instance().refcount(), 0); | |||||
} | |||||
TEST(TestNetWork, SetDeviceId) { | TEST(TestNetWork, SetDeviceId) { | ||||
Config config; | Config config; | ||||
auto lite_tensor = get_input_data("./input_data.npy"); | auto lite_tensor = get_input_data("./input_data.npy"); | ||||