paws/paws/middleware.py
2020-01-13 08:10:48 -05:00

224 lines
5.7 KiB
Python

import asyncio
import aiohttp
import json
import logging
import binascii
import base64
import traceback
from urllib.parse import urlparse
from aiohttp.http_exceptions import *
from aiohttp.client_exceptions import *
from .signature import validate, pass_hash
from .functions import json_error, user_check
from .config import MASTOCONFIG, script_path
from . import database as db
# I'm a little teapot :3
class HTTPTeapot(aiohttp.web.HTTPError):
status_code = 418
blocked_agents = [
'gabsocial',
'kiwifarms',
'fedichive',
'liveview',
'freespeech',
'shitposter.club',
'baraag',
'gameliberty',
'neckbeard'
]
auth_paths = [
'/@',
'/users'
]
def parse_sig(signature):
for line in signature.split(','):
if 'keyId' in line:
actor = line.split('=')[1].split('#')[0].replace('"', '')
return actor
def parse_ua(agent):
if not agent:
return
ua1 = agent.split('+https://')
if len(ua1) < 2:
return
ua2 = ua1[1].split('/')
if len(ua2) > 1:
return ua2[0]
async def raise_auth_error(request, auth_realm):
raise aiohttp.web.HTTPUnauthorized(
headers={aiohttp.hdrs.WWW_AUTHENTICATE: f'Basic realm={auth_realm}'},
body=open(f'{script_path}/templates/unauthorized.html').read(),
content_type='text/html'
)
async def passthrough(path, headers, post=None, query=None):
reqtype = 'POST' if post else 'GET'
url = urlparse(path).path
querydata = query if query else ''
try:
async with aiohttp.request(reqtype, f'https://{MASTOCONFIG["domain"]}/{path}{query}', headers=headers, data=post) as resp:
data = await resp.read()
if resp.status not in [200, 202]:
print(data)
logging.warning(f'Recieved error {resp.status} from Mastodon')
json_error(504, f'Failed to forward request. Recieved error {resp.status} from Mastodon')
raise aiohttp.web.HTTPOk(body=data, content_type=resp.content_type)
except ClientConnectorError:
traceback.print_exc()
return json_error(504, f'Failed to connect to Mastodon')
async def http_redirect(app, handler):
async def redirect_handler(request):
querydata = request.query
rawquery = '?'
if len(querydata) > 0:
for var in querydata:
if rawquery == '?':
rawquery += f'{var}={querydata[var]}'
else:
rawquery += f'&{var}={querydata[var]}'
query = rawquery if rawquery != '' else None
try:
data = await request.json()
except Exception as e:
#logging.warning(f'failed to grab data: {e}')
data = None
await passthrough(request.path, request.headers, post=data, query=query)
return (await handler(request))
return redirect_handler
async def http_signatures(app, handler):
async def http_signatures_handler(request):
request['validated'] = False
json_req = True if 'json' in request.headers.get('Accept', '') else False
if request.method == 'POST':
if 'signature' in request.headers:
data = await request.json()
#print(json.dumps(data, indent=' '))
if 'actor' not in data:
logging.info('signature check failed, no actor in message')
raise json_error(401, 'signature check failed, no actor in message')
actor = data["actor"]
if not (await validate(actor, request)):
logging.info(f'Signature validation failed for: {actor}')
raise json_error(401, 'signature check failed, signature did not match key')
else:
logging.info('missing signature')
raise json_error(401, 'Missing signature')
if any(map(request.path.startswith, auth_paths)) and request.method != 'POST':
if user_check(request.path):
logging.info('allowing passthrough of user')
elif json_req or request.path.endswith('.json'):
signature = request.headers.get('signature', '')
if not signature:
logging.info('missing signature')
raise json_error(401, 'Missing signature')
actor = parse_sig(signature)
if not (await validate(actor, request)):
logging.info(f'Signature validation failed for: {actor}')
raise json_error(401, 'signature check failed, signature did not match key')
else:
auth_username = 'admin'
auth_password = 'doubleheck'
auth_realm = 'Nope'
auth_header = request.headers.get(aiohttp.hdrs.AUTHORIZATION)
if auth_header == None or not auth_header.startswith('Basic '):
return await raise_auth_error(request, auth_realm)
try:
secret = auth_header[6:].encode('utf-8')
auth_decoded = base64.decodebytes(secret).decode('utf-8')
except (UnicodeDecodeError, UnicodeEncodeError, binascii.Error):
await raise_auth_error(request)
credentials = auth_decoded.split(':')
if len(credentials) != 2:
await raise_auth_error(request, auth_realm)
username, password = credentials
if username != auth_username or password != auth_password:
await raise_auth_error(request, auth_realm)
return (await handler(request))
return http_signatures_handler
async def http_filter(app, handler):
async def http_filter_handler(request):
domain = parse_ua(request.headers.get('user-agent'))
if not domain:
raise json_error(401, 'Missing User-Agent')
if [agent for agent in blocked_agents if agent in request.headers.get('User-Agent', '').lower()]:
logging.info(f'Blocked garbage: {domain}')
raise HTTPTeapot(body='418 This teapot kills fascists', content_type='text/plain')
if db.ban_check(domain):
logging.info(f'Blocked instance: {domain}')
raise json_error(403, 'Forbidden')
return (await handler(request))
return http_filter_handler
# Fucking trailing slashes
async def http_trailing_slash(app, handler):
async def http_trailing_slash_handler(request):
if request.path != '/' and request.path.endswith('/'):
return aiohttp.web.HTTPFound(request.path[:-1])
return (await handler(request))
return http_trailing_slash_handler
__all__ = ['http_signatures_middleware', 'http_auth_middleware', 'http_filter_middleware', 'http_trailing_slash']