rework database config

This commit is contained in:
Izalia Mae 2024-04-21 00:50:45 -04:00
parent fc3323faa5
commit 6cc895aa5d
13 changed files with 218 additions and 148 deletions

View file

@ -6,6 +6,7 @@ from blib import HttpDate
from pathlib import Path
from . import TRANS
from .database import check_setup
from .enums import PermissionLevel
from .server import Application
@ -26,20 +27,22 @@ def cli(ctx: click.Context, config: Path) -> None:
@cli.command("setup")
@click.pass_context
def cli_setup(ctx: click.Context) -> None:
if ctx.obj.config.sqlite_path.exists():
app: Application = ctx.obj
if app.config.sqlite_path.exists():
click.confirm(TRANS.fetch("setup", "prompt-database"), abort = True)
ctx.obj.config.sqlite_path.unlink()
app.config.sqlite_path.unlink()
current = HttpDate.new_utc()
with ctx.obj.database.session(True) as s:
with app.database.session(True) as s:
s.create_tables()
s.insert("instance", {
"id": -99,
"domain": ctx.obj.config.host,
"web_domain": ctx.obj.config.web_host,
"shared_inbox": f"https://{ctx.obj.config.web_host}/inbox",
"domain": app.config.host,
"web_domain": app.config.web_host,
"shared_inbox": f"https://{app.config.web_host}/inbox",
"name": "Barkshark Social",
"description": "Lightweight ActivityPub server",
"software": "barksharksocial",
@ -47,20 +50,20 @@ def cli_setup(ctx: click.Context) -> None:
"updated": current
})
click.echo(TRANS.fetch("setup", "create-key", username = ctx.obj.config.host))
instance_signer = Signer.new(f"https://{ctx.obj.config.web_host}/actor")
click.echo(TRANS.fetch("setup", "create-key", username = app.config.host))
instance_signer = Signer.new(f"https://{app.config.web_host}/actor")
click.echo(TRANS.fetch("setup", "create-key", username = "relay"))
relay_signer = Signer.new(f"https://{ctx.obj.config.web_host}/relay")
relay_signer = Signer.new(f"https://{app.config.web_host}/relay")
s.insert("user", {
"id": -99,
"username": ctx.obj.config.host,
"domain": ctx.obj.config.host,
"display_name": ctx.obj.config.host,
"actor": f"https://{ctx.obj.config.web_host}/actor",
"inbox": f"https://{ctx.obj.config.web_host}/inbox",
"page_url": f"https://{ctx.obj.config.web_host}/about",
"username": app.config.host,
"domain": app.config.host,
"display_name": app.config.host,
"actor": f"https://{app.config.web_host}/actor",
"inbox": f"https://{app.config.web_host}/inbox",
"page_url": f"https://{app.config.web_host}/about",
"permission": 0,
"locked": False,
"private_key": instance_signer.export(),
@ -72,11 +75,11 @@ def cli_setup(ctx: click.Context) -> None:
s.insert("user", {
"id": -98,
"username": "relay",
"domain": ctx.obj.config.host,
"domain": app.config.host,
"display_name": "Relay",
"actor": f"https://{ctx.obj.config.web_host}/relay",
"inbox": f"https://{ctx.obj.config.web_host}/relay",
"page_url": f"https://{ctx.obj.config.web_host}/about#relay",
"actor": f"https://{app.config.web_host}/relay",
"inbox": f"https://{app.config.web_host}/relay",
"page_url": f"https://{app.config.web_host}/about#relay",
"permission": 0,
"locked": False,
"private_key": relay_signer.export(),
@ -85,6 +88,8 @@ def cli_setup(ctx: click.Context) -> None:
"updated": current
})
s.put_config("version", None)
click.echo(TRANS.fetch("setup", "done"))
@ -93,9 +98,15 @@ def cli_setup(ctx: click.Context) -> None:
@click.option("--dev", "-d", is_flag = True)
@click.pass_context
def cli_run(ctx: click.Context, language: str, dev: bool = False) -> None:
ctx.obj.dev = dev
ctx.obj.trans.default = language
ctx.obj.run()
app: Application = ctx.obj
if "example.com" in app.config.host or not check_setup(app.database):
click.echo("Instance not setup yet")
return
app.config.dev = dev
app.trans.default = language
app.run()
@cli.group("user")
@ -116,8 +127,10 @@ def cli_user_create(
permissions: PermissionLevel,
activate: bool) -> None:
with ctx.obj.database.session(True) as s:
if (user := s.get_user(username, ctx.obj.config.host)) is not None:
app: Application = ctx.obj
with app.database.session(True) as s:
if (user := s.get_user(username, app.config.host)) is not None:
TRANS.print("error", "user-exists", handle = user.handle)
return
@ -147,7 +160,9 @@ def cli_user_create(
@click.argument("username")
@click.pass_context
def cli_user_reset_pass(ctx: click.Context, username: str) -> None:
with ctx.obj.database.session(True) as s:
app: Application = ctx.obj
with app.database.session(True) as s:
if (user := s.get_user(username, None)) is None:
click.echo("User does not exist")
return

View file

@ -80,6 +80,7 @@ class Config(dict[str, dict[str, Any]]):
# Advanced
allowed_ips: Property[list[str]] = Property("Advanced", "allowed-ips", [])
cookie_name: Property[str] = Property("Advanced", "cookie-name", "bsocial_token")
media_path: Property[str] = Property("Advanced", "media-path", "media")
workers: Property[int] = Property("Advanced", "workers", len(os.sched_getaffinity(0)))
dev: Property[bool] = Property("Advanced", "dev", False)
@ -108,6 +109,14 @@ class Config(dict[str, dict[str, Any]]):
)
@property
def media(self) -> Path:
if not os.path.isabs(self.media_path):
return self.path.parent.joinpath(self.media_path)
return Path(self.media_path).expanduser().resolve()
@property
def sqlite_path(self) -> Path:
return self.path.parent.joinpath("database.sqlite3")

View file

@ -2,7 +2,9 @@ import sqlite3
from aputils import MessageDate
from blib import HttpDate, Enum
from bsql import Database
from .config import ConfigData
from .connection import Connection
from .schema import SCHEMA
@ -10,3 +12,18 @@ from .schema import SCHEMA
sqlite3.register_adapter(Enum, lambda v: v.value)
sqlite3.register_adapter(HttpDate, lambda v: v.timestamp())
sqlite3.register_adapter(MessageDate, lambda v: v.timestamp())
def check_setup(db: Database[Connection]) -> bool:
with db.session(False) as s:
if "config" not in s.get_tables():
return False
with s.execute("SELECT * FROM config WHERE key = 'version'") as cur:
if (row := cur.one()) is None:
return False
if int(row["value"]) < ConfigData().version:
return False
return True

View file

@ -8,7 +8,10 @@ from .. import TRANS
T = TypeVar("T")
CONVERTERS: dict[type, tuple[Callable[[str], Any], Callable[[Any], str]]] = {
DeserializerCallback = Callable[[str], Any]
SerializerCallback = Callable[[Any], str]
CONVERTERS: dict[type, tuple[DeserializerCallback, SerializerCallback]] = {
bool: (convert_to_boolean, str),
int: (int, str),
dict: (JsonBase.parse, lambda value: JsonBase(value).to_json()),
@ -16,24 +19,24 @@ CONVERTERS: dict[type, tuple[Callable[[str], Any], Callable[[Any], str]]] = {
}
class Value(Generic[T]):
class ConfigValue(Generic[T]):
__slots__: tuple[str, ...] = ("default", "type", "description", "name")
def __init__(self, default: T, description: str) -> None:
self.default: T = default
self.type: type[Any] = type(default)
self.type: type = type(default)
self.description: str = description
self.name: str = ""
def __set_name__(self, owner: Any, name: str) -> None:
self.name = name.replace("_", "-")
self.name = ConfigValue.parse_key(name)
def __get__(self, obj: dict[str, str] | None, cls: type[dict[str, str]] | None) -> Self | T:
def __get__(self, obj: dict[str, str] | None, cls: type[dict[str, str]] | None) -> T:
if obj is None:
return self
raise RuntimeError("No object")
try:
return self.deserialize(obj[self.name])
@ -42,10 +45,7 @@ class Value(Generic[T]):
return self.default
def __set__(self, obj: dict[str, str], value: T) -> None:
if not isinstance(value, self.type):
raise TypeError(f"Value '{value}' is not type '{self.type.__name__}'")
def __set__(self, obj: dict[str, str], value: T | str) -> None:
obj.__setitem__(self.name, self.serialize(value), parse_value = False) # type: ignore
@ -53,7 +53,15 @@ class Value(Generic[T]):
self.__set__(obj, self.default)
def serialize(self, value: T) -> str:
@classmethod
def parse_key(cls: type[Self], key: str, to_underscore: bool = False) -> str:
if to_underscore:
return key.replace("-", "_")
return key.replace("_", "-")
def serialize(self, value: T | str) -> str:
if self.type is str and isinstance(value, str):
return value
@ -64,106 +72,106 @@ class Value(Generic[T]):
raise TypeError(TRANS.fetch("error", "cannot-convert", type = self.type.__name__))
def deserialize(self, value: str) -> T:
def deserialize(self, value: T | str) -> T:
if isinstance(value, self.type):
return value # type: ignore[no-any-return]
return value # type: ignore[return-value]
try:
return CONVERTERS[self.type][0](value) # type: ignore[no-any-return]
return CONVERTERS[self.type][0](value) # type: ignore[arg-type,no-any-return]
except KeyError:
raise TypeError(TRANS.fetch("error", "cannot-convert", type = type(value).__name__))
class ConfigData(dict[str, str]):
class ConfigData(dict[str, Any]):
__slots__: tuple[str, ...] = tuple([])
version: Value[int] = Value(
version: ConfigValue[int] = ConfigValue(
20240409, "Version of the database schema"
)
description: Value[str] = Value(
description: ConfigValue[str] = ConfigValue(
"Lightweight AP server", "Description for the instance"
)
description_short: Value[str] = Value(
description_short: ConfigValue[str] = ConfigValue(
"", "Short version of the description for the instance"
)
email: Value[str] = Value(
email: ConfigValue[str] = ConfigValue(
"", "E-mail address used for contact"
)
invites_enabled: Value[bool] = Value(
invites_enabled: ConfigValue[bool] = ConfigValue(
True, "Allow users to invite new users"
)
max_attachments: Value[int] = Value(
max_attachments: ConfigValue[int] = ConfigValue(
4, "Maximum number of media attachments allowed in a post"
)
max_audio_size: Value[int] = Value(
max_audio_size: ConfigValue[int] = ConfigValue(
10, "Maximum size of uploaded audio files in mebibytes"
)
max_emoji_size: Value[int] = Value(
max_emoji_size: ConfigValue[int] = ConfigValue(
512, "Maximum size of emojis in kibibytes"
)
max_featured_tags: Value[int] = Value(
max_featured_tags: ConfigValue[int] = ConfigValue(
8, "Maximum number of tags to display on a user's profile"
)
max_image_size: Value[int] = Value(
max_image_size: ConfigValue[int] = ConfigValue(
2, "Maximum size of uploaded image files in mebibytes"
)
max_name_chars: Value[int] = Value(
max_name_chars: ConfigValue[int] = ConfigValue(
64, "Maximum number of characters allowed in a display name or handle"
)
max_pinned_posts: Value[int] = Value(
max_pinned_posts: ConfigValue[int] = ConfigValue(
8, "Maximum number of posts that can be pinned on a profile"
)
max_poll_expiration: Value[int] = Value(
max_poll_expiration: ConfigValue[int] = ConfigValue(
60 * 24 * 7, "Max time in minutes for a poll to be open"
)
max_poll_option_chars: Value[int] = Value(
max_poll_option_chars: ConfigValue[int] = ConfigValue(
200, "Maximum characters allowed in a poll option"
)
max_poll_options: Value[int] = Value(
max_poll_options: ConfigValue[int] = ConfigValue(
10, "Maxiumum number of poll options in a post"
)
max_post_chars: Value[int] = Value(
max_post_chars: ConfigValue[int] = ConfigValue(
4096, "Maximum number of characters allowed in a post or bio"
)
max_profile_fields: Value[int] = Value(
max_profile_fields: ConfigValue[int] = ConfigValue(
8, "Maximum number of key/value fields allowed in a profile"
)
max_video_size: Value[int] = Value(
max_video_size: ConfigValue[int] = ConfigValue(
100, "Maximum size of uploaded video files in mebibytes"
)
min_poll_expiration: Value[int] = Value(
min_poll_expiration: ConfigValue[int] = ConfigValue(
5, "Minimum time in minutes for a poll to be open"
)
name: Value[str] = Value(
name: ConfigValue[str] = ConfigValue(
"Barkshark Social", "Name of the instance"
)
registration_open: Value[bool] = Value(
registration_open: ConfigValue[bool] = ConfigValue(
True, "Allow new users to sign up on the instance"
)
require_approval: Value[bool] = Value(
require_approval: ConfigValue[bool] = ConfigValue(
False, "Require admins/mods to manually accept new users"
)
@ -179,16 +187,17 @@ class ConfigData(dict[str, str]):
self.set(key, value)
def __getitem__(self, raw_key: str) -> str:
key = raw_key.replace("_", "-")
def __getitem__(self, key: str) -> Any:
key = ConfigValue.parse_key(key)
if key not in self:
self.__delitem__(raw_key)
self.__delitem__(key)
return dict.__getitem__(self, key)
def __setitem__(self, key: str, value: str, parse_value: bool = True) -> None:
def __setitem__(self, key: str, value: Any, parse_value: bool = True) -> None:
key = ConfigValue.parse_key(key)
if parse_value:
try:
self.set(key, value)
@ -197,11 +206,11 @@ class ConfigData(dict[str, str]):
except AttributeError:
raise KeyError(key) from None
dict.__setitem__(self, key.replace("_", "-"), value)
dict.__setitem__(self, key, value)
def __delitem__(self, key: str) -> None:
delattr(self, key.replace("-", "_"))
delattr(self, ConfigValue.parse_key(key, True))
@classmethod
@ -209,24 +218,17 @@ class ConfigData(dict[str, str]):
cfg = cls()
for row in rows:
cfg.set_from_row(row)
cfg.set(row["key"], row["value"])
return cfg
def get(self, key: str) -> Any: # type: ignore[override]
return getattr(self, key.replace("-", "_"))
return getattr(self, ConfigValue.parse_key(key, True))
def set(self, key: str, value: Any) -> None:
setattr(self, key, value)
def set_from_row(self, row: Row | None) -> None:
if row is None:
return
self[row["key"]] = row["value"]
setattr(self, ConfigValue.parse_key(key, True), value)
def update(self, data: dict[str, Any]) -> None: # type: ignore[override]

View file

@ -66,25 +66,29 @@ class Connection(BsqlConnection):
pass
def get_config(self) -> ConfigData:
with self.execute("SELECT * FROM config") as cur:
return ConfigData.from_rows(cur.all())
def get_config(self, key: str) -> Any:
cfg = ConfigData()
with self.execute("SELECT * FROM config WHERE key = $key", {"key": key}) as cur:
if (row := cur.one()) is not None:
cfg.set(key, row["value"])
return cfg.get(key)
def set_config(self, data: ConfigData) -> ConfigData:
for key, value in data.items():
with self.update("config", {"value": value}, key = key) as cur:
data.set_from_row(cur.one())
def put_config(self, key: str, value: Any | None) -> None:
cfg = ConfigData({key: value})
params = {
"key": key,
"value": cfg.get(key)
}
return data
with self.run("put-config", params):
pass
def get_config_value(self, key: str) -> Any:
return self.get_config().get(key)
def set_config_key(self, key: str, value: Any) -> None:
self.set_config(ConfigData({key: value}))
def get_config_all(self) -> ConfigData:
return ConfigData.from_rows(self.execute("SELECT * FROM config").all())
def get_instance(self, domain: str) -> Instance | None:

View file

@ -3,10 +3,16 @@ from bsql import Column, Table, Tables
SCHEMA = Tables(
Table(
"config",
Column("key", "text", nullable = False, unique = True),
Column("value", "text")
),
Table(
"media",
Column("id", "serial"),
Column("key", "text", nullable = False),
Column("value", "text"),
Column("type", "text", nullable = False)
Column("filename", "text", nullable = False),
Column("original_name", "text", nullable = False),
Column("created", "timestamp")
),
Table(
"instance",
@ -16,6 +22,7 @@ SCHEMA = Tables(
Column("shared_inbox", "text", nullable = False),
Column("name", "text"),
Column("description", "text"),
Column("banner", "integer", foreign_key = ("media", "id")),
Column("software", "text", nullable = False),
Column("mod_action", "text"),
Column("mod_action_reason", "text"),
@ -30,6 +37,8 @@ SCHEMA = Tables(
Column("username", "text", nullable = False),
Column("domain", "text", foreign_key = ("instance", "domain")),
Column("display_name", "text", nullable = False),
Column("avatar", "integer", foreign_key = ("media", "id")),
Column("banner", "integer", foreign_key = ("media", "id")),
Column("actor", "text", nullable = False),
Column("inbox", "text", nullable = False),
Column("page_url", "text"),

View file

@ -1,3 +1,9 @@
-- name: put-config
INSERT INTO config (key, value)
VALUES($key, $value)
ON CONFLICT(key) DO UPDATE SET value = $value
RETURNING *;
-- name: get-instance
SELECT * FROM instance WHERE
domain = $domain or

View file

@ -5,6 +5,7 @@ from basgi import Request, Response, router
from blib import HttpError
from ..processors import process_message
from ..server import Application
def ensure_signed(request: Request) -> None:
@ -14,10 +15,10 @@ def ensure_signed(request: Request) -> None:
@router.get("BarksharkSocial", "/actor")
async def handle_instance_actor_get(request: Request) -> Response:
config = request.app.state.config
app = Application.default()
with request.app.state.database.session(False) as s:
if (user := s.get_user(config.host, config.host)) is None:
with app.database.session(False) as s:
if (user := s.get_user(app.config.host, app.config.host)) is None:
raise HttpError(404, "User not found")
instance = user.get_instance(s)
@ -59,9 +60,10 @@ async def handle_instance_actor_get(request: Request) -> Response:
@router.post("BarksharkSocial", "/inbox")
async def handle_instance_actor_post(request: Request) -> Response:
ensure_signed(request)
app = Application.default()
with request.app.state.database.session(False) as s:
if (user := s.get_user(request.app.state.config.host)) is None:
with app.database.session(False) as s:
if (user := s.get_user(app.config.host)) is None:
raise HttpError(404, "User not found")
message = Message.parse(await request.body())
@ -71,7 +73,7 @@ async def handle_instance_actor_post(request: Request) -> Response:
@router.get("BarksharkSocial", "/relay")
async def handle_relay_get(request: Request) -> Response:
with request.app.state.database.session(False) as s:
with Application.default().database.session(False) as s:
if (user := s.get_user("relay")) is None:
raise HttpError(404, "User not found")
@ -92,7 +94,7 @@ async def handle_relay_get(request: Request) -> Response:
async def handle_relay_post(request: Request) -> Response:
ensure_signed(request)
with request.app.state.database.session(False) as s:
with Application.default().database.session(False) as s:
if (user := s.get_user("relay")) is None:
raise HttpError(404, "User not found")

View file

@ -4,16 +4,11 @@ from blib import HttpDate, HttpError
from datetime import timedelta
from ..misc import get_resource
from ..server import Application
@router.get("BarksharkSocial", "/")
async def handle_home(request: Request) -> Response:
if request.state.user is not None:
print(request.state.user.handle)
else:
print("no user")
return TemplateResponse("page/home.haml")
@ -33,11 +28,11 @@ async def handle_login_get(request: Request) -> Response:
@router.post("BarksharkSocial", "/login")
async def handle_login_post(request: Request) -> Response:
state = request.app.state
app = Application.default()
form = await request.form()
username, password = form.get("username"), form.get("password")
if username in [state.config.host, state.config.web_host, *state.config.alt_hosts, "relay"]:
if username in [app.config.host, app.config.web_host, *app.config.alt_hosts, "relay"]:
raise HttpError(400, "User does not exist")
if username is None or password is None:
@ -46,23 +41,23 @@ async def handle_login_post(request: Request) -> Response:
if not isinstance(username, str) or not isinstance(password, str):
raise HttpError(400, "Invalid field type")
with state.database.session(False) as s:
with app.database.session(False) as s:
if (user := s.get_user(username.lower())) is None:
raise HttpError(400, "User does not exist")
try:
s.hasher.verify(user.password, password)
s.hasher.verify(user.password, password) # type: ignore[arg-type]
except VerifyMismatchError:
raise HttpError(400, "Password does not match")
cookie = s.put_cookie(user, request.headers.get("User-Agent"))
host = request.headers.get("Host", state.config.host)
host = request.headers.get("Host", app.config.host)
response = Response.new_redirect(f"https://{host}/", 303)
response.set_cookie(
key = state.config.cookie_name,
key = app.config.cookie_name,
value = cookie.code,
same_site = "strict",
expires = HttpDate.new_utc() + timedelta(days = 30),
@ -76,17 +71,17 @@ async def handle_login_post(request: Request) -> Response:
@router.get("BarksharkSocial", "/logout")
async def handle_logout_get(request: Request) -> Response:
cname = request.app.state.config.cookie_name
app = app = Application.default()
response = Response.new_redirect("/")
try:
cookie = request.cookies[cname]
cookie = request.cookies[app.config.cookie_name]
with request.app.state.database.session(True) as s:
s.del_cookie(cookie.value)
with app.database.session(True) as s:
s.del_cookie(cookie.value) # type: ignore[arg-type]
response.cookies.append(cookie)
response.delete_cookie(cname)
response.delete_cookie(app.config.cookie_name)
except KeyError:
pass
@ -101,10 +96,12 @@ async def handle_register_get(request: Request) -> Response:
@router.get("BarksharkSocial", "/@{username}", "/@{username}@{domain}")
async def handle_user_page(request: Request) -> Response:
with request.app.state.database.session(False) as s:
app = Application.default()
with app.database.session(False) as s:
user = s.get_user(
request.params["username"],
request.params.get("domain", request.app.state.config.host)
request.params.get("domain", app.config.host)
)
if user is None or user.id < 1:

View file

@ -2,17 +2,18 @@ from basgi import Request, Response, router
from typing import Any
from .. import __version__
from ..server import Application
@router.get("BarksharkSocial", "/api/v1/instance")
async def handle_instance(request: Request) -> Response:
state = request.app.state
app = Application.default()
with state.database.session(False) as s:
config = s.get_config()
with app.database.session(False) as s:
config = s.get_config_all()
data: dict[str, Any] = {
"uri": state.config.host,
"uri": app.config.host,
"version": f"4.0.0 (compatible; Barkshark-Social {__version__})",
"title": config.name,
"short_description": config.description_short or None,
@ -26,7 +27,7 @@ async def handle_instance(request: Request) -> Response:
"en"
],
"urls": {
"streaming_api": f"wss://{state.config.web_host}"
"streaming_api": f"wss://{app.config.web_host}"
},
"stats": {
"user_count": s.count_users(None),
@ -121,9 +122,9 @@ async def handle_instance(request: Request) -> Response:
@router.get("BarksharkSocial", "/api/v1/instance/peers")
async def handle_instance_peers(request: Request) -> Response:
state = request.app.state
app = Application.default()
with state.database.session(False) as s:
with app.database.session(False) as s:
data = list(s.get_peer_instances())
return Response.new_json(200, data)

View file

@ -3,26 +3,27 @@ from basgi import Request, Response, router
from blib import HttpError
from .. import __version__
from ..server import Application
@router.get("BarksharkSocial", "/.well-known/host-meta")
async def handle_hostmeta(request: Request) -> Response:
config = request.app.state.config
app = Application.default()
if "json" in request.content_type:
return Response.new_json(200, HostMetaJson.new(config.host))
return Response.new_json(200, HostMetaJson.new(app.config.host))
return Response(200, HostMeta.new(config.host), mimetype = "application/xrd+xml")
return Response(200, HostMeta.new(app.config.host), mimetype = "application/xrd+xml")
@router.get("BarksharkSocial", "/.well-known/host-meta.json")
async def handle_hostmeta_json(request: Request) -> Response:
return Response.new_json(200, HostMetaJson.new(request.app.state.config.host))
return Response.new_json(200, HostMetaJson.new(Application.default().config.host))
@router.get("BarksharkSocial", "/.well-known/webfinger")
async def handle_webfinger(request: Request) -> Response:
config = request.app.state.config
app = Application.default()
if (res := request.query.get("resource")) is None:
raise HttpError(400, "Missing 'resource' query parameter")
@ -40,11 +41,11 @@ async def handle_webfinger(request: Request) -> Response:
except ValueError:
raise HttpError(400, "Invalid account format")
if not (request.app.state.config.check_host(domain)):
if not (app.config.check_host(domain)):
raise HttpError(404, "Domain not handled")
with request.app.state.database.session(False) as s:
if (user := s.get_user(username, config.host)) is None:
with app.database.session(False) as s:
if (user := s.get_user(username, app.config.host)) is None:
raise HttpError(404, "User not found")
data = Webfinger.new(
@ -52,7 +53,7 @@ async def handle_webfinger(request: Request) -> Response:
domain = user.domain,
actor = user.actor,
profile = user.page_url,
interaction = f"https://{config.web_host}/interact?uri={{uri}}"
interaction = f"https://{app.config.web_host}/interact?uri={{uri}}"
)
return Response.new_json(200, data)
@ -60,16 +61,16 @@ async def handle_webfinger(request: Request) -> Response:
@router.get("BarksharkSocial", "/.well-known/nodeinfo")
async def handle_wk_nodeinfo(request: Request) -> Response:
data = WellKnownNodeinfo.new_template(request.app.state.config.host)
data = WellKnownNodeinfo.new_template(Application.default().config.host)
return Response.new_json(200, data)
@router.get("BarksharkSocial", "/nodeinfo/{version:float}.json", "/nodeinfo/{version:float}")
async def handle_nodeinfo(request: Request) -> Response:
ni_version = request.params["version"]
config = request.app.state.config
app = Application.default()
with request.app.state.database.session(False) as s:
with app.database.session(False) as s:
data = Nodeinfo.new(
"barksharksocial", __version__,
protocols = [NodeinfoProtocol.ACTIVITYPUB],
@ -78,7 +79,7 @@ async def handle_nodeinfo(request: Request) -> Response:
repo = "https://git.barkshark.xyz/barkshark/social",
homepage = "https://docs.barkshark.xyz/social",
open_regs = True,
users = s.count_users(config.host),
users = s.count_users(app.config.host),
halfyear = 0,
month = 0,
posts = s.count_local_posts(),

View file

@ -42,19 +42,25 @@ class Application(App[Request, RequestState, AppState]):
self.state.setup(cfg_path, language)
self.error_handlers[HttpError] = self.handle_http_exception
if self.state.config.proxy_enabled:
self.client = Client(self.name, self.state.config.proxy_url)
if self.config.proxy_enabled:
self.client = Client(self.name, self.config.proxy_url)
self.client.useragent = f"BarksharkSocial/{__version__} (https://{self.config.web_host})"
self.add_static(
"/static", get_resource("frontend/static"), cached = not self.state.config.dev
"/static", get_resource("frontend/static"), cached = not self.config.dev
)
self.on_request.connect(mw.FrontendAuthMiddleware)
self.on_request.connect(mw.ActivitypubAuthMiddleware)
# I broke generics on the basgi.Application class, so this is the workaround for now
@staticmethod
def default() -> Application:
return App.get("BarksharkSocial") # type: ignore[return-value]
@property
def config(self) -> Config:
return self.state.config
@ -86,14 +92,14 @@ class Application(App[Request, RequestState, AppState]):
"--no-access-log"
]
if self.state.config.dev:
if self.config.dev:
cmd.extend(["--reload", "--reload-dir", str(Path(__file__).parent)])
else:
cmd.extend(["--workers", str(self.config.workers)])
env = os.environ.copy()
env["SOCIAL_CONFIG_PATH"] = str(self.state.config.path)
env["SOCIAL_CONFIG_PATH"] = str(self.config.path)
subprocess.run(cmd, env = env)
@ -110,11 +116,11 @@ class Application(App[Request, RequestState, AppState]):
"--opt"
]
if self.state.config.dev:
if self.config.dev:
cmd.append("--reload")
env = os.environ.copy()
env["SOCIAL_CONFIG_PATH"] = str(self.state.config.path)
env["SOCIAL_CONFIG_PATH"] = str(self.config.path)
subprocess.run(cmd, env = env)

View file

@ -40,6 +40,7 @@ dependencies = [
"barkshark-lib >= 0.1.1",
"barkshark-sql == 0.1.4",
"click == 8.1.7",
"pillow == 10.3.0",
"pymemcache == 4.0.0",
"pyyaml == 6.0.1",