diff --git a/cpp/dbpool.cc b/cpp/dbpool.cc new file mode 100644 index 0000000..2447a3d --- /dev/null +++ b/cpp/dbpool.cc @@ -0,0 +1,421 @@ +#include "precompile.h" + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "framework/cpp/dbpool.h" +#include "framework/cpp/msgqueue.h" +#include "framework/cpp/utils.h" + +enum AsyncQueryError +{ + AQE_NO_ERROR = 0, + AQE_EXEC_TYPE_ERROR = 1, + AQE_QUERY_TYPE_ERROR = 2, + AQE_SYNTAX_ERROR = 3, + AQE_CONN_ERROR = 4 +}; + +struct AsyncQueryRequest +{ + list_head entry; + long long context_id = 0; + std::string sql; + a8::XParams param; + time_t add_time = 0; + AsyncDBOnOkFunc on_ok = nullptr; + AsyncDBOnErrorFunc on_error = nullptr; +}; + +struct AsyncQueryNode +{ + int socket_handle = 0; + int query_type = 0; + long long context_id = 0; + std::string sql; +#if 1 + std::string _sql_fmt; + std::initializer_list _sql_params; + a8::XObject conn_info; +#endif + AsyncQueryNode* nextnode = nullptr; +}; + +class DBThread +{ +public: + + void Init() + { + loop_mutex_ = new std::mutex(); + loop_cond_ = new std::condition_variable(); + + last_checkdb_tick_ = a8::XGetTickCount(); + + top_node_ = nullptr; + bot_node_ = nullptr; + work_node_ = nullptr; + msg_mutex_ = new std::mutex(); + work_thread_ = new std::thread(&DBThread::WorkThreadProc, this); + } + + void AddAsyncQuery(int sockhandle, int query_type, long long context_id, const std::string& sql) + { + AsyncQueryNode *p = new AsyncQueryNode(); + p->query_type = query_type; + p->socket_handle = sockhandle; + p->context_id = context_id; + p->sql = sql; + + std::unique_lock lk(*loop_mutex_); + msg_mutex_->lock(); + if (bot_node_) { + bot_node_->nextnode = p; + bot_node_ = p; + } else { + top_node_ = p; + bot_node_ = p; + } + msg_mutex_->unlock(); + loop_cond_->notify_all(); + } + + void AddAsyncQuery(AsyncQueryNode* p) + { + std::unique_lock lk(*loop_mutex_); + msg_mutex_->lock(); + if (bot_node_) { + bot_node_->nextnode = p; + bot_node_ = p; + } else { + top_node_ = p; + bot_node_ = p; + } + msg_mutex_->unlock(); + loop_cond_->notify_all(); + } + +private: + + void WorkThreadProc() + { + while (true) { + #if 0 + a8::mysql::Connection conn; + a8::mysql::Query* query = conn.CreateQuery(); + conn.Connect(dbhost_, 3306, dbuser_, dbpasswd_, gamedb_); + InitMysqlConnection(query); + + CheckDB(conn, *query); + ProcessMsg(*query); + #endif + WaitLoopCond(); + } + } + + void CheckDB(a8::mysql::Connection& conn, a8::mysql::Query& query) + { + if (a8::XGetTickCount() - last_checkdb_tick_ < 1000 * 60 * 5) { + return; + } + last_checkdb_tick_ = a8::XGetTickCount(); + if (query.ExecQuery("SELECT 1;", {}) <= 0) { + #if 0 + a8::UdpLog::Instance()->Warning("mysql disconnect", {}); + if (conn.Connect(dbhost_, 3306, dbuser_, dbpasswd_, gamedb_)) { + InitMysqlConnection(&query); + a8::UdpLog::Instance()->Info("mysql reconnect successed", {}); + } else { + a8::UdpLog::Instance()->Info("mysql reconnect failed", {}); + } + #endif + } + } + + void ProcessMsg(a8::mysql::Query& query) + { + if (!work_node_ && top_node_) { + msg_mutex_->lock(); + work_node_ = top_node_; + top_node_ = nullptr; + bot_node_ = nullptr; + msg_mutex_->unlock(); + } + while (work_node_) { + AsyncQueryNode *pdelnode = work_node_; + work_node_ = work_node_->nextnode; + ProcAsyncQuery(query, pdelnode); + delete pdelnode; + } + } + + void WaitLoopCond() + { + std::unique_lock lk(*loop_mutex_); + { + msg_mutex_->lock(); + if (!work_node_ && top_node_) { + work_node_ = top_node_; + top_node_ = nullptr; + bot_node_ = nullptr; + } + msg_mutex_->unlock(); + } + if (!work_node_) { + loop_cond_->wait_for(lk, std::chrono::seconds(10)); + } + } + + void ProcAsyncQuery(a8::mysql::Query& query, AsyncQueryNode* node) + { + switch (node->query_type) { + case 0: + { + int ret = query.ExecQuery(node->sql.c_str(), {}); + if (ret < 0) { + MsgQueue::Instance()->PostMsg_r(exec_async_query_msgid, + a8::XParams() + .SetSender(node->context_id) + .SetParam1(AQE_SYNTAX_ERROR) + .SetParam2(query.GetError())); + } else { + DataSet* data_set = new DataSet(); + data_set->reserve(query.RowsNum()); + while (!query.Eof()) { + auto& row = a8::FastAppend(*data_set); + int field_num = query.FieldsNum(); + row.reserve(field_num); + for (int i = 0; i < field_num; i++) { + row.push_back(query.GetValue(i).GetString()); + } + query.Next(); + } + MsgQueue::Instance()->PostMsg_r(exec_async_query_msgid, + a8::XParams() + .SetSender(node->context_id) + .SetParam1(AQE_NO_ERROR) + .SetParam2((void*)data_set)); + } + } + break; + case 1: + { + bool ret = query.ExecScript(node->sql.c_str(), {}); + if (!ret) { + MsgQueue::Instance()->PostMsg_r(exec_async_query_msgid, + a8::XParams() + .SetSender(node->context_id) + .SetParam1(AQE_SYNTAX_ERROR) + .SetParam2(query.GetError())); + } else { + DataSet* data_set = new DataSet(); + MsgQueue::Instance()->PostMsg_r(exec_async_query_msgid, + a8::XParams() + .SetSender(node->context_id) + .SetParam1(AQE_NO_ERROR) + .SetParam2((void*)data_set)); + } + } + break; + default: + { + MsgQueue::Instance()->PostMsg_r(exec_async_query_msgid, + a8::XParams() + .SetSender(node->context_id) + .SetParam1(AQE_QUERY_TYPE_ERROR) + .SetParam2("不可识别的query类型")); + } + break; + } + } + +public: + int exec_async_query_msgid = 0; +private: + std::mutex *loop_mutex_ = nullptr; + std::condition_variable *loop_cond_ = nullptr; + + #if 0 + std::string gamedb_; + std::string dbhost_; + std::string dbuser_; + std::string dbpasswd_; + #endif + long long last_checkdb_tick_ = 0; + + std::thread *work_thread_ = nullptr; + AsyncQueryNode *top_node_ = nullptr; + AsyncQueryNode *bot_node_ = nullptr; + AsyncQueryNode *work_node_ = nullptr; + std::mutex *msg_mutex_ = nullptr; +}; + +void DBPool::Init() +{ + curr_seqid_ = 1000001; + #if 0 + INIT_LIST_HEAD(&query_list_); + #endif + #if 1 + /*mysql_init()不是完全线程安全的,但是只要成功调用一次就后就线程安全了, + 如果有多线程并发使用mysql_init(),建议在程序初始化时空调一次mysql_init(),他的这点特性很像qsort() + */ + a8::mysql::Connection conn; + #endif + exec_async_query_msgid_ = MsgQueue::Instance()->AllocIMMsgId(); + MsgQueue::Instance()->RegisterCallBack(exec_async_query_msgid_, + [] (const a8::XParams& param) + { + if (param.param1.GetInt() == AQE_NO_ERROR) { + DataSet* data_set = (DataSet*)param.param2.GetUserData(); + DBPool::Instance()->AsyncSqlOnOk(param.sender, data_set); + delete data_set; + } else { + DBPool::Instance()->AsyncSqlOnError(param.sender, + param.param1, + param.param2); + } + } + ); +} + +void DBPool::UnInit() +{ + +} + +void DBPool::SetThreadNum(int thread_num) +{ + assert(thread_num > 0); + for (int i = 0; i < thread_num; i++) { + DBThread *db_thread = new DBThread(); + db_thread->exec_async_query_msgid = exec_async_query_msgid_; + db_thread->Init(); + db_thread_pool_.push_back(db_thread); + } +} + +void DBPool::ExecAsyncQuery(a8::XObject conn_info, const char* querystr, std::initializer_list args, + a8::XParams param, AsyncDBOnOkFunc on_ok, AsyncDBOnErrorFunc on_error, long long hash_code) +{ + long long context_id = ++curr_seqid_; + { + AsyncQueryRequest* p = new AsyncQueryRequest(); + p->context_id = context_id; + p->param = param; + p->sql = ""; + p->add_time = time(nullptr); + p->on_ok = on_ok; + p->on_error = on_error; + #if 0 + list_add_tail(&p->entry, &query_list_); + #endif + async_query_hash_[p->context_id] = p; + } + if (db_thread_pool_.empty()) { + MsgQueue::Instance()->PostMsg_r(exec_async_query_msgid_, + a8::XParams() + .SetSender(context_id) + .SetParam1(AQE_CONN_ERROR)); + return; + } + DBThread *db_thread = nullptr; + if (hash_code != 0) { + db_thread = db_thread_pool_[hash_code % db_thread_pool_.size()]; + } else { + db_thread = db_thread_pool_[rand() % db_thread_pool_.size()]; + } + { + AsyncQueryNode* node = new AsyncQueryNode(); + node->socket_handle = 0; + node->query_type = 0; + node->context_id = context_id; + node->sql = ""; + node->_sql_fmt = querystr; + node->_sql_params = args; + conn_info.DeepCopy(node->conn_info); + db_thread->AddAsyncQuery(node); + } +} + +void DBPool::ExecAsyncScript(a8::XObject conn_info, const char* querystr, std::initializer_list args, + a8::XParams param, AsyncDBOnOkFunc on_ok, AsyncDBOnErrorFunc on_error, long long hash_code) +{ + long long context_id = ++curr_seqid_; + { + AsyncQueryRequest* p = new AsyncQueryRequest(); + p->context_id = context_id; + p->param = param; + p->sql = ""; + p->add_time = time(nullptr); + p->on_ok = on_ok; + p->on_error = on_error; + #if 0 + list_add_tail(&p->entry, &query_list_); + #endif + async_query_hash_[p->context_id] = p; + } + if (db_thread_pool_.empty()) { + MsgQueue::Instance()->PostMsg_r(exec_async_query_msgid_, + a8::XParams() + .SetSender(context_id) + .SetParam1(AQE_CONN_ERROR)); + return; + } + DBThread *db_thread = nullptr; + if (hash_code != 0) { + db_thread = db_thread_pool_[hash_code % db_thread_pool_.size()]; + } else { + db_thread = db_thread_pool_[rand() % db_thread_pool_.size()]; + } + { + AsyncQueryNode* node = new AsyncQueryNode(); + node->socket_handle = 0; + node->query_type = 1; + node->context_id = context_id; + node->sql = ""; + node->_sql_fmt = querystr; + node->_sql_params = args; + conn_info.DeepCopy(node->conn_info); + db_thread->AddAsyncQuery(node); + } +} + +AsyncQueryRequest* DBPool::GetAsyncQueryRequest(long long seqid) +{ + auto itr = async_query_hash_.find(seqid); + return itr != async_query_hash_.end() ? itr->second : nullptr; +} + +void DBPool::AsyncSqlOnOk(long long seqid, DataSet* data_set) +{ + AsyncQueryRequest* request = GetAsyncQueryRequest(seqid); + if (!request) { + return; + } + if (request->on_ok) { + request->on_ok(request->param, data_set); + } + async_query_hash_.erase(seqid); + delete request; +} + +void DBPool::AsyncSqlOnError(long long seqid, int errcode, const std::string& errmsg) +{ + AsyncQueryRequest* request = GetAsyncQueryRequest(seqid); + if (!request) { + return; + } + if (request->on_error) { + request->on_error(request->param, errcode, errmsg); + } + async_query_hash_.erase(seqid); + delete request; +} diff --git a/cpp/dbpool.h b/cpp/dbpool.h new file mode 100644 index 0000000..df4cb3e --- /dev/null +++ b/cpp/dbpool.h @@ -0,0 +1,38 @@ +#pragma once + +typedef std::vector> DataSet; +typedef void (*AsyncDBOnOkFunc)(a8::XParams& param, const DataSet* data_set); +typedef void (*AsyncDBOnErrorFunc)(a8::XParams& param, int error_code, const std::string& error_msg); + +struct AsyncQueryRequest; +class DBThread; +class DBPool : public a8::Singleton +{ + private: + DBPool() {}; + friend class a8::Singleton; + + public: + void Init(); + void UnInit(); + void SetThreadNum(int thread_num); + + //执行异步并行查询 + void ExecAsyncQuery(a8::XObject conn_info, const char* querystr, std::initializer_list args, + a8::XParams param, AsyncDBOnOkFunc on_ok, AsyncDBOnErrorFunc on_error, long long hash_code); + //执行异步并行sql + void ExecAsyncScript(a8::XObject conn_info, const char* querystr, std::initializer_list args, + a8::XParams param, AsyncDBOnOkFunc on_ok, AsyncDBOnErrorFunc on_error, long long hash_code); + + private: + AsyncQueryRequest* GetAsyncQueryRequest(long long seqid); + void AsyncSqlOnOk(long long seqid, DataSet* data_set); + void AsyncSqlOnError(long long seqid, int errcode, const std::string& errmsg); + + private: + long long curr_seqid_ = 0; + std::map async_query_hash_; + + unsigned short exec_async_query_msgid_ = 0; + std::vector db_thread_pool_; +}; diff --git a/cpp/msgqueue.cc b/cpp/msgqueue.cc new file mode 100644 index 0000000..f86cc91 --- /dev/null +++ b/cpp/msgqueue.cc @@ -0,0 +1,116 @@ +#include "precompile.h" + +#include + +#include "framework/cpp/msgqueue.h" +#include "app.h" + +struct MsgQueueNode +{ + struct list_head entry; + MsgHandleFunc func; +}; + +class MsgQueueImp +{ +public: + int curr_im_msgid = 10000; + std::map msg_handlers; + + void ProcessMsg(int msgid, const a8::XParams& param) + { + auto itr = msg_handlers.find(msgid); + if (itr != msg_handlers.end()) { + list_head* head = &itr->second; + struct MsgQueueNode *node = nullptr; + struct MsgQueueNode *tmp = nullptr; + list_for_each_entry_safe(node, tmp, head, entry) { + node->func(param); + } + } + } + + CallBackHandle RegisterCallBack(int msgid, MsgHandleFunc handle_func) + { + MsgQueueNode* node = new MsgQueueNode(); + INIT_LIST_HEAD(&node->entry); + node->func = handle_func; + + auto itr = msg_handlers.find(msgid); + if (itr == msg_handlers.end()) { + msg_handlers[msgid] = list_head(); + itr = msg_handlers.find(msgid); + assert(itr != msg_handlers.end()); + INIT_LIST_HEAD(&itr->second); + } + list_add_tail(&node->entry, &itr->second); + return &node->entry; + } + +}; + +void MsgQueue::Init() +{ + imp_ = new MsgQueueImp(); +} + +void MsgQueue::UnInit() +{ + delete imp_; + imp_ = nullptr; +} + +void MsgQueue::SendMsg(int msgid, a8::XParams param) +{ + imp_->ProcessMsg(msgid, param); +} + +void MsgQueue::PostMsg(int msgid, a8::XParams param) +{ + param._sys_field = msgid; + a8::Timer::Instance()->AddDeadLineTimer(0, param, + [] (const a8::XParams& param) + { + MsgQueue::Instance()->imp_->ProcessMsg(param._sys_field, param); + }); +} + +void MsgQueue::AddDelayMsg(int msgid, a8::XParams param, int milli_seconds) +{ + param._sys_field = msgid; + a8::Timer::Instance()->AddDeadLineTimer(milli_seconds, param, + [] (const a8::XParams& param) + { + MsgQueue::Instance()->imp_->ProcessMsg(param._sys_field, param); + }); +} + +void MsgQueue::RemoveCallBack(CallBackHandle handle) +{ + list_head* head = handle; + MsgQueueNode* node = list_entry(head, struct MsgQueueNode, entry); + list_del_init(&node->entry); + delete node; +} + +CallBackHandle MsgQueue::RegisterCallBack(int msgid, MsgHandleFunc handle_func) +{ + return imp_->RegisterCallBack(msgid, handle_func); +} + +int MsgQueue::AllocIMMsgId() +{ + return ++imp_->curr_im_msgid; +} + +void MsgQueue::ProcessMsg(int msgid, const a8::XParams& param) +{ + imp_->ProcessMsg(msgid, param); +} + +void MsgQueue::PostMsg_r(int msgid, a8::XParams param) +{ + a8::XParams* p = new a8::XParams(); + param.DeepCopy(*p); + App::Instance()->AddIMMsg(IM_SysMsgQueue, a8::XParams().SetSender(msgid).SetParam1((void*)p)); +} diff --git a/cpp/msgqueue.h b/cpp/msgqueue.h new file mode 100644 index 0000000..5e23737 --- /dev/null +++ b/cpp/msgqueue.h @@ -0,0 +1,30 @@ +#pragma once + +typedef std::function MsgHandleFunc; +typedef list_head* CallBackHandle; + +class MsgQueueImp; +class MsgQueue : public a8::Singleton +{ + private: + MsgQueue() {}; + friend class a8::Singleton; + + public: + void Init(); + void UnInit(); + + void SendMsg(int msgid, a8::XParams param); + void PostMsg(int msgid, a8::XParams param); + void AddDelayMsg(int msgid, a8::XParams param, int milli_seconds); + void RemoveCallBack(CallBackHandle handle); + CallBackHandle RegisterCallBack(int msgid, MsgHandleFunc handle_func); + int AllocIMMsgId(); + void ProcessMsg(int msgid, const a8::XParams& param); + + //线程安全版本 + void PostMsg_r(int msgid, a8::XParams param); + + private: + MsgQueueImp* imp_ = nullptr; +}; diff --git a/cpp/types.h b/cpp/types.h index 5c13c1f..298949c 100644 --- a/cpp/types.h +++ b/cpp/types.h @@ -17,3 +17,10 @@ struct JsonHttpRequest JsonHttpRequest(); ~JsonHttpRequest(); }; + +enum SysInnerMesssage_e +{ + IM_SysBegin = 1, + IM_SysMsgQueue = 2, + IM_SysEnd = 99, +};