basgi/basgi/template.py
2024-04-13 09:33:55 -04:00

192 lines
5 KiB
Python

from __future__ import annotations
import os
import sass
from collections.abc import Callable, Sequence
from hamlish_jinja import HamlishExtension, OutputMode
from jinja2 import Environment, FileSystemLoader
from jinja2.ext import Extension
from os.path import splitext
from pathlib import Path
from typing import Any
from xml.dom import minidom
from xml.etree import ElementTree
from .enums import SassOutputStyle
from .misc import Color
TemplateContextType = Callable[[Environment, dict[str, Any]], dict[str, Any]]
# jinja's annotation for the `searchpath` property was weird, so this fixes it
class FsLoader(FileSystemLoader):
def __init__(self,
searchpath: Sequence[str | Path],
encoding: str = "utf-8",
followlinks: bool = False) -> None:
self.searchpath: list[str] = [os.fspath(p) for p in searchpath] # type: ignore[assignment]
self.encoding: str = encoding
self.followlinks: bool = followlinks
class SassExtension(Extension):
"An extension for Jinja2 that adds support for sass and scss compiling."
def __init__(self, environment: Environment):
Extension.__init__(self, environment)
self.output_style: SassOutputStyle = SassOutputStyle.NESTED
self.include_paths: list[str] = []
self._exts: tuple[str, str] = (".sass", ".scss")
environment.extend( # type: ignore[no-untyped-call]
sass_get_output_style = self.get_output_style,
sass_set_output_style = self.set_output_style,
sass_append_include_path = self.include_paths.append,
sass_remove_include_path = self.include_paths.remove
)
def __repr__(self) -> str:
return f"SassExtension(output_style='{self.output_style.value}')"
def get_output_style(self) -> SassOutputStyle:
return self.output_style
def set_output_style(self, value: SassOutputStyle) -> None:
self.output_style = SassOutputStyle.parse(value)
def preprocess(self, source: str, name: str | None, filename: str | None = None) -> str:
"""
Transpile a sass or scss file into a css file
:param source: Full text source of the template
:param name: Name of the template
:param filename: Path to the template
:raises CompileError: When the template cannot be parsed
"""
if (tpl_name := filename or name) is None:
return source
if (ext := splitext(tpl_name)[1]) not in self._exts:
return source
return sass.compile( # type: ignore[no-any-return]
string = source,
output_style = self.output_style.value,
indented = ext == ".sass"
)
class Template(Environment):
def __init__(self,
*search: str | Path,
context_function: TemplateContextType | None = None,
**global_env: Any):
self.search: FsLoader = FsLoader([])
super().__init__(
loader = self.search,
lstrip_blocks = True,
trim_blocks = True,
extensions = [
HamlishExtension,
SassExtension
]
)
for path in search:
self.add_search_path(path)
self.autoescape: bool = True
self.context_function: TemplateContextType | None = None
if context_function:
self.set_context_function(context_function)
self.global_env: dict[str, Any] = {
"cleanhtml": lambda text: "".join(ElementTree.fromstring(text).itertext()),
"color": Color,
"lighten": lambda c, v: Color(c).lighten(v),
"darken": lambda c, v: Color(c).darken(v),
"saturate": lambda c, v: Color(c).saturate(v),
"desaturate": lambda c, v: Color(c).desaturate(v),
"rgba": lambda c, v: Color(c).rgba(v),
**global_env
}
self.hamlish_file_extensions = (".haml", ".jhaml", ".jaml")
self.hamlish_enable_div_shortcut = True
self.hamlish_mode = OutputMode.INDENTED
def add_search_path(self, path: Path | str, index: int = 0) -> None:
if isinstance(path, str):
path = Path(path).expanduser().resolve()
if not path.exists():
raise FileNotFoundError(f"Cannot find search path: {path}")
if str(path) not in self.search.searchpath:
self.search.searchpath.insert(index, str(path))
def set_context_function(self, context: TemplateContextType) -> TemplateContextType:
if not hasattr(context, "__call__"):
raise TypeError("Context is not callable")
if not isinstance(context(self, {}), dict):
raise ValueError("Context does not return a dict object")
self.context_function = context
return context
def add_env(self, key: str, value: Any) -> None:
self.global_env[key] = value
def del_env(self, key: str) -> None:
del self.global_env[key]
def update_env(self, **kwargs: Any) -> None:
self.global_env.update(kwargs)
def add_filter(self, funct: Callable[..., Any], name: str | None = None) -> None:
name = funct.__name__ if not name else name
self.filters[name] = funct
def del_filter(self, name: str) -> None:
del self.filters[name]
def render(self,
template_name: str,
pprint: bool = False,
**context: Any,) -> str:
context.update(self.global_env.copy())
if self.context_function:
context = self.context_function(self, context)
result = self.get_template(template_name).render(context)
if pprint and template_name.lower().endswith(("haml", "jhaml", "jaml", "html", "xml")):
return minidom.parseString(result).toprettyxml(indent=" ")
return result