diff --git a/scripts/proxy/srvside.py b/scripts/proxy/srvside.py index 580daa9..74d4964 100644 --- a/scripts/proxy/srvside.py +++ b/scripts/proxy/srvside.py @@ -1,8 +1,12 @@ # -*- coding: utf-8 -*- import os import json +import random +import time +import threading import tornado.ioloop import tornado.web +import tornado.gen import tornado.websocket import tornado.tcpserver @@ -16,6 +20,12 @@ class ServerSide: self.local = None self.remote = None self.localHandlerHash = {} + self.localHandlerHashMutex = threading.Lock() + self.remoteConnIdx = 1000 + self.remotePending = {} + self.remotePendingMutex = threading.Lock() + self.remoteConnHash = {} + self.remoteConnHashMutex = threading.Lock() def run(self): webapp = tornado.web.Application([ @@ -30,34 +40,125 @@ class ServerSide: self.remote = remote_server tornado.ioloop.IOLoop.instance().start() - def addRemoteConn(conn): - pass + def addRemoteConn(self, conn): + local_conn = self.randLocalConn() + if not local_conn: + return False + self.remotePendingMutex.acquire() + conn.idx = self.remoteConnIdx + 1 + self.remoteConnIdx += 1 + self.remotePending[conn.idx] = conn + conn.localSocket = local_conn + self.remotePendingMutex.release() + local_conn.SendMsg({ + 'remoteConnIdx' : conn.idx, + 'cmd' : 'connect' + }) + return True - def addLocalConn(conn): + def addLocalConn(self, conn): + self.localHandlerHashMutex.acquire() self.localHandlerHash[conn] = conn + self.localHandlerHashMutex.release() - def removeLocalConn(conn): - del self.localHandlerHash[conn] + def removeLocalConn(self, conn): + self.localHandlerHashMutex.acquire() + if conn in self.localHandlerHash: + self.remoteConnHashMutex.acquire() + try: + for remoteConn in self.remoteConnHash.values(): + if remoteConn.localSocket == conn: + try: + remoteConn.steam.close() + except: + pass + except: + pass + finally: + self.remoteConnHashMutex.release() + del self.localHandlerHash[conn] + self.localHandlerHashMutex.release() + + def removeRemoteConn(self, conn): + self.remoteConnHashMutex.acquire() + if conn.idx in self.remoteConnHash: + del self.remoteConnHash[conn] + self.remoteConnHashMutex.release() + + def randLocalConn(self): + self.localHandlerHashMutex.acquire() + try: + if len(self.localHandlerHash) <= 0: + return None + key = random.choice(list(self.localHanderHash.keys())) + return self.localHandlerHash[key] + finally: + self.localHandlerHashMutex.release() + + def onLocalConnectOk(self, msg): + local_conn = None + self.remotePendingMutex.acquire() + if msg['remoteConnIdx'] in self.remotePending: + local_conn = self.remotePending[msg['remoteConnIdx']] + del self.remotePending[msg['remoteConnIdx']] + self.remotePendingMutex.release() + + self.remoteConnHashMutex.acquire() + if local_conn: + self.remoteConnHash[local_conn.idx] = local_conn + self.remoteConnHashMutex.release() + + def onLocalForwardData(self, msg): + self.remotePendingMutex.acquire() + if msg['remoteConnIdx'] in self.remotePending: + local_conn = self.remotePending[msg['remoteConnIdx']] + local_conn.steam.write(base64.b64decode(msg['data'])) + self.remotePendingMutex.release() + + def isConnectOk(self, connIdx): + isOk = False + self.remotePendingMutex.acquire() + isOk = connIdx not in self.remotePending + self.remotePendingMutex.release() + return isOk class RemoteServerConnection(object): def __init__(self, steam, address): self.steam = steam self.address = address + self.idx = 0 + self.localSocket = None async def handle_stream(self): while True: try: data = await self.stream.read_until(b"\n") - await self.stream.write(data) + self.localSocket.sendMsg({ + 'cmd' : 'forwardData', + 'data' : base64.b64encode(data) + }) except tornado.iostream.StreamClosedError: break + #while + if self.localSocket: + self.localSocket.sendMsg({ + 'cmd' : 'socketClose', + 'remoteConnIdx' : self.idx, + }) + app.removeRemoteConn(self) class RemoteServer(tornado.tcpserver.TCPServer): async def handle_stream(self, stream, address): conn = RemoteServerConnection(steam, address) - app.addRemoteConn(conn) + if not app.addRemoteConn(conn): + stream.close() + return + yield gen.sleep(2) + if not app.isConnectOk(conn.connIdx): + stream.close() + return conn.handle_stream() class LocalSocketHandler(tornado.websocket.WebSocketHandler): @@ -85,4 +186,14 @@ class LocalSocketHandler(tornado.websocket.WebSocketHandler): self.dispatchMsg(msg) def dispatchMsg(self, msg): - pass + if msg['cmd'] == 'connectOk': + app.onLocalConnectOk(msg) + elif msg['cmd'] == 'forwardData': + app.onLocalForward(msg) + elif msg['cmd'] == 'socketClose': + app.removeRemoteConn(msg['remoteConnIdx']) + + def sendMsg(self, msg): + data = json.dumps(msg) + '\n' + self.write_message(data) +