diff --git a/server/wsproxy/kcpsession.cc b/server/wsproxy/kcpsession.cc index 50126ba..fc41b91 100644 --- a/server/wsproxy/kcpsession.cc +++ b/server/wsproxy/kcpsession.cc @@ -85,10 +85,10 @@ void KcpSession::DecodeUserPacket(char* buf, int& offset, unsigned int buflen) //packagelen + msgid + magiccode + msgbody //2 + 2 + 4+ xx + \0 + xx bool warning = false; - while (buflen - offset >= sizeof(f8::PackHead)) { - f8::PackHead* p = (f8::PackHead*)&buf[offset]; + while (buflen - offset >= sizeof(f8::PackHead) + GetSecretKeyLen()) { + f8::PackHead* p = (f8::PackHead*)&buf[offset + GetSecretKeyLen()]; if (p->magic_code == f8::MAGIC_CODE) { - if (buflen - offset < sizeof(f8::PackHead) + p->packlen) { + if (buflen - offset < sizeof(f8::PackHead) + p->packlen + GetSecretKeyLen()) { break; } //a8::XPrintf("Recv MsgId:%d\n", {p->msgid}); @@ -98,7 +98,7 @@ void KcpSession::DecodeUserPacket(char* buf, int& offset, unsigned int buflen) //saddr, p->msgid, p->seqid, - &buf[offset + sizeof(f8::PackHead)], + &buf[offset + sizeof(f8::PackHead) + GetSecretKeyLen()], p->packlen); offset += sizeof(f8::PackHead) + p->packlen; } else { diff --git a/server/wsproxy/kcpsession.h b/server/wsproxy/kcpsession.h index 6f72d3e..c1ee0e9 100644 --- a/server/wsproxy/kcpsession.h +++ b/server/wsproxy/kcpsession.h @@ -22,6 +22,12 @@ public: void SendClientMsg(char* buf, int buf_len); virtual void OnRecvPacket(a8::UdpPacket* pkt) override; + static int GetSecretKeyLen() { return sizeof(long long) / 4; } + static long long ReadSecretKey(const char* buf, int buf_len) + { + return buf_len < GetSecretKeyLen() ? 0 : *((long long*)buf); + } + protected: virtual void DecodeUserPacket(char* buf, int& offset, unsigned int buflen) override; diff --git a/server/wsproxy/longsessionmgr.cc b/server/wsproxy/longsessionmgr.cc index d0779a5..0a1bbda 100644 --- a/server/wsproxy/longsessionmgr.cc +++ b/server/wsproxy/longsessionmgr.cc @@ -81,11 +81,14 @@ std::shared_ptr LongSessionMgr::GetSession(int socket_handle) void LongSessionMgr::ProcUdpPacket(a8::UdpPacket* pkt) { - int socket_handle = 0; - long long secret_key = 0; - auto session = GetSession(socket_handle); - if (session && secret_key == session->GetKcpSession()->GetSecretKey()) { - session->GetKcpSession()->OnRecvPacket(pkt); + const int IKCP_OVERHEAD = 24; + if (pkt->buf_len > IKCP_OVERHEAD + KcpSession::GetSecretKeyLen()) { + int socket_handle = ikcp_getconv(pkt->buf); + long long secret_key = KcpSession::ReadSecretKey(pkt->buf, pkt->buf_len); + auto session = GetSession(socket_handle); + if (session && secret_key == session->GetKcpSession()->GetSecretKey()) { + session->GetKcpSession()->OnRecvPacket(pkt); + } } }