216 lines
6.4 KiB
Python
216 lines
6.4 KiB
Python
# -*- coding: utf-8 -*-
|
|
import os
|
|
import json
|
|
import random
|
|
import time
|
|
import base64
|
|
import threading
|
|
import tornado.ioloop
|
|
import tornado.web
|
|
import tornado.gen
|
|
import tornado.websocket
|
|
import tornado.tcpserver
|
|
|
|
app = None
|
|
class ServerSide:
|
|
|
|
def __init__(self, local_port, remote_port):
|
|
global app
|
|
self._local_port = int(local_port)
|
|
self._remote_port = int(remote_port)
|
|
app = self
|
|
self.local = None
|
|
self.remote = None
|
|
self.localHandlerHash = {}
|
|
self.localHandlerHashMutex = threading.Lock()
|
|
self.remoteConnIdx = 1000
|
|
self.remotePending = {}
|
|
self.remotePendingMutex = threading.Lock()
|
|
self.remoteConnHash = {}
|
|
self.remoteConnHashMutex = threading.Lock()
|
|
|
|
def run(self):
|
|
webapp = tornado.web.Application([
|
|
(r"/websocket", LocalSocketHandler)
|
|
], {})
|
|
webapp.listen(self._local_port)
|
|
|
|
remote_server = RemoteServer()
|
|
remote_server.listen(self._remote_port)
|
|
|
|
self.local = webapp
|
|
self.remote = remote_server
|
|
tornado.ioloop.IOLoop.instance().start()
|
|
|
|
def addRemoteConn(self, conn):
|
|
local_conn = self.randLocalConn()
|
|
if not local_conn:
|
|
return False
|
|
self.remotePendingMutex.acquire()
|
|
conn.idx = self.remoteConnIdx + 1
|
|
self.remoteConnIdx += 1
|
|
self.remotePending[conn.idx] = conn
|
|
conn.localSocket = local_conn
|
|
self.remotePendingMutex.release()
|
|
local_conn.sendMsg({
|
|
'remoteConnIdx' : conn.idx,
|
|
'cmd' : 'connect'
|
|
})
|
|
print('addRemoteConn')
|
|
return True
|
|
|
|
def addLocalConn(self, conn):
|
|
self.localHandlerHashMutex.acquire()
|
|
self.localHandlerHash[conn] = conn
|
|
self.localHandlerHashMutex.release()
|
|
|
|
def removeLocalConn(self, conn):
|
|
self.localHandlerHashMutex.acquire()
|
|
if conn in self.localHandlerHash:
|
|
self.remoteConnHashMutex.acquire()
|
|
try:
|
|
for remoteConn in self.remoteConnHash.values():
|
|
if remoteConn.localSocket == conn:
|
|
try:
|
|
remoteConn.stream.close()
|
|
except:
|
|
pass
|
|
except:
|
|
pass
|
|
finally:
|
|
self.remoteConnHashMutex.release()
|
|
del self.localHandlerHash[conn]
|
|
self.localHandlerHashMutex.release()
|
|
|
|
def removeRemoteConn(self, conn):
|
|
self.remoteConnHashMutex.acquire()
|
|
if conn.idx in self.remoteConnHash:
|
|
del self.remoteConnHash[conn.idx]
|
|
self.remoteConnHashMutex.release()
|
|
|
|
def randLocalConn(self):
|
|
self.localHandlerHashMutex.acquire()
|
|
try:
|
|
if len(self.localHandlerHash) <= 0:
|
|
return None
|
|
key = random.choice(list(self.localHandlerHash.keys()))
|
|
return self.localHandlerHash[key]
|
|
finally:
|
|
self.localHandlerHashMutex.release()
|
|
|
|
def onLocalConnectOk(self, msg):
|
|
local_conn = None
|
|
self.remotePendingMutex.acquire()
|
|
if msg['remoteConnIdx'] in self.remotePending:
|
|
local_conn = self.remotePending[msg['remoteConnIdx']]
|
|
del self.remotePending[msg['remoteConnIdx']]
|
|
self.remotePendingMutex.release()
|
|
|
|
self.remoteConnHashMutex.acquire()
|
|
if local_conn:
|
|
self.remoteConnHash[local_conn.idx] = local_conn
|
|
self.remoteConnHashMutex.release()
|
|
|
|
def onLocalForwardData(self, msg):
|
|
# print(msg)
|
|
self.remoteConnHashMutex.acquire()
|
|
if msg['remoteConnIdx'] in self.remoteConnHash:
|
|
local_conn = self.remoteConnHash[msg['remoteConnIdx']]
|
|
data = base64.b64decode(msg['data'][2:-1])
|
|
print(data)
|
|
local_conn.stream.write(data)
|
|
self.remoteConnHashMutex.release()
|
|
|
|
def isConnectOk(self, connIdx):
|
|
isOk = False
|
|
self.remotePendingMutex.acquire()
|
|
isOk = connIdx not in self.remotePending
|
|
self.remotePendingMutex.release()
|
|
return isOk
|
|
|
|
class RemoteServerConnection(object):
|
|
|
|
def __init__(self, stream, address):
|
|
self.stream = stream
|
|
self.address = address
|
|
self.idx = 0
|
|
self.localSocket = None
|
|
|
|
async def handle_stream(self):
|
|
global app
|
|
while True:
|
|
try:
|
|
data = await self.stream.read_until(b"\n")
|
|
self.localSocket.sendMsg({
|
|
'cmd' : 'forwardData',
|
|
'remoteConnIdx' : self.idx,
|
|
'data' : str(base64.b64encode(data))
|
|
})
|
|
except tornado.iostream.StreamClosedError:
|
|
break
|
|
#while
|
|
if self.localSocket:
|
|
self.localSocket.sendMsg({
|
|
'cmd' : 'socketClose',
|
|
'remoteConnIdx' : self.idx,
|
|
})
|
|
app.removeRemoteConn(self)
|
|
|
|
class RemoteServer(tornado.tcpserver.TCPServer):
|
|
|
|
async def handle_stream(self, stream, address):
|
|
print('zzzz')
|
|
global app
|
|
conn = RemoteServerConnection(stream, address)
|
|
if not app.addRemoteConn(conn):
|
|
stream.close()
|
|
return
|
|
await tornado.gen.sleep(2)
|
|
if not app.isConnectOk(conn.idx):
|
|
stream.close()
|
|
return
|
|
await conn.handle_stream()
|
|
|
|
class LocalSocketHandler(tornado.websocket.WebSocketHandler):
|
|
|
|
def open(self):
|
|
global app
|
|
self._recvBuf = ''
|
|
app.addLocalConn(self)
|
|
|
|
def on_message(self, message):
|
|
global app
|
|
self._recvBuf += message
|
|
self.parsePacket()
|
|
|
|
def on_close(self):
|
|
global app
|
|
app.removeLocalConn(self)
|
|
print('on_close')
|
|
|
|
def parsePacket(self):
|
|
if len(self._recvBuf) <= 0:
|
|
return
|
|
lines = self._recvBuf.split('\n')
|
|
if self._recvBuf[-1] == '\n':
|
|
self._recvbuf = lines[-1]
|
|
lines = lines[:-1]
|
|
for line in lines:
|
|
msg = json.loads(line)
|
|
self.dispatchMsg(msg)
|
|
|
|
def dispatchMsg(self, msg):
|
|
global app
|
|
if msg['cmd'] == 'connectOk':
|
|
app.onLocalConnectOk(msg)
|
|
elif msg['cmd'] == 'forwardData':
|
|
app.onLocalForwardData(msg)
|
|
elif msg['cmd'] == 'socketClose':
|
|
pass
|
|
# app.removeRemoteConn(msg['remoteConnIdx'])
|
|
|
|
def sendMsg(self, msg):
|
|
data = json.dumps(msg) + '\n'
|
|
print(data)
|
|
self.write_message(data)
|