This commit is contained in:
Izalia Mae 2024-04-13 09:33:55 -04:00
parent 6ffe05c977
commit d58fbc6061
10 changed files with 508 additions and 122 deletions

View file

@ -1,31 +1,27 @@
__version__ = "0.1.0"
try:
from .application import Application, ExceptionType, ExceptionCallback
from .enums import HttpStatus
from .error import HttpError
from .misc import Color
from .request import Request
from .response import Response, TemplateResponse
from .runner import run_app
from .signal import Signal, SignalCallback
from .template import Template
from .application import Application, ExceptionType, ExceptionCallback
from .enums import HttpStatus
from .error import HttpError
from .misc import Color, StateProxy, Stream
from .request import Request
from .response import Response, FileResponse, TemplateResponse
from .runner import run_app
from .signal import Signal, SignalCallback
from .template import Template
from .router import (
Router,
get_router,
route,
connect,
delete,
get,
head,
options,
patch,
post,
put,
trace
)
except ModuleNotFoundError:
pass
from .router import (
Router,
get_router,
route,
connect,
delete,
get,
head,
options,
patch,
post,
put,
trace
)

View file

@ -1,16 +1,19 @@
from __future__ import annotations
import bsql
import traceback
from collections.abc import Awaitable, Callable
from collections.abc import Callable
from functools import lru_cache
from os.path import normpath
from pathlib import Path
from typing import Any, TypeVar
from .error import HttpError
from .misc import State
from .objtypes import ScopeType
from .request import Request
from .response import Response, TemplateResponse
from .router import Router, get_router
from .misc import StateProxy, Stream, ReaderFunction, WriterFunction
from .request import Request, ScopeType
from .response import FileResponse, Response
from .router import Router, RouteHandler, get_router
from .signal import Signal
from .template import Template, TemplateContextType
@ -18,32 +21,26 @@ from .template import Template, TemplateContextType
ExceptionType = TypeVar("ExceptionType", bound = Exception)
ExceptionCallback = Callable[[Request, ExceptionType], Response]
APPLICATIONS: dict[str, Application] = {}
class Application:
def __init__(self,
name: str,
app_path: str,
address: str = "127.0.0.1",
port: int = 8080,
workers: int = 1,
allowed_ips: list[str] | None = None,
env: dict[str, str] | None = None,
dev: bool = False,
reload_dirs: list[Path | str] | None = None,
template_search: list[Path | str] | None = None,
template_env: dict[str, Any] | None = None,
template_context: TemplateContextType | None = None) -> None:
template_context: TemplateContextType | None = None,
request_class: type[Request] = Request,
request_state_class: type[StateProxy] = StateProxy,
app_state_class: type[StateProxy] = StateProxy) -> None:
if name in APPLICATIONS:
raise ValueError(f"Application with name '{name}' already exists")
self.name: str = name
self.app_path: str = app_path
self.address: str = address
self.port: int = port
self.workers: int = 1
self.allowed_ips: list[str] = allowed_ips or []
self.env: dict[str, str] = env or {}
self.dev: bool = dev
self.reload_dirs: list[Path | str] = reload_dirs or []
self.state: State = State()
self.state: StateProxy = app_state_class({})
self.request_class: type[Request] = request_class
self.request_state_class: type[StateProxy] = request_state_class
# Not sure how to properly annotate `Exception` here, so just ignore the warnings
self.error_handlers: dict[type[ExceptionType], ExceptionCallback] = { # type: ignore
@ -58,41 +55,48 @@ class Application:
context_function = template_context,
**(template_env or {}))
APPLICATIONS[name] = self
@classmethod
def get(cls: type[Application], name: str) -> Application:
return APPLICATIONS[name]
async def __call__(self,
raw_scope: ScopeType,
receive: Callable[..., Awaitable[dict[str, Any]]],
send: Callable[[dict[str, Any]], Awaitable[None]]) -> None:
scope: ScopeType,
reader: ReaderFunction,
writer: WriterFunction) -> None:
if raw_scope["type"] == "lifespan":
if scope["type"] == "lifespan":
while True:
message = await receive()
message = await reader()
if message["type"] == "lifespan.startup":
self.on_startup.emit()
await send({"type": "lifespan.startup.complete"})
await writer({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown":
await self.on_shutdown.handle_emit()
await send({"type": "lifespan.shutdown.complete"})
await writer({"type": "lifespan.shutdown.complete"})
break
elif raw_scope["type"] == "http":
request = Request(self, raw_scope)
elif scope["type"] == "http":
stream = Stream(reader, writer)
response: Response | None = None
try:
request = self.request_class(self, scope, stream)
match = self.router(request.path, request.method)
request.params = match.params or {} # type: ignore[assignment]
await self.on_request.handle_emit(request)
response = await match.target(request, **(match.params or {}))
response = await match.target(request)
if response is None:
raise HttpError(500, "Empty response")
if isinstance(response, TemplateResponse): # type: ignore[unreachable]
response = response.get_response(request) # type: ignore[unreachable]
await self.on_response.handle_emit(request, response)
except Exception as error:
@ -106,21 +110,9 @@ class Application:
else:
response = self.handle_error_default(request, error)
await response.send(stream, request)
self.print_access_log(request, response)
await send({
"type": "http.response.start",
"status": response.status,
"headers": tuple(response.headers.items())
})
await send({
"type": "http.response.body",
"body": response.body
})
return
@Signal(5.0)
async def on_request(self, request: Request) -> None:
@ -142,6 +134,15 @@ class Application:
pass
def add_route(self, method: str, path: str, handler: RouteHandler) -> None:
self.router.bind(handler, path, methods = [method])
def add_statuc(self, path: str, location: Path | str, cached: bool = False) -> None:
handler = StaticHandler(path, location, cached)
self.add_route("GET", handler.path, handler)
def print_access_log(self, request: Request, response: Response) -> None:
message = "{}: {} \"{} {}\" {} {} \"{}\"".format(
"INFO",
@ -174,3 +175,46 @@ class Application:
def handle_error_default(self, request: Request, exception: Exception) -> Response:
traceback.print_exc()
return Response(500, "Internal Server Error")
class StaticHandler:
def __init__(self, path: str, location: Path | str, cached: bool) -> None:
if isinstance(location, str):
location = Path(location)
self.path: str = path.rstrip("/") + "/{filepath}"
self.location: Path = Path(location).expanduser().resolve()
self.cached: bool = cached
if not self.location.exists():
raise FileNotFoundError(self.location)
if not self.location.is_dir():
raise ValueError("Location is not a directory or file")
async def __call__(self, request: Request) -> Response:
filepath = request.params["filepath"]
if self.cached:
return await self.handle_call_cached(request, filepath)
return await self.handle_call(request, filepath)
async def handle_call(self, request: Request, filepath: str) -> Response:
filepath = normpath(filepath)
path = self.location.joinpath(filepath.lstrip("/"))
if path.is_dir():
path = path.joinpath("index.html")
if not path.is_file():
raise HttpError(404, str(filepath))
return FileResponse(path)
@lru_cache(maxsize = 128, typed = True)
async def handle_call_cached(self, request: Request, filepath: str) -> Response:
return await self.handle_call(request, filepath)

View file

@ -1,6 +1,6 @@
import re
from aputils import IntEnum
from aputils import IntEnum, StrEnum
CAPITAL = re.compile("[A-Z][^A-Z]")
@ -84,3 +84,10 @@ class HttpStatus(IntEnum):
@property
def reason(self) -> str:
return " ".join(CAPITAL.findall(self.name))
class SassOutputStyle(StrEnum):
NESTED = "nested"
EXPANDED = "expanded"
COMPACT = "compact"
COMPRESSED = "compressed"

View file

@ -2,11 +2,26 @@ import asyncio
import json
from aputils import JsonBase
from collections.abc import Iterable
from colorsys import rgb_to_hls, hls_to_rgb
from functools import cached_property
from typing import Any, Self
from collections.abc import (
Awaitable,
Callable,
Iterable,
Iterator,
ItemsView,
KeysView,
MutableMapping,
Sequence,
ValuesView
)
ReaderFunction = Callable[..., Awaitable[dict[str, Any]]]
WriterFunction = Callable[[dict[str, Any]], Awaitable[None]]
HTTP_STATUS = {
100: "Continue",
@ -320,7 +335,11 @@ class Color(str):
return f"rgba({values}, {trans:.2})"
class State(dict[str, Any]):
class StateProxy(MutableMapping[str, Any]):
def __init__(self, state: dict[str, str]) -> None:
self._state: dict[str, str] = state
def __getattr__(self, key: str) -> Any:
try:
object.__getattribute__(self, key)
@ -330,8 +349,122 @@ class State(dict[str, Any]):
def __setattr__(self, key: str, value: Any) -> None:
if key.startswith("_"):
object.__setattr__(self, key, value)
return
self[key] = value
def __delattr__(self, key: str) -> None:
if key.startswith("_"):
object.__delattr__(self, key)
return
del self[key]
def __getitem__(self, key: str) -> Any:
return self._state[key]
def __setitem__(self, key: str, value: Any) -> None:
self._state[key] = value
def __delitem__(self, key: str) -> None:
del self._state[key]
def __len__(self) -> int:
return len(self._state)
def __iter__(self) -> Iterator[str]:
for key in self._state:
yield key
def get(self, key: str, default: Any = None) -> Any:
return self._state.get(key, default)
def items(self) -> ItemsView[str, Any]:
return self._state.items()
def keys(self) -> KeysView[str]:
return self._state.keys()
def set(self, key: str, value: Any) -> None:
self._state[key] = value
def values(self) -> ValuesView[Any]:
return self._state.values()
class Stream:
def __init__(self, reader: ReaderFunction, writer: WriterFunction) -> None:
self.reader = reader
self.writer = writer
self._sent_headers: bool = False
self._sent_body: bool = False
def __aiter__(self) -> Self:
return self
async def __anext__(self) -> bytes:
data, more = await self.read_chunk()
if not data and not more:
raise StopAsyncIteration
return data
async def close(self) -> None:
await self.writer({"type": "http.disconnect"})
async def read(self) -> bytes:
body = b""
while True:
data, more = await self.read_chunk()
body += data
if not more:
break
return body
async def read_chunk(self) -> tuple[bytes, bool]:
data = await self.reader()
return data.get("body", b""), data.get("more_body", False)
async def write_headers(self, status: int, headers: Sequence[tuple[str, str]]) -> None:
if self._sent_headers:
return
await self.writer({
"type": "http.response.start",
"status": status,
"headers": headers
})
async def write_body(self, data: bytes, eof: bool = False) -> None:
if self._sent_body:
return
await self.writer({
"type": "http.response.body",
"body": data,
"more_body": eof
})

View file

@ -1,13 +1,15 @@
from __future__ import annotations
import multipart
import typing
from aputils import JsonBase
from collections.abc import Iterable
from multidict import CIMultiDict, CIMultiDictProxy
from typing import Any, Literal, TypedDict
from urllib.parse import parse_qsl
from .misc import State
from .misc import StateProxy, Stream
if typing.TYPE_CHECKING:
from .application import Application
@ -36,8 +38,14 @@ class ScopeType(TypedDict):
class Request:
def __init__(self, app: Application, scope: ScopeType) -> None:
def __init__(self,
app: Application,
scope: ScopeType,
stream: Stream) -> None:
self.app: Application = app
self.stream: Stream = stream
self._body: bytes = b""
raw_query = parse_qsl(scope["query_string"].decode("utf-8"), keep_blank_values = True)
raw_headers = ((k.decode("utf-8").title(), v.decode("utf-8")) for k, v in scope["headers"])
@ -46,13 +54,66 @@ class Request:
self.path: str = scope["path"]
self.query: CIMultiDictProxy[str] = CIMultiDictProxy(CIMultiDict(raw_query))
self.headers: CIMultiDictProxy[str] = CIMultiDictProxy(CIMultiDict(raw_headers))
self.params: dict[str, Any] = {}
self.remote: str = (scope.get("client") or ("n/a"))[0]
self.local: str = (scope.get("server") or ("n/a"))[0]
self.state: State = State(scope.get("state") or {})
self.state: StateProxy = app.request_state_class(scope.get("state") or {})
self.extensions: dict[str, Any] = scope.get("extensions", {}) # type: ignore[assignment]
# keep?
self.asgi: tuple[str, str] = (scope["asgi"]["spec_version"], scope["asgi"]["version"])
self.version: float = float(scope["http_version"])
self.scheme: str = scope["scheme"].lower()
@property
def content_length(self) -> int:
return int(self.headers.get("Content-Length", "0"))
@property
def content_type(self) -> str:
return self.headers.getone("Content-Type", "")
async def body(self) -> bytes:
if not self._body:
self._body = await self.stream.read()
return self._body
async def text(self) -> str:
return (await self.body()).decode("utf-8")
async def json(self, parser_class: type[JsonBase] = JsonBase) -> JsonBase:
return parser_class.parse(await self.body())
async def form(self) -> CIMultiDictProxy[str | multipart.File]:
if self.content_type != "multipart/form-data":
raise ValueError(f"Invalid mimetype for form data: {self.content_type}")
fields: CIMultiDict[str | multipart.File] = CIMultiDict({})
def handle_field(field: multipart.Field | multipart.File) -> None:
if isinstance(field, multipart.Field):
fields[field.field_name.decode("utf-8")] = field.value.decode("utf-8")
else:
field.file_object.seek(0)
fields[field.field_name.decode("utf-8")] = field
parser = multipart.create_form_parser(self.headers, handle_field, handle_field)
async for chunk in self.stream:
parser.write(chunk)
parser.finalize()
return CIMultiDictProxy(fields)
async def stream_response(self, status: int, headers: dict[str, str]) -> None:
pass

View file

@ -1,11 +1,17 @@
from __future__ import annotations
import os
from aputils import JsonBase
from mimetypes import guess_type
from multidict import CIMultiDict
from pathlib import Path
from typing import Any, Self
from .enums import HttpStatus
from .error import HttpError
from .misc import convert_to_bytes
from .misc import Stream, convert_to_bytes
from .request import Request
@ -14,12 +20,16 @@ class Response:
status: HttpStatus | int,
body: Any,
mimetype: str | None = None,
headers: dict[str, str] | None = None) -> None:
headers: CIMultiDict[str] | dict[str, str] | None = None) -> None:
self._body = b""
self.status: HttpStatus = HttpStatus.parse(status)
self.headers: CIMultiDict[str] = CIMultiDict(headers or {})
if isinstance(headers, CIMultiDict):
self.headers = headers
else:
self.headers = CIMultiDict(headers or {})
if body:
self.body = body
@ -28,6 +38,26 @@ class Response:
self.mimetype = mimetype
@classmethod
def new_json(cls: type[Self],
status: HttpStatus | int,
body: JsonBase | dict[str, Any] | str,
mimetype: str = "application/json",
headers: dict[str, str] | None = None) -> Self:
return cls(status, body, mimetype, headers)
@classmethod
def new_activity(cls: type[Self],
status: HttpStatus | int,
body: JsonBase | dict[str, Any] | str,
mimetype: str = "application/activity+json",
headers: dict[str, str] | None = None) -> Self:
return cls(status, body, mimetype, headers)
@classmethod
def new_from_http_error(cls: type[Self], error: HttpError) -> Self:
message = f"HTTP Error {error.status.value}: {error.message}"
@ -42,21 +72,63 @@ class Response:
@body.setter
def body(self, value: Any) -> None:
self._body = convert_to_bytes(value)
self.headers.popall("Content-Length", None)
self.headers["Content-Length"] = str(len(self._body))
self.length = len(self._body)
@property
def length(self) -> int:
return int(self.headers.getone("Content-Length", "0"))
@length.setter
def length(self, value: int) -> None:
self.headers.update({"Content-Length": str(value)})
@property
def mimetype(self) -> str:
return self.headers["Content-Type"]
return self.headers.getone("Content-Type", "")
@mimetype.setter
def mimetype(self, value: str) -> None:
self.headers["Content-Type"] = value
self.headers.update({"Content-Type": value})
class TemplateResponse:
async def send(self, stream: Stream, request: Request) -> None:
await stream.write_headers(self.status, tuple(self.headers.items()))
await stream.write_body(self.body)
class FileResponse(Response):
def __init__(self,
path: Path,
mimetype: str | None = None,
chunk_size: int = 8192,
status: int = 200,
headers: dict[str, str] | None = None) -> None:
Response.__init__(self, status, b"", headers = headers)
self.path: Path = path
self.mimetype: str = mimetype or guess_type(path)[0] or "application/octet+stream"
self.chunk_size: int = chunk_size or 8192
async def send(self, stream: Stream, request: Request) -> None:
await stream.write_headers(200, tuple([]))
with self.path.open("rb") as fd:
while True:
if not (data := fd.read(self.chunk_size)):
break
await stream.write_body(data, True)
await stream.write_body(b"", False)
class TemplateResponse(Response):
def __init__(self,
name: str,
status: HttpStatus | int = HttpStatus.Ok,
@ -65,10 +137,10 @@ class TemplateResponse:
pretty_print: bool = False,
**context: Any) -> None:
Response.__init__(self, status, b"", headers = headers)
self.name: str = name
self.status: HttpStatus = HttpStatus.parse(status)
self.mimetype: str = mimetype or self.detect_mimetype()
self.headers: dict[str, Any] = headers or {}
self.pretty_print: bool = pretty_print
self.context: dict[str, Any] = context
@ -85,12 +157,11 @@ class TemplateResponse:
if ext == ".xml":
return "text/xml"
return "text/plain"
return guess_type(self.name)[0] or "text/plain"
def get_response(self, request: Request) -> Response:
return Response(self.status, self.render(request), self.mimetype, self.headers)
async def send(self, stream: Stream, request: Request) -> None:
text = request.app.template.render(self.name, self.pretty_print, **self.context)
self.body = text.encode("utf-8")
def render(self, request: Request) -> str:
return request.app.template.render(self.name, self.pretty_print, **self.context)
await Response.send(self, stream, request)

View file

@ -6,10 +6,10 @@ from typing import Any, Pattern
from .error import HttpError
from .request import Request
from .response import Response, TemplateResponse
from .response import Response
RouteHandler = Callable[[Request], Awaitable[Response | TemplateResponse]]
RouteHandler = Callable[[Request], Awaitable[Response]]
class Router(http_router.Router):

View file

@ -1,7 +1,8 @@
import asyncio
import os
from collections.abc import Callable, Coroutine
from collections.abc import Callable, Coroutine, Sequence
from pathlib import Path
from typing import Any
from uvicorn import Config, Server
from uvicorn.supervisors import ChangeReload, Multiprocess
@ -12,31 +13,39 @@ from .application import Application
BackgroundTask = Callable[[Application], Coroutine[Any, Any, None]]
def run_app(app: Application, *bgtasks: BackgroundTask) -> None:
def run_app(*args: Any, **kwargs: Any) -> None:
try:
asyncio.run(handle_run_server(app, *bgtasks))
asyncio.run(handle_run_server(*args, **kwargs))
except KeyboardInterrupt:
pass
async def handle_run_server(app: Application, *bgtasks: BackgroundTask) -> None:
for key, value in app.env.items():
os.environ[key] = value
os.environ["BARKSHARK_ASGI_NAME"] = app.name
async def handle_run_server(
app: Application,
app_spec: str | None = None,
address: str = "127.0.0.1",
port: int = 8080,
workers: int = 1,
allowed_ips: list[str] | None = None,
reload: bool = False,
reload_dirs: list[Path | str] | None = None,
bgtasks: Sequence[BackgroundTask] | None = None,
env: dict[str, str] | None = None,
dev: bool = False) -> None:
cfg = Config(
app.app_path if app.workers > 1 else app,
host = app.address,
port = app.port,
workers = app.workers if not app.dev else 1,
app_spec if workers > 1 else app, # type: ignore[arg-type]
host = address,
port = port,
workers = workers if not dev else 1,
access_log = False,
log_level = "info",
factory = True,
forwarded_allow_ips = app.allowed_ips or None,
reload = app.dev,
reload_dirs = [str(path) for path in app.reload_dirs],
forwarded_allow_ips = allowed_ips or None,
reload = dev,
reload_dirs = [str(path) for path in reload_dirs] if reload_dirs else [],
reload_delay = 1,
server_header = False,
lifespan = "on",
@ -45,19 +54,22 @@ async def handle_run_server(app: Application, *bgtasks: BackgroundTask) -> None:
]
)
for key, value in (env or {}).items():
os.environ[key] = value
tasks: list[asyncio.Task[Any]] = []
for task in bgtasks:
for task in (bgtasks or []):
tasks.append(asyncio.create_task(task(app)))
server = Server(cfg)
runner: ChangeReload | Multiprocess # type: ignore[valid-type]
if app.dev:
if dev:
runner = ChangeReload(cfg, target = server.run, sockets = [cfg.bind_socket()])
runner.run()
elif cfg.workers > 1:
elif workers > 1:
runner = Multiprocess(cfg, target = server.run, sockets = [cfg.bind_socket()])
runner.run()

View file

@ -1,15 +1,19 @@
from __future__ import annotations
import os
import sass
from collections.abc import Callable, Sequence
from hamlish_jinja import HamlishExtension, OutputMode
from jinja2 import Environment, FileSystemLoader
from jinja2.ext import Extension
from os.path import splitext
from pathlib import Path
from typing import Any
from xml.dom import minidom
from xml.etree import ElementTree
from .enums import SassOutputStyle
from .misc import Color
@ -28,6 +32,60 @@ class FsLoader(FileSystemLoader):
self.followlinks: bool = followlinks
class SassExtension(Extension):
"An extension for Jinja2 that adds support for sass and scss compiling."
def __init__(self, environment: Environment):
Extension.__init__(self, environment)
self.output_style: SassOutputStyle = SassOutputStyle.NESTED
self.include_paths: list[str] = []
self._exts: tuple[str, str] = (".sass", ".scss")
environment.extend( # type: ignore[no-untyped-call]
sass_get_output_style = self.get_output_style,
sass_set_output_style = self.set_output_style,
sass_append_include_path = self.include_paths.append,
sass_remove_include_path = self.include_paths.remove
)
def __repr__(self) -> str:
return f"SassExtension(output_style='{self.output_style.value}')"
def get_output_style(self) -> SassOutputStyle:
return self.output_style
def set_output_style(self, value: SassOutputStyle) -> None:
self.output_style = SassOutputStyle.parse(value)
def preprocess(self, source: str, name: str | None, filename: str | None = None) -> str:
"""
Transpile a sass or scss file into a css file
:param source: Full text source of the template
:param name: Name of the template
:param filename: Path to the template
:raises CompileError: When the template cannot be parsed
"""
if (tpl_name := filename or name) is None:
return source
if (ext := splitext(tpl_name)[1]) not in self._exts:
return source
return sass.compile( # type: ignore[no-any-return]
string = source,
output_style = self.output_style.value,
indented = ext == ".sass"
)
class Template(Environment):
def __init__(self,
*search: str | Path,
@ -38,9 +96,12 @@ class Template(Environment):
super().__init__(
loader = self.search,
extensions = [HamlishExtension],
lstrip_blocks = True,
trim_blocks = True
trim_blocks = True,
extensions = [
HamlishExtension,
SassExtension
]
)
for path in search:

View file

@ -42,6 +42,7 @@ dependencies = [
"jinja2-haml == 0.3.5",
"libsass == 0.23.0",
"multidict == 6.0.5",
"multipart == 0.2.4",
"python-multipart == 0.0.9",
"pyyaml == 6.0.1",
"uvicorn == 0.29.0"