# -*- coding: utf-8 -*- import os import json import random import time import base64 import threading import tornado.ioloop import tornado.web import tornado.gen import tornado.websocket import tornado.tcpserver app = None class ServerSide: def __init__(self, local_port, remote_port): global app self._local_port = int(local_port) self._remote_port = int(remote_port) app = self self.local = None self.remote = None self.localHandlerHash = {} self.localHandlerHashMutex = threading.Lock() self.remoteConnIdx = 1000 self.remotePending = {} self.remotePendingMutex = threading.Lock() self.remoteConnHash = {} self.remoteConnHashMutex = threading.Lock() def run(self): webapp = tornado.web.Application([ (r"/websocket", LocalSocketHandler) ], {}) webapp.listen(self._local_port) remote_server = RemoteServer() remote_server.listen(self._remote_port) self.local = webapp self.remote = remote_server tornado.ioloop.IOLoop.instance().start() def addRemoteConn(self, conn): local_conn = self.randLocalConn() if not local_conn: return False self.remotePendingMutex.acquire() conn.idx = self.remoteConnIdx + 1 self.remoteConnIdx += 1 self.remotePending[conn.idx] = conn conn.localSocket = local_conn self.remotePendingMutex.release() local_conn.sendMsg({ 'remoteConnIdx' : conn.idx, 'cmd' : 'connect' }) print('addRemoteConn') return True def addLocalConn(self, conn): self.localHandlerHashMutex.acquire() self.localHandlerHash[conn] = conn self.localHandlerHashMutex.release() def removeLocalConn(self, conn): self.localHandlerHashMutex.acquire() if conn in self.localHandlerHash: self.remoteConnHashMutex.acquire() try: for remoteConn in self.remoteConnHash.values(): if remoteConn.localSocket == conn: try: remoteConn.stream.close() except: pass except: pass finally: self.remoteConnHashMutex.release() del self.localHandlerHash[conn] self.localHandlerHashMutex.release() def removeRemoteConn(self, conn): self.remoteConnHashMutex.acquire() if conn.idx in self.remoteConnHash: del self.remoteConnHash[conn.idx] self.remoteConnHashMutex.release() def randLocalConn(self): self.localHandlerHashMutex.acquire() try: if len(self.localHandlerHash) <= 0: return None key = random.choice(list(self.localHandlerHash.keys())) return self.localHandlerHash[key] finally: self.localHandlerHashMutex.release() def onLocalConnectOk(self, msg): local_conn = None self.remotePendingMutex.acquire() if msg['remoteConnIdx'] in self.remotePending: local_conn = self.remotePending[msg['remoteConnIdx']] del self.remotePending[msg['remoteConnIdx']] self.remotePendingMutex.release() self.remoteConnHashMutex.acquire() if local_conn: self.remoteConnHash[local_conn.idx] = local_conn self.remoteConnHashMutex.release() def onLocalForwardData(self, msg): # print(msg) self.remoteConnHashMutex.acquire() if msg['remoteConnIdx'] in self.remoteConnHash: local_conn = self.remoteConnHash[msg['remoteConnIdx']] data = base64.b64decode(msg['data'][2:-1]) print(data) local_conn.stream.write(data) self.remoteConnHashMutex.release() def isConnectOk(self, connIdx): isOk = False self.remotePendingMutex.acquire() isOk = connIdx not in self.remotePending self.remotePendingMutex.release() return isOk class RemoteServerConnection(object): def __init__(self, stream, address): self.stream = stream self.address = address self.idx = 0 self.localSocket = None async def handle_stream(self): global app while True: try: data = await self.stream.read_until(b"\n") self.localSocket.sendMsg({ 'cmd' : 'forwardData', 'remoteConnIdx' : self.idx, 'data' : str(base64.b64encode(data)) }) except tornado.iostream.StreamClosedError: break #while if self.localSocket: self.localSocket.sendMsg({ 'cmd' : 'socketClose', 'remoteConnIdx' : self.idx, }) app.removeRemoteConn(self) class RemoteServer(tornado.tcpserver.TCPServer): async def handle_stream(self, stream, address): print('zzzz') global app conn = RemoteServerConnection(stream, address) if not app.addRemoteConn(conn): stream.close() return await tornado.gen.sleep(0.3) if not app.isConnectOk(conn.idx): stream.close() return await conn.handle_stream() class LocalSocketHandler(tornado.websocket.WebSocketHandler): def open(self): global app self._recvBuf = '' app.addLocalConn(self) def on_message(self, message): global app self._recvBuf += message self.parsePacket() def on_close(self): global app app.removeLocalConn(self) print('on_close') def parsePacket(self): if len(self._recvBuf) <= 0: return lines = self._recvBuf.split('\n') if self._recvBuf[-1] == '\n': self._recvbuf = lines[-1] lines = lines[:-1] for line in lines: msg = json.loads(line) self.dispatchMsg(msg) def dispatchMsg(self, msg): global app if msg['cmd'] == 'connectOk': app.onLocalConnectOk(msg) elif msg['cmd'] == 'forwardData': app.onLocalForwardData(msg) elif msg['cmd'] == 'socketClose': pass # app.removeRemoteConn(msg['remoteConnIdx']) def sendMsg(self, msg): data = json.dumps(msg) + '\n' print(data) self.write_message(data)