server: limit header length

This commit is contained in:
Izalia Mae 2024-03-20 10:52:03 -04:00
parent 0df445ea22
commit a579c5f145
3 changed files with 34 additions and 8 deletions

View file

@ -4,7 +4,7 @@ import typing
from .document import Document
from .enums import StatusCode
from .error import BodyTooLargeError
from .error import BodyTooLargeError, GeminiError
from .misc import FileSize, Url
if typing.TYPE_CHECKING:
@ -116,7 +116,12 @@ class Request(Message):
:param transport: The transport of the server the request is being sent to
"""
url = (await transport.readline()).decode("utf-8")
try:
url = (await transport.readline(1024)).decode("utf-8")
except ValueError as error:
if "separator is found" in str(error).lower():
raise GeminiError(59, "Header too long")
if url.startswith("/"):
url = f"gemini://{transport.local_address}:{transport.local_port}{url}"
@ -193,7 +198,12 @@ class Response(Message):
:param transport: The transport of the client the response is being sent to
"""
header = (await transport.readline()).decode("utf-8")
try:
header = (await transport.readline(1024)).decode("utf-8")
except ValueError as error:
if "separator is found" in str(error).lower():
raise GeminiError(59, "Header too long")
try:
status, meta = header.strip().split(" ", 1)

View file

@ -222,7 +222,7 @@ class AsyncServer(BaseApp, dict):
"""
transport = AsyncTransport(reader, writer, self.timeout)
request = Request("localhost:1965")
request = Request(f"{self.addr}:{self.port}/")
try:
request = await Request.from_transport(self, transport)

View file

@ -3,6 +3,8 @@ from __future__ import annotations
import asyncio
import typing
from contextlib import contextmanager
from .misc import convert_to_bytes
if typing.TYPE_CHECKING:
@ -102,13 +104,14 @@ class AsyncTransport:
return await asyncio.wait_for(self.reader.read(length), self.timeout)
async def readline(self) -> bytes:
async def readline(self, limit: int = 65536) -> bytes:
"Read until a line ending ('\\\\r' or '\\\\n') is encountered"
return await asyncio.wait_for(self.reader.readline(), self.timeout)
with self._set_limit(limit):
return await asyncio.wait_for(self.reader.readline(), self.timeout)
async def readuntil(self, separator: bytes | str) -> bytes:
async def readuntil(self, separator: bytes | str, limit = 65536) -> bytes:
"""
Read upto the separator
@ -118,7 +121,8 @@ class AsyncTransport:
if isinstance(separator, str):
separator = separator.encode(self.encoding)
return await asyncio.wait_for(self.reader.readuntil(separator), self.timeout)
with self._set_limit(limit):
return await asyncio.wait_for(self.reader.readuntil(separator), self.timeout)
async def write(self, data: Any) -> None:
@ -131,3 +135,15 @@ class AsyncTransport:
data = convert_to_bytes(data, self.encoding)
self.writer.write(data)
await self.writer.drain()
@contextmanager
def _set_limit(self, limit: int = 65536):
orig_limit = self.reader._limit
self.reader._limit = limit
try:
yield
finally:
self.reader._limit = orig_limit