tools/scripts/proxy/srvside.py
aozhiwei d1fab83969 1
2018-12-15 20:16:33 +08:00

207 lines
6.2 KiB
Python

# -*- coding: utf-8 -*-
import os
import json
import random
import time
import threading
import tornado.ioloop
import tornado.web
import tornado.gen
import tornado.websocket
import tornado.tcpserver
app = None
class ServerSide:
def __init__(self, local_port, remote_port):
global app
self._local_port = int(local_port)
self._remote_port = int(remote_port)
app = self
self.local = None
self.remote = None
self.localHandlerHash = {}
self.localHandlerHashMutex = threading.Lock()
self.remoteConnIdx = 1000
self.remotePending = {}
self.remotePendingMutex = threading.Lock()
self.remoteConnHash = {}
self.remoteConnHashMutex = threading.Lock()
def run(self):
webapp = tornado.web.Application([
(r"/websocket", LocalSocketHandler)
], {})
webapp.listen(self._local_port)
remote_server = RemoteServer()
remote_server.listen(self._remote_port)
self.local = webapp
self.remote = remote_server
tornado.ioloop.IOLoop.instance().start()
def addRemoteConn(self, conn):
local_conn = self.randLocalConn()
if not local_conn:
return False
self.remotePendingMutex.acquire()
conn.idx = self.remoteConnIdx + 1
self.remoteConnIdx += 1
self.remotePending[conn.idx] = conn
conn.localSocket = local_conn
print('zzzzzzzz')
self.remotePendingMutex.release()
local_conn.SendMsg({
'remoteConnIdx' : conn.idx,
'cmd' : 'connect'
})
return True
def addLocalConn(self, conn):
self.localHandlerHashMutex.acquire()
self.localHandlerHash[conn] = conn
self.localHandlerHashMutex.release()
def removeLocalConn(self, conn):
self.localHandlerHashMutex.acquire()
if conn in self.localHandlerHash:
self.remoteConnHashMutex.acquire()
try:
for remoteConn in self.remoteConnHash.values():
if remoteConn.localSocket == conn:
try:
remoteConn.stream.close()
except:
pass
except:
pass
finally:
self.remoteConnHashMutex.release()
del self.localHandlerHash[conn]
self.localHandlerHashMutex.release()
def removeRemoteConn(self, conn):
self.remoteConnHashMutex.acquire()
if conn.idx in self.remoteConnHash:
del self.remoteConnHash[conn]
self.remoteConnHashMutex.release()
def randLocalConn(self):
self.localHandlerHashMutex.acquire()
try:
if len(self.localHandlerHash) <= 0:
return None
key = random.choice(list(self.localHanderHash.keys()))
return self.localHandlerHash[key]
finally:
self.localHandlerHashMutex.release()
def onLocalConnectOk(self, msg):
local_conn = None
self.remotePendingMutex.acquire()
if msg['remoteConnIdx'] in self.remotePending:
local_conn = self.remotePending[msg['remoteConnIdx']]
del self.remotePending[msg['remoteConnIdx']]
self.remotePendingMutex.release()
self.remoteConnHashMutex.acquire()
if local_conn:
self.remoteConnHash[local_conn.idx] = local_conn
self.remoteConnHashMutex.release()
def onLocalForwardData(self, msg):
self.remotePendingMutex.acquire()
if msg['remoteConnIdx'] in self.remotePending:
local_conn = self.remotePending[msg['remoteConnIdx']]
local_conn.stream.write(base64.b64decode(msg['data']))
self.remotePendingMutex.release()
def isConnectOk(self, connIdx):
isOk = False
self.remotePendingMutex.acquire()
isOk = connIdx not in self.remotePending
self.remotePendingMutex.release()
return isOk
class RemoteServerConnection(object):
def __init__(self, stream, address):
self.stream = stream
self.address = address
self.idx = 0
self.localSocket = None
async def handle_stream(self):
global app
while True:
try:
data = await self.stream.read_until(b"\n")
self.localSocket.sendMsg({
'cmd' : 'forwardData',
'data' : base64.b64encode(data)
})
except tornado.iostream.StreamClosedError:
break
#while
if self.localSocket:
self.localSocket.sendMsg({
'cmd' : 'socketClose',
'remoteConnIdx' : self.idx,
})
app.removeRemoteConn(self)
class RemoteServer(tornado.tcpserver.TCPServer):
async def handle_stream(self, stream, address):
global app
conn = RemoteServerConnection(stream, address)
if not app.addRemoteConn(conn):
stream.close()
return
await gen.sleep(2)
if not app.isConnectOk(conn.connIdx):
stream.close()
return
await conn.handle_stream()
class LocalSocketHandler(tornado.websocket.WebSocketHandler):
def open(self):
global app
self._recvBuf = ''
app.addLocalConn(self)
def on_message(self, message):
global app
self._recvBuf += message
self.parsePacket()
def on_close(self):
global app
app.removeLocalConn(self)
def parsePacket(self):
if len(self._recvBuf) <= 0:
return
lines = self._recvBuf.split('\n')
if self._recvBuf[-1] == '\n':
self._recvbuf = lines[-1]
lines = lines[:-1]
for line in lines:
msg = json.loads(line)
self.dispatchMsg(msg)
def dispatchMsg(self, msg):
global app
if msg['cmd'] == 'connectOk':
app.onLocalConnectOk(msg)
elif msg['cmd'] == 'forwardData':
app.onLocalForward(msg)
elif msg['cmd'] == 'socketClose':
app.removeRemoteConn(msg['remoteConnIdx'])
def sendMsg(self, msg):
data = json.dumps(msg) + '\n'
self.write_message(data)