This commit is contained in:
Izalia Mae 2024-04-13 09:33:55 -04:00
parent 6ffe05c977
commit d58fbc6061
10 changed files with 508 additions and 122 deletions

View file

@ -1,31 +1,27 @@
__version__ = "0.1.0" __version__ = "0.1.0"
try: from .application import Application, ExceptionType, ExceptionCallback
from .application import Application, ExceptionType, ExceptionCallback from .enums import HttpStatus
from .enums import HttpStatus from .error import HttpError
from .error import HttpError from .misc import Color, StateProxy, Stream
from .misc import Color from .request import Request
from .request import Request from .response import Response, FileResponse, TemplateResponse
from .response import Response, TemplateResponse from .runner import run_app
from .runner import run_app from .signal import Signal, SignalCallback
from .signal import Signal, SignalCallback from .template import Template
from .template import Template
from .router import ( from .router import (
Router, Router,
get_router, get_router,
route, route,
connect, connect,
delete, delete,
get, get,
head, head,
options, options,
patch, patch,
post, post,
put, put,
trace trace
) )
except ModuleNotFoundError:
pass

View file

@ -1,16 +1,19 @@
from __future__ import annotations
import bsql import bsql
import traceback import traceback
from collections.abc import Awaitable, Callable from collections.abc import Callable
from functools import lru_cache
from os.path import normpath
from pathlib import Path from pathlib import Path
from typing import Any, TypeVar from typing import Any, TypeVar
from .error import HttpError from .error import HttpError
from .misc import State from .misc import StateProxy, Stream, ReaderFunction, WriterFunction
from .objtypes import ScopeType from .request import Request, ScopeType
from .request import Request from .response import FileResponse, Response
from .response import Response, TemplateResponse from .router import Router, RouteHandler, get_router
from .router import Router, get_router
from .signal import Signal from .signal import Signal
from .template import Template, TemplateContextType from .template import Template, TemplateContextType
@ -18,32 +21,26 @@ from .template import Template, TemplateContextType
ExceptionType = TypeVar("ExceptionType", bound = Exception) ExceptionType = TypeVar("ExceptionType", bound = Exception)
ExceptionCallback = Callable[[Request, ExceptionType], Response] ExceptionCallback = Callable[[Request, ExceptionType], Response]
APPLICATIONS: dict[str, Application] = {}
class Application: class Application:
def __init__(self, def __init__(self,
name: str, name: str,
app_path: str,
address: str = "127.0.0.1",
port: int = 8080,
workers: int = 1,
allowed_ips: list[str] | None = None,
env: dict[str, str] | None = None,
dev: bool = False,
reload_dirs: list[Path | str] | None = None,
template_search: list[Path | str] | None = None, template_search: list[Path | str] | None = None,
template_env: dict[str, Any] | None = None, template_env: dict[str, Any] | None = None,
template_context: TemplateContextType | None = None) -> None: template_context: TemplateContextType | None = None,
request_class: type[Request] = Request,
request_state_class: type[StateProxy] = StateProxy,
app_state_class: type[StateProxy] = StateProxy) -> None:
if name in APPLICATIONS:
raise ValueError(f"Application with name '{name}' already exists")
self.name: str = name self.name: str = name
self.app_path: str = app_path self.state: StateProxy = app_state_class({})
self.address: str = address self.request_class: type[Request] = request_class
self.port: int = port self.request_state_class: type[StateProxy] = request_state_class
self.workers: int = 1
self.allowed_ips: list[str] = allowed_ips or []
self.env: dict[str, str] = env or {}
self.dev: bool = dev
self.reload_dirs: list[Path | str] = reload_dirs or []
self.state: State = State()
# Not sure how to properly annotate `Exception` here, so just ignore the warnings # Not sure how to properly annotate `Exception` here, so just ignore the warnings
self.error_handlers: dict[type[ExceptionType], ExceptionCallback] = { # type: ignore self.error_handlers: dict[type[ExceptionType], ExceptionCallback] = { # type: ignore
@ -58,41 +55,48 @@ class Application:
context_function = template_context, context_function = template_context,
**(template_env or {})) **(template_env or {}))
APPLICATIONS[name] = self
@classmethod
def get(cls: type[Application], name: str) -> Application:
return APPLICATIONS[name]
async def __call__(self, async def __call__(self,
raw_scope: ScopeType, scope: ScopeType,
receive: Callable[..., Awaitable[dict[str, Any]]], reader: ReaderFunction,
send: Callable[[dict[str, Any]], Awaitable[None]]) -> None: writer: WriterFunction) -> None:
if raw_scope["type"] == "lifespan": if scope["type"] == "lifespan":
while True: while True:
message = await receive() message = await reader()
if message["type"] == "lifespan.startup": if message["type"] == "lifespan.startup":
self.on_startup.emit() self.on_startup.emit()
await send({"type": "lifespan.startup.complete"}) await writer({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown": elif message["type"] == "lifespan.shutdown":
await self.on_shutdown.handle_emit() await self.on_shutdown.handle_emit()
await send({"type": "lifespan.shutdown.complete"}) await writer({"type": "lifespan.shutdown.complete"})
break break
elif raw_scope["type"] == "http": elif scope["type"] == "http":
request = Request(self, raw_scope) stream = Stream(reader, writer)
response: Response | None = None response: Response | None = None
try: try:
request = self.request_class(self, scope, stream)
match = self.router(request.path, request.method) match = self.router(request.path, request.method)
request.params = match.params or {} # type: ignore[assignment]
await self.on_request.handle_emit(request) await self.on_request.handle_emit(request)
response = await match.target(request, **(match.params or {})) response = await match.target(request)
if response is None: if response is None:
raise HttpError(500, "Empty response") raise HttpError(500, "Empty response")
if isinstance(response, TemplateResponse): # type: ignore[unreachable]
response = response.get_response(request) # type: ignore[unreachable]
await self.on_response.handle_emit(request, response) await self.on_response.handle_emit(request, response)
except Exception as error: except Exception as error:
@ -106,21 +110,9 @@ class Application:
else: else:
response = self.handle_error_default(request, error) response = self.handle_error_default(request, error)
await response.send(stream, request)
self.print_access_log(request, response) self.print_access_log(request, response)
await send({
"type": "http.response.start",
"status": response.status,
"headers": tuple(response.headers.items())
})
await send({
"type": "http.response.body",
"body": response.body
})
return
@Signal(5.0) @Signal(5.0)
async def on_request(self, request: Request) -> None: async def on_request(self, request: Request) -> None:
@ -142,6 +134,15 @@ class Application:
pass pass
def add_route(self, method: str, path: str, handler: RouteHandler) -> None:
self.router.bind(handler, path, methods = [method])
def add_statuc(self, path: str, location: Path | str, cached: bool = False) -> None:
handler = StaticHandler(path, location, cached)
self.add_route("GET", handler.path, handler)
def print_access_log(self, request: Request, response: Response) -> None: def print_access_log(self, request: Request, response: Response) -> None:
message = "{}: {} \"{} {}\" {} {} \"{}\"".format( message = "{}: {} \"{} {}\" {} {} \"{}\"".format(
"INFO", "INFO",
@ -174,3 +175,46 @@ class Application:
def handle_error_default(self, request: Request, exception: Exception) -> Response: def handle_error_default(self, request: Request, exception: Exception) -> Response:
traceback.print_exc() traceback.print_exc()
return Response(500, "Internal Server Error") return Response(500, "Internal Server Error")
class StaticHandler:
def __init__(self, path: str, location: Path | str, cached: bool) -> None:
if isinstance(location, str):
location = Path(location)
self.path: str = path.rstrip("/") + "/{filepath}"
self.location: Path = Path(location).expanduser().resolve()
self.cached: bool = cached
if not self.location.exists():
raise FileNotFoundError(self.location)
if not self.location.is_dir():
raise ValueError("Location is not a directory or file")
async def __call__(self, request: Request) -> Response:
filepath = request.params["filepath"]
if self.cached:
return await self.handle_call_cached(request, filepath)
return await self.handle_call(request, filepath)
async def handle_call(self, request: Request, filepath: str) -> Response:
filepath = normpath(filepath)
path = self.location.joinpath(filepath.lstrip("/"))
if path.is_dir():
path = path.joinpath("index.html")
if not path.is_file():
raise HttpError(404, str(filepath))
return FileResponse(path)
@lru_cache(maxsize = 128, typed = True)
async def handle_call_cached(self, request: Request, filepath: str) -> Response:
return await self.handle_call(request, filepath)

View file

@ -1,6 +1,6 @@
import re import re
from aputils import IntEnum from aputils import IntEnum, StrEnum
CAPITAL = re.compile("[A-Z][^A-Z]") CAPITAL = re.compile("[A-Z][^A-Z]")
@ -84,3 +84,10 @@ class HttpStatus(IntEnum):
@property @property
def reason(self) -> str: def reason(self) -> str:
return " ".join(CAPITAL.findall(self.name)) return " ".join(CAPITAL.findall(self.name))
class SassOutputStyle(StrEnum):
NESTED = "nested"
EXPANDED = "expanded"
COMPACT = "compact"
COMPRESSED = "compressed"

View file

@ -2,11 +2,26 @@ import asyncio
import json import json
from aputils import JsonBase from aputils import JsonBase
from collections.abc import Iterable
from colorsys import rgb_to_hls, hls_to_rgb from colorsys import rgb_to_hls, hls_to_rgb
from functools import cached_property from functools import cached_property
from typing import Any, Self from typing import Any, Self
from collections.abc import (
Awaitable,
Callable,
Iterable,
Iterator,
ItemsView,
KeysView,
MutableMapping,
Sequence,
ValuesView
)
ReaderFunction = Callable[..., Awaitable[dict[str, Any]]]
WriterFunction = Callable[[dict[str, Any]], Awaitable[None]]
HTTP_STATUS = { HTTP_STATUS = {
100: "Continue", 100: "Continue",
@ -320,7 +335,11 @@ class Color(str):
return f"rgba({values}, {trans:.2})" return f"rgba({values}, {trans:.2})"
class State(dict[str, Any]): class StateProxy(MutableMapping[str, Any]):
def __init__(self, state: dict[str, str]) -> None:
self._state: dict[str, str] = state
def __getattr__(self, key: str) -> Any: def __getattr__(self, key: str) -> Any:
try: try:
object.__getattribute__(self, key) object.__getattribute__(self, key)
@ -330,8 +349,122 @@ class State(dict[str, Any]):
def __setattr__(self, key: str, value: Any) -> None: def __setattr__(self, key: str, value: Any) -> None:
if key.startswith("_"):
object.__setattr__(self, key, value)
return
self[key] = value self[key] = value
def __delattr__(self, key: str) -> None: def __delattr__(self, key: str) -> None:
if key.startswith("_"):
object.__delattr__(self, key)
return
del self[key] del self[key]
def __getitem__(self, key: str) -> Any:
return self._state[key]
def __setitem__(self, key: str, value: Any) -> None:
self._state[key] = value
def __delitem__(self, key: str) -> None:
del self._state[key]
def __len__(self) -> int:
return len(self._state)
def __iter__(self) -> Iterator[str]:
for key in self._state:
yield key
def get(self, key: str, default: Any = None) -> Any:
return self._state.get(key, default)
def items(self) -> ItemsView[str, Any]:
return self._state.items()
def keys(self) -> KeysView[str]:
return self._state.keys()
def set(self, key: str, value: Any) -> None:
self._state[key] = value
def values(self) -> ValuesView[Any]:
return self._state.values()
class Stream:
def __init__(self, reader: ReaderFunction, writer: WriterFunction) -> None:
self.reader = reader
self.writer = writer
self._sent_headers: bool = False
self._sent_body: bool = False
def __aiter__(self) -> Self:
return self
async def __anext__(self) -> bytes:
data, more = await self.read_chunk()
if not data and not more:
raise StopAsyncIteration
return data
async def close(self) -> None:
await self.writer({"type": "http.disconnect"})
async def read(self) -> bytes:
body = b""
while True:
data, more = await self.read_chunk()
body += data
if not more:
break
return body
async def read_chunk(self) -> tuple[bytes, bool]:
data = await self.reader()
return data.get("body", b""), data.get("more_body", False)
async def write_headers(self, status: int, headers: Sequence[tuple[str, str]]) -> None:
if self._sent_headers:
return
await self.writer({
"type": "http.response.start",
"status": status,
"headers": headers
})
async def write_body(self, data: bytes, eof: bool = False) -> None:
if self._sent_body:
return
await self.writer({
"type": "http.response.body",
"body": data,
"more_body": eof
})

View file

@ -1,13 +1,15 @@
from __future__ import annotations from __future__ import annotations
import multipart
import typing import typing
from aputils import JsonBase
from collections.abc import Iterable from collections.abc import Iterable
from multidict import CIMultiDict, CIMultiDictProxy from multidict import CIMultiDict, CIMultiDictProxy
from typing import Any, Literal, TypedDict from typing import Any, Literal, TypedDict
from urllib.parse import parse_qsl from urllib.parse import parse_qsl
from .misc import State from .misc import StateProxy, Stream
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from .application import Application from .application import Application
@ -36,8 +38,14 @@ class ScopeType(TypedDict):
class Request: class Request:
def __init__(self, app: Application, scope: ScopeType) -> None: def __init__(self,
app: Application,
scope: ScopeType,
stream: Stream) -> None:
self.app: Application = app self.app: Application = app
self.stream: Stream = stream
self._body: bytes = b""
raw_query = parse_qsl(scope["query_string"].decode("utf-8"), keep_blank_values = True) raw_query = parse_qsl(scope["query_string"].decode("utf-8"), keep_blank_values = True)
raw_headers = ((k.decode("utf-8").title(), v.decode("utf-8")) for k, v in scope["headers"]) raw_headers = ((k.decode("utf-8").title(), v.decode("utf-8")) for k, v in scope["headers"])
@ -46,13 +54,66 @@ class Request:
self.path: str = scope["path"] self.path: str = scope["path"]
self.query: CIMultiDictProxy[str] = CIMultiDictProxy(CIMultiDict(raw_query)) self.query: CIMultiDictProxy[str] = CIMultiDictProxy(CIMultiDict(raw_query))
self.headers: CIMultiDictProxy[str] = CIMultiDictProxy(CIMultiDict(raw_headers)) self.headers: CIMultiDictProxy[str] = CIMultiDictProxy(CIMultiDict(raw_headers))
self.params: dict[str, Any] = {}
self.remote: str = (scope.get("client") or ("n/a"))[0] self.remote: str = (scope.get("client") or ("n/a"))[0]
self.local: str = (scope.get("server") or ("n/a"))[0] self.local: str = (scope.get("server") or ("n/a"))[0]
self.state: State = State(scope.get("state") or {}) self.state: StateProxy = app.request_state_class(scope.get("state") or {})
self.extensions: dict[str, Any] = scope.get("extensions", {}) # type: ignore[assignment] self.extensions: dict[str, Any] = scope.get("extensions", {}) # type: ignore[assignment]
# keep? # keep?
self.asgi: tuple[str, str] = (scope["asgi"]["spec_version"], scope["asgi"]["version"]) self.asgi: tuple[str, str] = (scope["asgi"]["spec_version"], scope["asgi"]["version"])
self.version: float = float(scope["http_version"]) self.version: float = float(scope["http_version"])
self.scheme: str = scope["scheme"].lower() self.scheme: str = scope["scheme"].lower()
@property
def content_length(self) -> int:
return int(self.headers.get("Content-Length", "0"))
@property
def content_type(self) -> str:
return self.headers.getone("Content-Type", "")
async def body(self) -> bytes:
if not self._body:
self._body = await self.stream.read()
return self._body
async def text(self) -> str:
return (await self.body()).decode("utf-8")
async def json(self, parser_class: type[JsonBase] = JsonBase) -> JsonBase:
return parser_class.parse(await self.body())
async def form(self) -> CIMultiDictProxy[str | multipart.File]:
if self.content_type != "multipart/form-data":
raise ValueError(f"Invalid mimetype for form data: {self.content_type}")
fields: CIMultiDict[str | multipart.File] = CIMultiDict({})
def handle_field(field: multipart.Field | multipart.File) -> None:
if isinstance(field, multipart.Field):
fields[field.field_name.decode("utf-8")] = field.value.decode("utf-8")
else:
field.file_object.seek(0)
fields[field.field_name.decode("utf-8")] = field
parser = multipart.create_form_parser(self.headers, handle_field, handle_field)
async for chunk in self.stream:
parser.write(chunk)
parser.finalize()
return CIMultiDictProxy(fields)
async def stream_response(self, status: int, headers: dict[str, str]) -> None:
pass

View file

@ -1,11 +1,17 @@
from __future__ import annotations
import os import os
from aputils import JsonBase
from mimetypes import guess_type
from multidict import CIMultiDict from multidict import CIMultiDict
from pathlib import Path
from typing import Any, Self from typing import Any, Self
from .enums import HttpStatus from .enums import HttpStatus
from .error import HttpError from .error import HttpError
from .misc import convert_to_bytes from .misc import Stream, convert_to_bytes
from .request import Request from .request import Request
@ -14,12 +20,16 @@ class Response:
status: HttpStatus | int, status: HttpStatus | int,
body: Any, body: Any,
mimetype: str | None = None, mimetype: str | None = None,
headers: dict[str, str] | None = None) -> None: headers: CIMultiDict[str] | dict[str, str] | None = None) -> None:
self._body = b"" self._body = b""
self.status: HttpStatus = HttpStatus.parse(status) self.status: HttpStatus = HttpStatus.parse(status)
self.headers: CIMultiDict[str] = CIMultiDict(headers or {})
if isinstance(headers, CIMultiDict):
self.headers = headers
else:
self.headers = CIMultiDict(headers or {})
if body: if body:
self.body = body self.body = body
@ -28,6 +38,26 @@ class Response:
self.mimetype = mimetype self.mimetype = mimetype
@classmethod
def new_json(cls: type[Self],
status: HttpStatus | int,
body: JsonBase | dict[str, Any] | str,
mimetype: str = "application/json",
headers: dict[str, str] | None = None) -> Self:
return cls(status, body, mimetype, headers)
@classmethod
def new_activity(cls: type[Self],
status: HttpStatus | int,
body: JsonBase | dict[str, Any] | str,
mimetype: str = "application/activity+json",
headers: dict[str, str] | None = None) -> Self:
return cls(status, body, mimetype, headers)
@classmethod @classmethod
def new_from_http_error(cls: type[Self], error: HttpError) -> Self: def new_from_http_error(cls: type[Self], error: HttpError) -> Self:
message = f"HTTP Error {error.status.value}: {error.message}" message = f"HTTP Error {error.status.value}: {error.message}"
@ -42,21 +72,63 @@ class Response:
@body.setter @body.setter
def body(self, value: Any) -> None: def body(self, value: Any) -> None:
self._body = convert_to_bytes(value) self._body = convert_to_bytes(value)
self.headers.popall("Content-Length", None) self.length = len(self._body)
self.headers["Content-Length"] = str(len(self._body))
@property
def length(self) -> int:
return int(self.headers.getone("Content-Length", "0"))
@length.setter
def length(self, value: int) -> None:
self.headers.update({"Content-Length": str(value)})
@property @property
def mimetype(self) -> str: def mimetype(self) -> str:
return self.headers["Content-Type"] return self.headers.getone("Content-Type", "")
@mimetype.setter @mimetype.setter
def mimetype(self, value: str) -> None: def mimetype(self, value: str) -> None:
self.headers["Content-Type"] = value self.headers.update({"Content-Type": value})
class TemplateResponse: async def send(self, stream: Stream, request: Request) -> None:
await stream.write_headers(self.status, tuple(self.headers.items()))
await stream.write_body(self.body)
class FileResponse(Response):
def __init__(self,
path: Path,
mimetype: str | None = None,
chunk_size: int = 8192,
status: int = 200,
headers: dict[str, str] | None = None) -> None:
Response.__init__(self, status, b"", headers = headers)
self.path: Path = path
self.mimetype: str = mimetype or guess_type(path)[0] or "application/octet+stream"
self.chunk_size: int = chunk_size or 8192
async def send(self, stream: Stream, request: Request) -> None:
await stream.write_headers(200, tuple([]))
with self.path.open("rb") as fd:
while True:
if not (data := fd.read(self.chunk_size)):
break
await stream.write_body(data, True)
await stream.write_body(b"", False)
class TemplateResponse(Response):
def __init__(self, def __init__(self,
name: str, name: str,
status: HttpStatus | int = HttpStatus.Ok, status: HttpStatus | int = HttpStatus.Ok,
@ -65,10 +137,10 @@ class TemplateResponse:
pretty_print: bool = False, pretty_print: bool = False,
**context: Any) -> None: **context: Any) -> None:
Response.__init__(self, status, b"", headers = headers)
self.name: str = name self.name: str = name
self.status: HttpStatus = HttpStatus.parse(status)
self.mimetype: str = mimetype or self.detect_mimetype() self.mimetype: str = mimetype or self.detect_mimetype()
self.headers: dict[str, Any] = headers or {}
self.pretty_print: bool = pretty_print self.pretty_print: bool = pretty_print
self.context: dict[str, Any] = context self.context: dict[str, Any] = context
@ -85,12 +157,11 @@ class TemplateResponse:
if ext == ".xml": if ext == ".xml":
return "text/xml" return "text/xml"
return "text/plain" return guess_type(self.name)[0] or "text/plain"
def get_response(self, request: Request) -> Response: async def send(self, stream: Stream, request: Request) -> None:
return Response(self.status, self.render(request), self.mimetype, self.headers) text = request.app.template.render(self.name, self.pretty_print, **self.context)
self.body = text.encode("utf-8")
await Response.send(self, stream, request)
def render(self, request: Request) -> str:
return request.app.template.render(self.name, self.pretty_print, **self.context)

View file

@ -6,10 +6,10 @@ from typing import Any, Pattern
from .error import HttpError from .error import HttpError
from .request import Request from .request import Request
from .response import Response, TemplateResponse from .response import Response
RouteHandler = Callable[[Request], Awaitable[Response | TemplateResponse]] RouteHandler = Callable[[Request], Awaitable[Response]]
class Router(http_router.Router): class Router(http_router.Router):

View file

@ -1,7 +1,8 @@
import asyncio import asyncio
import os import os
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine, Sequence
from pathlib import Path
from typing import Any from typing import Any
from uvicorn import Config, Server from uvicorn import Config, Server
from uvicorn.supervisors import ChangeReload, Multiprocess from uvicorn.supervisors import ChangeReload, Multiprocess
@ -12,31 +13,39 @@ from .application import Application
BackgroundTask = Callable[[Application], Coroutine[Any, Any, None]] BackgroundTask = Callable[[Application], Coroutine[Any, Any, None]]
def run_app(app: Application, *bgtasks: BackgroundTask) -> None: def run_app(*args: Any, **kwargs: Any) -> None:
try: try:
asyncio.run(handle_run_server(app, *bgtasks)) asyncio.run(handle_run_server(*args, **kwargs))
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
async def handle_run_server(app: Application, *bgtasks: BackgroundTask) -> None: async def handle_run_server(
for key, value in app.env.items(): app: Application,
os.environ[key] = value app_spec: str | None = None,
address: str = "127.0.0.1",
os.environ["BARKSHARK_ASGI_NAME"] = app.name port: int = 8080,
workers: int = 1,
allowed_ips: list[str] | None = None,
reload: bool = False,
reload_dirs: list[Path | str] | None = None,
bgtasks: Sequence[BackgroundTask] | None = None,
env: dict[str, str] | None = None,
dev: bool = False) -> None:
cfg = Config( cfg = Config(
app.app_path if app.workers > 1 else app, app_spec if workers > 1 else app, # type: ignore[arg-type]
host = app.address, host = address,
port = app.port, port = port,
workers = app.workers if not app.dev else 1, workers = workers if not dev else 1,
access_log = False, access_log = False,
log_level = "info", log_level = "info",
factory = True, factory = True,
forwarded_allow_ips = app.allowed_ips or None, forwarded_allow_ips = allowed_ips or None,
reload = app.dev, reload = dev,
reload_dirs = [str(path) for path in app.reload_dirs], reload_dirs = [str(path) for path in reload_dirs] if reload_dirs else [],
reload_delay = 1, reload_delay = 1,
server_header = False, server_header = False,
lifespan = "on", lifespan = "on",
@ -45,19 +54,22 @@ async def handle_run_server(app: Application, *bgtasks: BackgroundTask) -> None:
] ]
) )
for key, value in (env or {}).items():
os.environ[key] = value
tasks: list[asyncio.Task[Any]] = [] tasks: list[asyncio.Task[Any]] = []
for task in bgtasks: for task in (bgtasks or []):
tasks.append(asyncio.create_task(task(app))) tasks.append(asyncio.create_task(task(app)))
server = Server(cfg) server = Server(cfg)
runner: ChangeReload | Multiprocess # type: ignore[valid-type] runner: ChangeReload | Multiprocess # type: ignore[valid-type]
if app.dev: if dev:
runner = ChangeReload(cfg, target = server.run, sockets = [cfg.bind_socket()]) runner = ChangeReload(cfg, target = server.run, sockets = [cfg.bind_socket()])
runner.run() runner.run()
elif cfg.workers > 1: elif workers > 1:
runner = Multiprocess(cfg, target = server.run, sockets = [cfg.bind_socket()]) runner = Multiprocess(cfg, target = server.run, sockets = [cfg.bind_socket()])
runner.run() runner.run()

View file

@ -1,15 +1,19 @@
from __future__ import annotations from __future__ import annotations
import os import os
import sass
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from hamlish_jinja import HamlishExtension, OutputMode from hamlish_jinja import HamlishExtension, OutputMode
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
from jinja2.ext import Extension
from os.path import splitext
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from xml.dom import minidom from xml.dom import minidom
from xml.etree import ElementTree from xml.etree import ElementTree
from .enums import SassOutputStyle
from .misc import Color from .misc import Color
@ -28,6 +32,60 @@ class FsLoader(FileSystemLoader):
self.followlinks: bool = followlinks self.followlinks: bool = followlinks
class SassExtension(Extension):
"An extension for Jinja2 that adds support for sass and scss compiling."
def __init__(self, environment: Environment):
Extension.__init__(self, environment)
self.output_style: SassOutputStyle = SassOutputStyle.NESTED
self.include_paths: list[str] = []
self._exts: tuple[str, str] = (".sass", ".scss")
environment.extend( # type: ignore[no-untyped-call]
sass_get_output_style = self.get_output_style,
sass_set_output_style = self.set_output_style,
sass_append_include_path = self.include_paths.append,
sass_remove_include_path = self.include_paths.remove
)
def __repr__(self) -> str:
return f"SassExtension(output_style='{self.output_style.value}')"
def get_output_style(self) -> SassOutputStyle:
return self.output_style
def set_output_style(self, value: SassOutputStyle) -> None:
self.output_style = SassOutputStyle.parse(value)
def preprocess(self, source: str, name: str | None, filename: str | None = None) -> str:
"""
Transpile a sass or scss file into a css file
:param source: Full text source of the template
:param name: Name of the template
:param filename: Path to the template
:raises CompileError: When the template cannot be parsed
"""
if (tpl_name := filename or name) is None:
return source
if (ext := splitext(tpl_name)[1]) not in self._exts:
return source
return sass.compile( # type: ignore[no-any-return]
string = source,
output_style = self.output_style.value,
indented = ext == ".sass"
)
class Template(Environment): class Template(Environment):
def __init__(self, def __init__(self,
*search: str | Path, *search: str | Path,
@ -38,9 +96,12 @@ class Template(Environment):
super().__init__( super().__init__(
loader = self.search, loader = self.search,
extensions = [HamlishExtension],
lstrip_blocks = True, lstrip_blocks = True,
trim_blocks = True trim_blocks = True,
extensions = [
HamlishExtension,
SassExtension
]
) )
for path in search: for path in search:

View file

@ -42,6 +42,7 @@ dependencies = [
"jinja2-haml == 0.3.5", "jinja2-haml == 0.3.5",
"libsass == 0.23.0", "libsass == 0.23.0",
"multidict == 6.0.5", "multidict == 6.0.5",
"multipart == 0.2.4",
"python-multipart == 0.0.9", "python-multipart == 0.0.9",
"pyyaml == 6.0.1", "pyyaml == 6.0.1",
"uvicorn == 0.29.0" "uvicorn == 0.29.0"