server: limit header length
This commit is contained in:
parent
0df445ea22
commit
a579c5f145
|
@ -4,7 +4,7 @@ import typing
|
|||
|
||||
from .document import Document
|
||||
from .enums import StatusCode
|
||||
from .error import BodyTooLargeError
|
||||
from .error import BodyTooLargeError, GeminiError
|
||||
from .misc import FileSize, Url
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
|
@ -116,7 +116,12 @@ class Request(Message):
|
|||
:param transport: The transport of the server the request is being sent to
|
||||
"""
|
||||
|
||||
url = (await transport.readline()).decode("utf-8")
|
||||
try:
|
||||
url = (await transport.readline(1024)).decode("utf-8")
|
||||
|
||||
except ValueError as error:
|
||||
if "separator is found" in str(error).lower():
|
||||
raise GeminiError(59, "Header too long")
|
||||
|
||||
if url.startswith("/"):
|
||||
url = f"gemini://{transport.local_address}:{transport.local_port}{url}"
|
||||
|
@ -193,7 +198,12 @@ class Response(Message):
|
|||
:param transport: The transport of the client the response is being sent to
|
||||
"""
|
||||
|
||||
header = (await transport.readline()).decode("utf-8")
|
||||
try:
|
||||
header = (await transport.readline(1024)).decode("utf-8")
|
||||
|
||||
except ValueError as error:
|
||||
if "separator is found" in str(error).lower():
|
||||
raise GeminiError(59, "Header too long")
|
||||
|
||||
try:
|
||||
status, meta = header.strip().split(" ", 1)
|
||||
|
|
|
@ -222,7 +222,7 @@ class AsyncServer(BaseApp, dict):
|
|||
"""
|
||||
|
||||
transport = AsyncTransport(reader, writer, self.timeout)
|
||||
request = Request("localhost:1965")
|
||||
request = Request(f"{self.addr}:{self.port}/")
|
||||
|
||||
try:
|
||||
request = await Request.from_transport(self, transport)
|
||||
|
|
|
@ -3,6 +3,8 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import typing
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from .misc import convert_to_bytes
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
|
@ -102,13 +104,14 @@ class AsyncTransport:
|
|||
return await asyncio.wait_for(self.reader.read(length), self.timeout)
|
||||
|
||||
|
||||
async def readline(self) -> bytes:
|
||||
async def readline(self, limit: int = 65536) -> bytes:
|
||||
"Read until a line ending ('\\\\r' or '\\\\n') is encountered"
|
||||
|
||||
with self._set_limit(limit):
|
||||
return await asyncio.wait_for(self.reader.readline(), self.timeout)
|
||||
|
||||
|
||||
async def readuntil(self, separator: bytes | str) -> bytes:
|
||||
async def readuntil(self, separator: bytes | str, limit = 65536) -> bytes:
|
||||
"""
|
||||
Read upto the separator
|
||||
|
||||
|
@ -118,6 +121,7 @@ class AsyncTransport:
|
|||
if isinstance(separator, str):
|
||||
separator = separator.encode(self.encoding)
|
||||
|
||||
with self._set_limit(limit):
|
||||
return await asyncio.wait_for(self.reader.readuntil(separator), self.timeout)
|
||||
|
||||
|
||||
|
@ -131,3 +135,15 @@ class AsyncTransport:
|
|||
data = convert_to_bytes(data, self.encoding)
|
||||
self.writer.write(data)
|
||||
await self.writer.drain()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _set_limit(self, limit: int = 65536):
|
||||
orig_limit = self.reader._limit
|
||||
self.reader._limit = limit
|
||||
|
||||
try:
|
||||
yield
|
||||
|
||||
finally:
|
||||
self.reader._limit = orig_limit
|
||||
|
|
Loading…
Reference in a new issue