rework database config
This commit is contained in:
parent
fc3323faa5
commit
6cc895aa5d
|
@ -6,6 +6,7 @@ from blib import HttpDate
|
|||
from pathlib import Path
|
||||
|
||||
from . import TRANS
|
||||
from .database import check_setup
|
||||
from .enums import PermissionLevel
|
||||
from .server import Application
|
||||
|
||||
|
@ -26,20 +27,22 @@ def cli(ctx: click.Context, config: Path) -> None:
|
|||
@cli.command("setup")
|
||||
@click.pass_context
|
||||
def cli_setup(ctx: click.Context) -> None:
|
||||
if ctx.obj.config.sqlite_path.exists():
|
||||
app: Application = ctx.obj
|
||||
|
||||
if app.config.sqlite_path.exists():
|
||||
click.confirm(TRANS.fetch("setup", "prompt-database"), abort = True)
|
||||
ctx.obj.config.sqlite_path.unlink()
|
||||
app.config.sqlite_path.unlink()
|
||||
|
||||
current = HttpDate.new_utc()
|
||||
|
||||
with ctx.obj.database.session(True) as s:
|
||||
with app.database.session(True) as s:
|
||||
s.create_tables()
|
||||
|
||||
s.insert("instance", {
|
||||
"id": -99,
|
||||
"domain": ctx.obj.config.host,
|
||||
"web_domain": ctx.obj.config.web_host,
|
||||
"shared_inbox": f"https://{ctx.obj.config.web_host}/inbox",
|
||||
"domain": app.config.host,
|
||||
"web_domain": app.config.web_host,
|
||||
"shared_inbox": f"https://{app.config.web_host}/inbox",
|
||||
"name": "Barkshark Social",
|
||||
"description": "Lightweight ActivityPub server",
|
||||
"software": "barksharksocial",
|
||||
|
@ -47,20 +50,20 @@ def cli_setup(ctx: click.Context) -> None:
|
|||
"updated": current
|
||||
})
|
||||
|
||||
click.echo(TRANS.fetch("setup", "create-key", username = ctx.obj.config.host))
|
||||
instance_signer = Signer.new(f"https://{ctx.obj.config.web_host}/actor")
|
||||
click.echo(TRANS.fetch("setup", "create-key", username = app.config.host))
|
||||
instance_signer = Signer.new(f"https://{app.config.web_host}/actor")
|
||||
|
||||
click.echo(TRANS.fetch("setup", "create-key", username = "relay"))
|
||||
relay_signer = Signer.new(f"https://{ctx.obj.config.web_host}/relay")
|
||||
relay_signer = Signer.new(f"https://{app.config.web_host}/relay")
|
||||
|
||||
s.insert("user", {
|
||||
"id": -99,
|
||||
"username": ctx.obj.config.host,
|
||||
"domain": ctx.obj.config.host,
|
||||
"display_name": ctx.obj.config.host,
|
||||
"actor": f"https://{ctx.obj.config.web_host}/actor",
|
||||
"inbox": f"https://{ctx.obj.config.web_host}/inbox",
|
||||
"page_url": f"https://{ctx.obj.config.web_host}/about",
|
||||
"username": app.config.host,
|
||||
"domain": app.config.host,
|
||||
"display_name": app.config.host,
|
||||
"actor": f"https://{app.config.web_host}/actor",
|
||||
"inbox": f"https://{app.config.web_host}/inbox",
|
||||
"page_url": f"https://{app.config.web_host}/about",
|
||||
"permission": 0,
|
||||
"locked": False,
|
||||
"private_key": instance_signer.export(),
|
||||
|
@ -72,11 +75,11 @@ def cli_setup(ctx: click.Context) -> None:
|
|||
s.insert("user", {
|
||||
"id": -98,
|
||||
"username": "relay",
|
||||
"domain": ctx.obj.config.host,
|
||||
"domain": app.config.host,
|
||||
"display_name": "Relay",
|
||||
"actor": f"https://{ctx.obj.config.web_host}/relay",
|
||||
"inbox": f"https://{ctx.obj.config.web_host}/relay",
|
||||
"page_url": f"https://{ctx.obj.config.web_host}/about#relay",
|
||||
"actor": f"https://{app.config.web_host}/relay",
|
||||
"inbox": f"https://{app.config.web_host}/relay",
|
||||
"page_url": f"https://{app.config.web_host}/about#relay",
|
||||
"permission": 0,
|
||||
"locked": False,
|
||||
"private_key": relay_signer.export(),
|
||||
|
@ -85,6 +88,8 @@ def cli_setup(ctx: click.Context) -> None:
|
|||
"updated": current
|
||||
})
|
||||
|
||||
s.put_config("version", None)
|
||||
|
||||
click.echo(TRANS.fetch("setup", "done"))
|
||||
|
||||
|
||||
|
@ -93,9 +98,15 @@ def cli_setup(ctx: click.Context) -> None:
|
|||
@click.option("--dev", "-d", is_flag = True)
|
||||
@click.pass_context
|
||||
def cli_run(ctx: click.Context, language: str, dev: bool = False) -> None:
|
||||
ctx.obj.dev = dev
|
||||
ctx.obj.trans.default = language
|
||||
ctx.obj.run()
|
||||
app: Application = ctx.obj
|
||||
|
||||
if "example.com" in app.config.host or not check_setup(app.database):
|
||||
click.echo("Instance not setup yet")
|
||||
return
|
||||
|
||||
app.config.dev = dev
|
||||
app.trans.default = language
|
||||
app.run()
|
||||
|
||||
|
||||
@cli.group("user")
|
||||
|
@ -116,8 +127,10 @@ def cli_user_create(
|
|||
permissions: PermissionLevel,
|
||||
activate: bool) -> None:
|
||||
|
||||
with ctx.obj.database.session(True) as s:
|
||||
if (user := s.get_user(username, ctx.obj.config.host)) is not None:
|
||||
app: Application = ctx.obj
|
||||
|
||||
with app.database.session(True) as s:
|
||||
if (user := s.get_user(username, app.config.host)) is not None:
|
||||
TRANS.print("error", "user-exists", handle = user.handle)
|
||||
return
|
||||
|
||||
|
@ -147,7 +160,9 @@ def cli_user_create(
|
|||
@click.argument("username")
|
||||
@click.pass_context
|
||||
def cli_user_reset_pass(ctx: click.Context, username: str) -> None:
|
||||
with ctx.obj.database.session(True) as s:
|
||||
app: Application = ctx.obj
|
||||
|
||||
with app.database.session(True) as s:
|
||||
if (user := s.get_user(username, None)) is None:
|
||||
click.echo("User does not exist")
|
||||
return
|
||||
|
|
|
@ -80,6 +80,7 @@ class Config(dict[str, dict[str, Any]]):
|
|||
# Advanced
|
||||
allowed_ips: Property[list[str]] = Property("Advanced", "allowed-ips", [])
|
||||
cookie_name: Property[str] = Property("Advanced", "cookie-name", "bsocial_token")
|
||||
media_path: Property[str] = Property("Advanced", "media-path", "media")
|
||||
workers: Property[int] = Property("Advanced", "workers", len(os.sched_getaffinity(0)))
|
||||
dev: Property[bool] = Property("Advanced", "dev", False)
|
||||
|
||||
|
@ -108,6 +109,14 @@ class Config(dict[str, dict[str, Any]]):
|
|||
)
|
||||
|
||||
|
||||
@property
|
||||
def media(self) -> Path:
|
||||
if not os.path.isabs(self.media_path):
|
||||
return self.path.parent.joinpath(self.media_path)
|
||||
|
||||
return Path(self.media_path).expanduser().resolve()
|
||||
|
||||
|
||||
@property
|
||||
def sqlite_path(self) -> Path:
|
||||
return self.path.parent.joinpath("database.sqlite3")
|
||||
|
|
|
@ -2,7 +2,9 @@ import sqlite3
|
|||
|
||||
from aputils import MessageDate
|
||||
from blib import HttpDate, Enum
|
||||
from bsql import Database
|
||||
|
||||
from .config import ConfigData
|
||||
from .connection import Connection
|
||||
from .schema import SCHEMA
|
||||
|
||||
|
@ -10,3 +12,18 @@ from .schema import SCHEMA
|
|||
sqlite3.register_adapter(Enum, lambda v: v.value)
|
||||
sqlite3.register_adapter(HttpDate, lambda v: v.timestamp())
|
||||
sqlite3.register_adapter(MessageDate, lambda v: v.timestamp())
|
||||
|
||||
|
||||
def check_setup(db: Database[Connection]) -> bool:
|
||||
with db.session(False) as s:
|
||||
if "config" not in s.get_tables():
|
||||
return False
|
||||
|
||||
with s.execute("SELECT * FROM config WHERE key = 'version'") as cur:
|
||||
if (row := cur.one()) is None:
|
||||
return False
|
||||
|
||||
if int(row["value"]) < ConfigData().version:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
|
|
@ -8,7 +8,10 @@ from .. import TRANS
|
|||
|
||||
|
||||
T = TypeVar("T")
|
||||
CONVERTERS: dict[type, tuple[Callable[[str], Any], Callable[[Any], str]]] = {
|
||||
DeserializerCallback = Callable[[str], Any]
|
||||
SerializerCallback = Callable[[Any], str]
|
||||
|
||||
CONVERTERS: dict[type, tuple[DeserializerCallback, SerializerCallback]] = {
|
||||
bool: (convert_to_boolean, str),
|
||||
int: (int, str),
|
||||
dict: (JsonBase.parse, lambda value: JsonBase(value).to_json()),
|
||||
|
@ -16,24 +19,24 @@ CONVERTERS: dict[type, tuple[Callable[[str], Any], Callable[[Any], str]]] = {
|
|||
}
|
||||
|
||||
|
||||
class Value(Generic[T]):
|
||||
class ConfigValue(Generic[T]):
|
||||
__slots__: tuple[str, ...] = ("default", "type", "description", "name")
|
||||
|
||||
|
||||
def __init__(self, default: T, description: str) -> None:
|
||||
self.default: T = default
|
||||
self.type: type[Any] = type(default)
|
||||
self.type: type = type(default)
|
||||
self.description: str = description
|
||||
self.name: str = ""
|
||||
|
||||
|
||||
def __set_name__(self, owner: Any, name: str) -> None:
|
||||
self.name = name.replace("_", "-")
|
||||
self.name = ConfigValue.parse_key(name)
|
||||
|
||||
|
||||
def __get__(self, obj: dict[str, str] | None, cls: type[dict[str, str]] | None) -> Self | T:
|
||||
def __get__(self, obj: dict[str, str] | None, cls: type[dict[str, str]] | None) -> T:
|
||||
if obj is None:
|
||||
return self
|
||||
raise RuntimeError("No object")
|
||||
|
||||
try:
|
||||
return self.deserialize(obj[self.name])
|
||||
|
@ -42,10 +45,7 @@ class Value(Generic[T]):
|
|||
return self.default
|
||||
|
||||
|
||||
def __set__(self, obj: dict[str, str], value: T) -> None:
|
||||
if not isinstance(value, self.type):
|
||||
raise TypeError(f"Value '{value}' is not type '{self.type.__name__}'")
|
||||
|
||||
def __set__(self, obj: dict[str, str], value: T | str) -> None:
|
||||
obj.__setitem__(self.name, self.serialize(value), parse_value = False) # type: ignore
|
||||
|
||||
|
||||
|
@ -53,7 +53,15 @@ class Value(Generic[T]):
|
|||
self.__set__(obj, self.default)
|
||||
|
||||
|
||||
def serialize(self, value: T) -> str:
|
||||
@classmethod
|
||||
def parse_key(cls: type[Self], key: str, to_underscore: bool = False) -> str:
|
||||
if to_underscore:
|
||||
return key.replace("-", "_")
|
||||
|
||||
return key.replace("_", "-")
|
||||
|
||||
|
||||
def serialize(self, value: T | str) -> str:
|
||||
if self.type is str and isinstance(value, str):
|
||||
return value
|
||||
|
||||
|
@ -64,106 +72,106 @@ class Value(Generic[T]):
|
|||
raise TypeError(TRANS.fetch("error", "cannot-convert", type = self.type.__name__))
|
||||
|
||||
|
||||
def deserialize(self, value: str) -> T:
|
||||
def deserialize(self, value: T | str) -> T:
|
||||
if isinstance(value, self.type):
|
||||
return value # type: ignore[no-any-return]
|
||||
return value # type: ignore[return-value]
|
||||
|
||||
try:
|
||||
return CONVERTERS[self.type][0](value) # type: ignore[no-any-return]
|
||||
return CONVERTERS[self.type][0](value) # type: ignore[arg-type,no-any-return]
|
||||
|
||||
except KeyError:
|
||||
raise TypeError(TRANS.fetch("error", "cannot-convert", type = type(value).__name__))
|
||||
|
||||
|
||||
class ConfigData(dict[str, str]):
|
||||
class ConfigData(dict[str, Any]):
|
||||
__slots__: tuple[str, ...] = tuple([])
|
||||
|
||||
|
||||
version: Value[int] = Value(
|
||||
version: ConfigValue[int] = ConfigValue(
|
||||
20240409, "Version of the database schema"
|
||||
)
|
||||
|
||||
description: Value[str] = Value(
|
||||
description: ConfigValue[str] = ConfigValue(
|
||||
"Lightweight AP server", "Description for the instance"
|
||||
)
|
||||
|
||||
description_short: Value[str] = Value(
|
||||
description_short: ConfigValue[str] = ConfigValue(
|
||||
"", "Short version of the description for the instance"
|
||||
)
|
||||
|
||||
email: Value[str] = Value(
|
||||
email: ConfigValue[str] = ConfigValue(
|
||||
"", "E-mail address used for contact"
|
||||
)
|
||||
|
||||
invites_enabled: Value[bool] = Value(
|
||||
invites_enabled: ConfigValue[bool] = ConfigValue(
|
||||
True, "Allow users to invite new users"
|
||||
)
|
||||
|
||||
max_attachments: Value[int] = Value(
|
||||
max_attachments: ConfigValue[int] = ConfigValue(
|
||||
4, "Maximum number of media attachments allowed in a post"
|
||||
)
|
||||
|
||||
max_audio_size: Value[int] = Value(
|
||||
max_audio_size: ConfigValue[int] = ConfigValue(
|
||||
10, "Maximum size of uploaded audio files in mebibytes"
|
||||
)
|
||||
|
||||
max_emoji_size: Value[int] = Value(
|
||||
max_emoji_size: ConfigValue[int] = ConfigValue(
|
||||
512, "Maximum size of emojis in kibibytes"
|
||||
)
|
||||
|
||||
max_featured_tags: Value[int] = Value(
|
||||
max_featured_tags: ConfigValue[int] = ConfigValue(
|
||||
8, "Maximum number of tags to display on a user's profile"
|
||||
)
|
||||
|
||||
max_image_size: Value[int] = Value(
|
||||
max_image_size: ConfigValue[int] = ConfigValue(
|
||||
2, "Maximum size of uploaded image files in mebibytes"
|
||||
)
|
||||
|
||||
max_name_chars: Value[int] = Value(
|
||||
max_name_chars: ConfigValue[int] = ConfigValue(
|
||||
64, "Maximum number of characters allowed in a display name or handle"
|
||||
)
|
||||
|
||||
max_pinned_posts: Value[int] = Value(
|
||||
max_pinned_posts: ConfigValue[int] = ConfigValue(
|
||||
8, "Maximum number of posts that can be pinned on a profile"
|
||||
)
|
||||
|
||||
max_poll_expiration: Value[int] = Value(
|
||||
max_poll_expiration: ConfigValue[int] = ConfigValue(
|
||||
60 * 24 * 7, "Max time in minutes for a poll to be open"
|
||||
)
|
||||
|
||||
max_poll_option_chars: Value[int] = Value(
|
||||
max_poll_option_chars: ConfigValue[int] = ConfigValue(
|
||||
200, "Maximum characters allowed in a poll option"
|
||||
)
|
||||
|
||||
max_poll_options: Value[int] = Value(
|
||||
max_poll_options: ConfigValue[int] = ConfigValue(
|
||||
10, "Maxiumum number of poll options in a post"
|
||||
)
|
||||
|
||||
max_post_chars: Value[int] = Value(
|
||||
max_post_chars: ConfigValue[int] = ConfigValue(
|
||||
4096, "Maximum number of characters allowed in a post or bio"
|
||||
)
|
||||
|
||||
max_profile_fields: Value[int] = Value(
|
||||
max_profile_fields: ConfigValue[int] = ConfigValue(
|
||||
8, "Maximum number of key/value fields allowed in a profile"
|
||||
)
|
||||
|
||||
max_video_size: Value[int] = Value(
|
||||
max_video_size: ConfigValue[int] = ConfigValue(
|
||||
100, "Maximum size of uploaded video files in mebibytes"
|
||||
)
|
||||
|
||||
min_poll_expiration: Value[int] = Value(
|
||||
min_poll_expiration: ConfigValue[int] = ConfigValue(
|
||||
5, "Minimum time in minutes for a poll to be open"
|
||||
)
|
||||
|
||||
name: Value[str] = Value(
|
||||
name: ConfigValue[str] = ConfigValue(
|
||||
"Barkshark Social", "Name of the instance"
|
||||
)
|
||||
|
||||
registration_open: Value[bool] = Value(
|
||||
registration_open: ConfigValue[bool] = ConfigValue(
|
||||
True, "Allow new users to sign up on the instance"
|
||||
)
|
||||
|
||||
require_approval: Value[bool] = Value(
|
||||
require_approval: ConfigValue[bool] = ConfigValue(
|
||||
False, "Require admins/mods to manually accept new users"
|
||||
)
|
||||
|
||||
|
@ -179,16 +187,17 @@ class ConfigData(dict[str, str]):
|
|||
self.set(key, value)
|
||||
|
||||
|
||||
def __getitem__(self, raw_key: str) -> str:
|
||||
key = raw_key.replace("_", "-")
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
key = ConfigValue.parse_key(key)
|
||||
|
||||
if key not in self:
|
||||
self.__delitem__(raw_key)
|
||||
self.__delitem__(key)
|
||||
|
||||
return dict.__getitem__(self, key)
|
||||
|
||||
|
||||
def __setitem__(self, key: str, value: str, parse_value: bool = True) -> None:
|
||||
def __setitem__(self, key: str, value: Any, parse_value: bool = True) -> None:
|
||||
key = ConfigValue.parse_key(key)
|
||||
if parse_value:
|
||||
try:
|
||||
self.set(key, value)
|
||||
|
@ -197,11 +206,11 @@ class ConfigData(dict[str, str]):
|
|||
except AttributeError:
|
||||
raise KeyError(key) from None
|
||||
|
||||
dict.__setitem__(self, key.replace("_", "-"), value)
|
||||
dict.__setitem__(self, key, value)
|
||||
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
delattr(self, key.replace("-", "_"))
|
||||
delattr(self, ConfigValue.parse_key(key, True))
|
||||
|
||||
|
||||
@classmethod
|
||||
|
@ -209,24 +218,17 @@ class ConfigData(dict[str, str]):
|
|||
cfg = cls()
|
||||
|
||||
for row in rows:
|
||||
cfg.set_from_row(row)
|
||||
cfg.set(row["key"], row["value"])
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def get(self, key: str) -> Any: # type: ignore[override]
|
||||
return getattr(self, key.replace("-", "_"))
|
||||
return getattr(self, ConfigValue.parse_key(key, True))
|
||||
|
||||
|
||||
def set(self, key: str, value: Any) -> None:
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
def set_from_row(self, row: Row | None) -> None:
|
||||
if row is None:
|
||||
return
|
||||
|
||||
self[row["key"]] = row["value"]
|
||||
setattr(self, ConfigValue.parse_key(key, True), value)
|
||||
|
||||
|
||||
def update(self, data: dict[str, Any]) -> None: # type: ignore[override]
|
||||
|
|
|
@ -66,25 +66,29 @@ class Connection(BsqlConnection):
|
|||
pass
|
||||
|
||||
|
||||
def get_config(self) -> ConfigData:
|
||||
with self.execute("SELECT * FROM config") as cur:
|
||||
return ConfigData.from_rows(cur.all())
|
||||
def get_config(self, key: str) -> Any:
|
||||
cfg = ConfigData()
|
||||
|
||||
with self.execute("SELECT * FROM config WHERE key = $key", {"key": key}) as cur:
|
||||
if (row := cur.one()) is not None:
|
||||
cfg.set(key, row["value"])
|
||||
|
||||
return cfg.get(key)
|
||||
|
||||
|
||||
def set_config(self, data: ConfigData) -> ConfigData:
|
||||
for key, value in data.items():
|
||||
with self.update("config", {"value": value}, key = key) as cur:
|
||||
data.set_from_row(cur.one())
|
||||
def put_config(self, key: str, value: Any | None) -> None:
|
||||
cfg = ConfigData({key: value})
|
||||
params = {
|
||||
"key": key,
|
||||
"value": cfg.get(key)
|
||||
}
|
||||
|
||||
return data
|
||||
with self.run("put-config", params):
|
||||
pass
|
||||
|
||||
|
||||
def get_config_value(self, key: str) -> Any:
|
||||
return self.get_config().get(key)
|
||||
|
||||
|
||||
def set_config_key(self, key: str, value: Any) -> None:
|
||||
self.set_config(ConfigData({key: value}))
|
||||
def get_config_all(self) -> ConfigData:
|
||||
return ConfigData.from_rows(self.execute("SELECT * FROM config").all())
|
||||
|
||||
|
||||
def get_instance(self, domain: str) -> Instance | None:
|
||||
|
|
|
@ -3,10 +3,16 @@ from bsql import Column, Table, Tables
|
|||
SCHEMA = Tables(
|
||||
Table(
|
||||
"config",
|
||||
Column("key", "text", nullable = False, unique = True),
|
||||
Column("value", "text")
|
||||
),
|
||||
|
||||
Table(
|
||||
"media",
|
||||
Column("id", "serial"),
|
||||
Column("key", "text", nullable = False),
|
||||
Column("value", "text"),
|
||||
Column("type", "text", nullable = False)
|
||||
Column("filename", "text", nullable = False),
|
||||
Column("original_name", "text", nullable = False),
|
||||
Column("created", "timestamp")
|
||||
),
|
||||
Table(
|
||||
"instance",
|
||||
|
@ -16,6 +22,7 @@ SCHEMA = Tables(
|
|||
Column("shared_inbox", "text", nullable = False),
|
||||
Column("name", "text"),
|
||||
Column("description", "text"),
|
||||
Column("banner", "integer", foreign_key = ("media", "id")),
|
||||
Column("software", "text", nullable = False),
|
||||
Column("mod_action", "text"),
|
||||
Column("mod_action_reason", "text"),
|
||||
|
@ -30,6 +37,8 @@ SCHEMA = Tables(
|
|||
Column("username", "text", nullable = False),
|
||||
Column("domain", "text", foreign_key = ("instance", "domain")),
|
||||
Column("display_name", "text", nullable = False),
|
||||
Column("avatar", "integer", foreign_key = ("media", "id")),
|
||||
Column("banner", "integer", foreign_key = ("media", "id")),
|
||||
Column("actor", "text", nullable = False),
|
||||
Column("inbox", "text", nullable = False),
|
||||
Column("page_url", "text"),
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
-- name: put-config
|
||||
INSERT INTO config (key, value)
|
||||
VALUES($key, $value)
|
||||
ON CONFLICT(key) DO UPDATE SET value = $value
|
||||
RETURNING *;
|
||||
|
||||
-- name: get-instance
|
||||
SELECT * FROM instance WHERE
|
||||
domain = $domain or
|
||||
|
|
|
@ -5,6 +5,7 @@ from basgi import Request, Response, router
|
|||
from blib import HttpError
|
||||
|
||||
from ..processors import process_message
|
||||
from ..server import Application
|
||||
|
||||
|
||||
def ensure_signed(request: Request) -> None:
|
||||
|
@ -14,10 +15,10 @@ def ensure_signed(request: Request) -> None:
|
|||
|
||||
@router.get("BarksharkSocial", "/actor")
|
||||
async def handle_instance_actor_get(request: Request) -> Response:
|
||||
config = request.app.state.config
|
||||
app = Application.default()
|
||||
|
||||
with request.app.state.database.session(False) as s:
|
||||
if (user := s.get_user(config.host, config.host)) is None:
|
||||
with app.database.session(False) as s:
|
||||
if (user := s.get_user(app.config.host, app.config.host)) is None:
|
||||
raise HttpError(404, "User not found")
|
||||
|
||||
instance = user.get_instance(s)
|
||||
|
@ -59,9 +60,10 @@ async def handle_instance_actor_get(request: Request) -> Response:
|
|||
@router.post("BarksharkSocial", "/inbox")
|
||||
async def handle_instance_actor_post(request: Request) -> Response:
|
||||
ensure_signed(request)
|
||||
app = Application.default()
|
||||
|
||||
with request.app.state.database.session(False) as s:
|
||||
if (user := s.get_user(request.app.state.config.host)) is None:
|
||||
with app.database.session(False) as s:
|
||||
if (user := s.get_user(app.config.host)) is None:
|
||||
raise HttpError(404, "User not found")
|
||||
|
||||
message = Message.parse(await request.body())
|
||||
|
@ -71,7 +73,7 @@ async def handle_instance_actor_post(request: Request) -> Response:
|
|||
|
||||
@router.get("BarksharkSocial", "/relay")
|
||||
async def handle_relay_get(request: Request) -> Response:
|
||||
with request.app.state.database.session(False) as s:
|
||||
with Application.default().database.session(False) as s:
|
||||
if (user := s.get_user("relay")) is None:
|
||||
raise HttpError(404, "User not found")
|
||||
|
||||
|
@ -92,7 +94,7 @@ async def handle_relay_get(request: Request) -> Response:
|
|||
async def handle_relay_post(request: Request) -> Response:
|
||||
ensure_signed(request)
|
||||
|
||||
with request.app.state.database.session(False) as s:
|
||||
with Application.default().database.session(False) as s:
|
||||
if (user := s.get_user("relay")) is None:
|
||||
raise HttpError(404, "User not found")
|
||||
|
||||
|
|
|
@ -4,16 +4,11 @@ from blib import HttpDate, HttpError
|
|||
from datetime import timedelta
|
||||
|
||||
from ..misc import get_resource
|
||||
from ..server import Application
|
||||
|
||||
|
||||
@router.get("BarksharkSocial", "/")
|
||||
async def handle_home(request: Request) -> Response:
|
||||
if request.state.user is not None:
|
||||
print(request.state.user.handle)
|
||||
|
||||
else:
|
||||
print("no user")
|
||||
|
||||
return TemplateResponse("page/home.haml")
|
||||
|
||||
|
||||
|
@ -33,11 +28,11 @@ async def handle_login_get(request: Request) -> Response:
|
|||
|
||||
@router.post("BarksharkSocial", "/login")
|
||||
async def handle_login_post(request: Request) -> Response:
|
||||
state = request.app.state
|
||||
app = Application.default()
|
||||
form = await request.form()
|
||||
username, password = form.get("username"), form.get("password")
|
||||
|
||||
if username in [state.config.host, state.config.web_host, *state.config.alt_hosts, "relay"]:
|
||||
if username in [app.config.host, app.config.web_host, *app.config.alt_hosts, "relay"]:
|
||||
raise HttpError(400, "User does not exist")
|
||||
|
||||
if username is None or password is None:
|
||||
|
@ -46,23 +41,23 @@ async def handle_login_post(request: Request) -> Response:
|
|||
if not isinstance(username, str) or not isinstance(password, str):
|
||||
raise HttpError(400, "Invalid field type")
|
||||
|
||||
with state.database.session(False) as s:
|
||||
with app.database.session(False) as s:
|
||||
if (user := s.get_user(username.lower())) is None:
|
||||
raise HttpError(400, "User does not exist")
|
||||
|
||||
try:
|
||||
s.hasher.verify(user.password, password)
|
||||
s.hasher.verify(user.password, password) # type: ignore[arg-type]
|
||||
|
||||
except VerifyMismatchError:
|
||||
raise HttpError(400, "Password does not match")
|
||||
|
||||
cookie = s.put_cookie(user, request.headers.get("User-Agent"))
|
||||
|
||||
host = request.headers.get("Host", state.config.host)
|
||||
host = request.headers.get("Host", app.config.host)
|
||||
|
||||
response = Response.new_redirect(f"https://{host}/", 303)
|
||||
response.set_cookie(
|
||||
key = state.config.cookie_name,
|
||||
key = app.config.cookie_name,
|
||||
value = cookie.code,
|
||||
same_site = "strict",
|
||||
expires = HttpDate.new_utc() + timedelta(days = 30),
|
||||
|
@ -76,17 +71,17 @@ async def handle_login_post(request: Request) -> Response:
|
|||
|
||||
@router.get("BarksharkSocial", "/logout")
|
||||
async def handle_logout_get(request: Request) -> Response:
|
||||
cname = request.app.state.config.cookie_name
|
||||
app = app = Application.default()
|
||||
response = Response.new_redirect("/")
|
||||
|
||||
try:
|
||||
cookie = request.cookies[cname]
|
||||
cookie = request.cookies[app.config.cookie_name]
|
||||
|
||||
with request.app.state.database.session(True) as s:
|
||||
s.del_cookie(cookie.value)
|
||||
with app.database.session(True) as s:
|
||||
s.del_cookie(cookie.value) # type: ignore[arg-type]
|
||||
|
||||
response.cookies.append(cookie)
|
||||
response.delete_cookie(cname)
|
||||
response.delete_cookie(app.config.cookie_name)
|
||||
|
||||
except KeyError:
|
||||
pass
|
||||
|
@ -101,10 +96,12 @@ async def handle_register_get(request: Request) -> Response:
|
|||
|
||||
@router.get("BarksharkSocial", "/@{username}", "/@{username}@{domain}")
|
||||
async def handle_user_page(request: Request) -> Response:
|
||||
with request.app.state.database.session(False) as s:
|
||||
app = Application.default()
|
||||
|
||||
with app.database.session(False) as s:
|
||||
user = s.get_user(
|
||||
request.params["username"],
|
||||
request.params.get("domain", request.app.state.config.host)
|
||||
request.params.get("domain", app.config.host)
|
||||
)
|
||||
|
||||
if user is None or user.id < 1:
|
||||
|
|
|
@ -2,17 +2,18 @@ from basgi import Request, Response, router
|
|||
from typing import Any
|
||||
|
||||
from .. import __version__
|
||||
from ..server import Application
|
||||
|
||||
|
||||
@router.get("BarksharkSocial", "/api/v1/instance")
|
||||
async def handle_instance(request: Request) -> Response:
|
||||
state = request.app.state
|
||||
app = Application.default()
|
||||
|
||||
with state.database.session(False) as s:
|
||||
config = s.get_config()
|
||||
with app.database.session(False) as s:
|
||||
config = s.get_config_all()
|
||||
|
||||
data: dict[str, Any] = {
|
||||
"uri": state.config.host,
|
||||
"uri": app.config.host,
|
||||
"version": f"4.0.0 (compatible; Barkshark-Social {__version__})",
|
||||
"title": config.name,
|
||||
"short_description": config.description_short or None,
|
||||
|
@ -26,7 +27,7 @@ async def handle_instance(request: Request) -> Response:
|
|||
"en"
|
||||
],
|
||||
"urls": {
|
||||
"streaming_api": f"wss://{state.config.web_host}"
|
||||
"streaming_api": f"wss://{app.config.web_host}"
|
||||
},
|
||||
"stats": {
|
||||
"user_count": s.count_users(None),
|
||||
|
@ -121,9 +122,9 @@ async def handle_instance(request: Request) -> Response:
|
|||
|
||||
@router.get("BarksharkSocial", "/api/v1/instance/peers")
|
||||
async def handle_instance_peers(request: Request) -> Response:
|
||||
state = request.app.state
|
||||
app = Application.default()
|
||||
|
||||
with state.database.session(False) as s:
|
||||
with app.database.session(False) as s:
|
||||
data = list(s.get_peer_instances())
|
||||
|
||||
return Response.new_json(200, data)
|
||||
|
|
|
@ -3,26 +3,27 @@ from basgi import Request, Response, router
|
|||
from blib import HttpError
|
||||
|
||||
from .. import __version__
|
||||
from ..server import Application
|
||||
|
||||
|
||||
@router.get("BarksharkSocial", "/.well-known/host-meta")
|
||||
async def handle_hostmeta(request: Request) -> Response:
|
||||
config = request.app.state.config
|
||||
app = Application.default()
|
||||
|
||||
if "json" in request.content_type:
|
||||
return Response.new_json(200, HostMetaJson.new(config.host))
|
||||
return Response.new_json(200, HostMetaJson.new(app.config.host))
|
||||
|
||||
return Response(200, HostMeta.new(config.host), mimetype = "application/xrd+xml")
|
||||
return Response(200, HostMeta.new(app.config.host), mimetype = "application/xrd+xml")
|
||||
|
||||
|
||||
@router.get("BarksharkSocial", "/.well-known/host-meta.json")
|
||||
async def handle_hostmeta_json(request: Request) -> Response:
|
||||
return Response.new_json(200, HostMetaJson.new(request.app.state.config.host))
|
||||
return Response.new_json(200, HostMetaJson.new(Application.default().config.host))
|
||||
|
||||
|
||||
@router.get("BarksharkSocial", "/.well-known/webfinger")
|
||||
async def handle_webfinger(request: Request) -> Response:
|
||||
config = request.app.state.config
|
||||
app = Application.default()
|
||||
|
||||
if (res := request.query.get("resource")) is None:
|
||||
raise HttpError(400, "Missing 'resource' query parameter")
|
||||
|
@ -40,11 +41,11 @@ async def handle_webfinger(request: Request) -> Response:
|
|||
except ValueError:
|
||||
raise HttpError(400, "Invalid account format")
|
||||
|
||||
if not (request.app.state.config.check_host(domain)):
|
||||
if not (app.config.check_host(domain)):
|
||||
raise HttpError(404, "Domain not handled")
|
||||
|
||||
with request.app.state.database.session(False) as s:
|
||||
if (user := s.get_user(username, config.host)) is None:
|
||||
with app.database.session(False) as s:
|
||||
if (user := s.get_user(username, app.config.host)) is None:
|
||||
raise HttpError(404, "User not found")
|
||||
|
||||
data = Webfinger.new(
|
||||
|
@ -52,7 +53,7 @@ async def handle_webfinger(request: Request) -> Response:
|
|||
domain = user.domain,
|
||||
actor = user.actor,
|
||||
profile = user.page_url,
|
||||
interaction = f"https://{config.web_host}/interact?uri={{uri}}"
|
||||
interaction = f"https://{app.config.web_host}/interact?uri={{uri}}"
|
||||
)
|
||||
|
||||
return Response.new_json(200, data)
|
||||
|
@ -60,16 +61,16 @@ async def handle_webfinger(request: Request) -> Response:
|
|||
|
||||
@router.get("BarksharkSocial", "/.well-known/nodeinfo")
|
||||
async def handle_wk_nodeinfo(request: Request) -> Response:
|
||||
data = WellKnownNodeinfo.new_template(request.app.state.config.host)
|
||||
data = WellKnownNodeinfo.new_template(Application.default().config.host)
|
||||
return Response.new_json(200, data)
|
||||
|
||||
|
||||
@router.get("BarksharkSocial", "/nodeinfo/{version:float}.json", "/nodeinfo/{version:float}")
|
||||
async def handle_nodeinfo(request: Request) -> Response:
|
||||
ni_version = request.params["version"]
|
||||
config = request.app.state.config
|
||||
app = Application.default()
|
||||
|
||||
with request.app.state.database.session(False) as s:
|
||||
with app.database.session(False) as s:
|
||||
data = Nodeinfo.new(
|
||||
"barksharksocial", __version__,
|
||||
protocols = [NodeinfoProtocol.ACTIVITYPUB],
|
||||
|
@ -78,7 +79,7 @@ async def handle_nodeinfo(request: Request) -> Response:
|
|||
repo = "https://git.barkshark.xyz/barkshark/social",
|
||||
homepage = "https://docs.barkshark.xyz/social",
|
||||
open_regs = True,
|
||||
users = s.count_users(config.host),
|
||||
users = s.count_users(app.config.host),
|
||||
halfyear = 0,
|
||||
month = 0,
|
||||
posts = s.count_local_posts(),
|
||||
|
|
|
@ -42,19 +42,25 @@ class Application(App[Request, RequestState, AppState]):
|
|||
self.state.setup(cfg_path, language)
|
||||
self.error_handlers[HttpError] = self.handle_http_exception
|
||||
|
||||
if self.state.config.proxy_enabled:
|
||||
self.client = Client(self.name, self.state.config.proxy_url)
|
||||
if self.config.proxy_enabled:
|
||||
self.client = Client(self.name, self.config.proxy_url)
|
||||
|
||||
self.client.useragent = f"BarksharkSocial/{__version__} (https://{self.config.web_host})"
|
||||
|
||||
self.add_static(
|
||||
"/static", get_resource("frontend/static"), cached = not self.state.config.dev
|
||||
"/static", get_resource("frontend/static"), cached = not self.config.dev
|
||||
)
|
||||
|
||||
self.on_request.connect(mw.FrontendAuthMiddleware)
|
||||
self.on_request.connect(mw.ActivitypubAuthMiddleware)
|
||||
|
||||
|
||||
# I broke generics on the basgi.Application class, so this is the workaround for now
|
||||
@staticmethod
|
||||
def default() -> Application:
|
||||
return App.get("BarksharkSocial") # type: ignore[return-value]
|
||||
|
||||
|
||||
@property
|
||||
def config(self) -> Config:
|
||||
return self.state.config
|
||||
|
@ -86,14 +92,14 @@ class Application(App[Request, RequestState, AppState]):
|
|||
"--no-access-log"
|
||||
]
|
||||
|
||||
if self.state.config.dev:
|
||||
if self.config.dev:
|
||||
cmd.extend(["--reload", "--reload-dir", str(Path(__file__).parent)])
|
||||
|
||||
else:
|
||||
cmd.extend(["--workers", str(self.config.workers)])
|
||||
|
||||
env = os.environ.copy()
|
||||
env["SOCIAL_CONFIG_PATH"] = str(self.state.config.path)
|
||||
env["SOCIAL_CONFIG_PATH"] = str(self.config.path)
|
||||
subprocess.run(cmd, env = env)
|
||||
|
||||
|
||||
|
@ -110,11 +116,11 @@ class Application(App[Request, RequestState, AppState]):
|
|||
"--opt"
|
||||
]
|
||||
|
||||
if self.state.config.dev:
|
||||
if self.config.dev:
|
||||
cmd.append("--reload")
|
||||
|
||||
env = os.environ.copy()
|
||||
env["SOCIAL_CONFIG_PATH"] = str(self.state.config.path)
|
||||
env["SOCIAL_CONFIG_PATH"] = str(self.config.path)
|
||||
subprocess.run(cmd, env = env)
|
||||
|
||||
|
||||
|
|
|
@ -40,6 +40,7 @@ dependencies = [
|
|||
"barkshark-lib >= 0.1.1",
|
||||
"barkshark-sql == 0.1.4",
|
||||
"click == 8.1.7",
|
||||
"pillow == 10.3.0",
|
||||
"pymemcache == 4.0.0",
|
||||
"pyyaml == 6.0.1",
|
||||
|
||||
|
|
Loading…
Reference in a new issue