127 lines
2 KiB
Python
Executable file
127 lines
2 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
import asyncio
|
|
import os
|
|
import signal
|
|
import sys
|
|
import traceback
|
|
|
|
|
|
try:
|
|
remote_address = sys.argv[1]
|
|
except IndexError:
|
|
raise ValueError('No remote address specified')
|
|
|
|
try:
|
|
remote_port = int(sys.argv[2])
|
|
except IndexError:
|
|
remote_port = 4713
|
|
|
|
|
|
running = False
|
|
socket_path = f'/var/run/user/{os.getuid()}/pulse/native'
|
|
tasks = []
|
|
signals_to_handle = [
|
|
signal.SIGHUP,
|
|
signal.SIGINT,
|
|
signal.SIGTERM,
|
|
signal.SIGQUIT
|
|
]
|
|
|
|
|
|
async def proxy(reader, writer):
|
|
while True:
|
|
if not running:
|
|
return
|
|
|
|
try:
|
|
data = await reader.read(512)
|
|
|
|
if not data:
|
|
break
|
|
|
|
writer.write(data)
|
|
await writer.drain()
|
|
|
|
except ConnectionResetError:
|
|
break
|
|
|
|
writer.close()
|
|
|
|
|
|
async def handle_connection(source_reader, source_writer):
|
|
if not running:
|
|
loop.stop()
|
|
return
|
|
|
|
target_reader, target_writer = await asyncio.open_connection(remote_address, remote_port)
|
|
|
|
try:
|
|
address = source_writer.get_extra_info('peername')[0]
|
|
except IndexError:
|
|
address = 'unix socket'
|
|
|
|
print('New connection from', address)
|
|
|
|
await asyncio.gather(*[proxy(source_reader, target_writer), proxy(target_reader, source_writer)])
|
|
|
|
print('Closed connection to', address)
|
|
|
|
|
|
def signal_handler(callback=None):
|
|
for sig in signals_to_handle:
|
|
if callback:
|
|
loop.add_signal_handler(sig, callback)
|
|
|
|
else:
|
|
loop.remove_signal_handler(sig)
|
|
|
|
|
|
def stop(*args):
|
|
global running
|
|
|
|
if not running:
|
|
return
|
|
|
|
running = False
|
|
signal_handler()
|
|
|
|
for server in servers:
|
|
server.close()
|
|
|
|
for task in tasks:
|
|
task.cancel()
|
|
|
|
while True:
|
|
if task.done():
|
|
break
|
|
|
|
loop.stop()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if os.path.exists(socket_path):
|
|
os.remove(socket_path)
|
|
|
|
loop = asyncio.new_event_loop()
|
|
|
|
signal_handler(stop)
|
|
|
|
servers = [
|
|
asyncio.start_server(handle_connection, 'localhost', 4713),
|
|
asyncio.start_unix_server(handle_connection, path=socket_path)
|
|
]
|
|
|
|
for server in servers:
|
|
tasks.append(loop.create_task(server))
|
|
|
|
running = True
|
|
|
|
try:
|
|
loop.run_forever()
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|
|
stop()
|
|
|
|
loop.close()
|