social/barkshark_social/database/connection.py
2024-04-18 22:36:06 -04:00

376 lines
9.5 KiB
Python

from __future__ import annotations
import secrets
from aputils import Message, MessageDate, Nodeinfo, ObjectType, Signer, Webfinger
from argon2 import PasswordHasher
from basgi import Application as App
from blib import HttpDate
from bsql import Connection as BsqlConnection, Row
from collections.abc import Iterator
from typing import TYPE_CHECKING, Any, TypeVar
from .config import ConfigData
from .objects import ObjectMixin, Cookie, Follow, Instance, User
from .. import TRANS
from ..config import Config
from ..enums import PermissionLevel
if TYPE_CHECKING:
from ..server import Application
T = TypeVar("T", bound = ObjectMixin)
class Connection(BsqlConnection):
@property
def app(self) -> Application:
return App.get("BarksharkSocial") # type: ignore[return-value]
@property
def config(self) -> Config:
return self.app.state.config
@property
def hasher(self) -> PasswordHasher:
return self.app.state.hasher
def insert_row(self, item: T) -> T:
with self.insert(item.table(), item.to_dict(item.id != 0)) as cur:
if (row := item.__class__.from_possible_row(cur.one())) is None:
raise ValueError("Failed to insert row")
return row
def update_row(self, item: T) -> T:
if item.id is None or item.id == 0:
raise ValueError("ID field is empty")
with self.update(item.table(), item.to_dict(False), id = item.id) as cur:
if (row := item.__class__.from_possible_row(cur.one())) is None:
raise ValueError("Failed to update row")
return row
def delete_row(self, item: T) -> None:
if item.id is None or item.id == 0:
raise ValueError("ID field is empty")
with self.delete(item.table(), id = item.id):
pass
def get_config(self) -> ConfigData:
with self.execute("SELECT * FROM config") as cur:
return ConfigData.from_rows(cur.all())
def set_config(self, data: ConfigData) -> ConfigData:
for key, value in data.items():
with self.update("config", {"value": value}, key = key) as cur:
data.set_from_row(cur.one())
return data
def get_config_value(self, key: str) -> Any:
return self.get_config().get(key)
def set_config_key(self, key: str, value: Any) -> None:
self.set_config(ConfigData({key: value}))
def get_instance(self, domain: str) -> Instance | None:
with self.run("get-instance", {"domain": domain}) as cur:
return Instance.from_possible_row(cur.one())
def get_instance_by_id(self, userid: int) -> Instance | None:
with self.execute("SELECT * FROM instance WHERE id = $id", {"id": userid}) as cur:
return Instance.from_possible_row(cur.one())
def get_user(self, username: str, domain: str | None = None) -> User | None:
if not domain:
domain = self.config.host
with self.select("user", username = username, domain = domain) as cur:
return User.from_possible_row(cur.one())
def get_user_by_actor(self, actor: str) -> User | None:
with self.execute("SELECT * FROM user WHERE actor = $actor", {"actor": actor}) as cur:
user = User.from_possible_row(cur.one())
if user is not None:
user.get_instance(self)
return user
def get_user_by_id(self, userid: int) -> User | None:
with self.execute("SELECT * FROM user WHERE id = $id", {"id": userid}) as cur:
user = User.from_possible_row(cur.one())
if user is not None:
user.get_instance(self)
return user
def get_users(self, domain: str) -> Iterator[User]:
for row in self.run("get-users", {"domain": domain}):
if (user := User.from_possible_row(row)) is not None:
user.get_instance(self)
yield user
def count_domains(self) -> int:
with self.execute("SELECT COUNT(*) FROM instance WHERE id > 0") as cur:
return get_count(cur.one())
def count_users(self, domain: str | None) -> int:
with self.run("count-users", {"domain": domain or self.config.host}) as cur:
return get_count(cur.one())
def count_local_posts(self) -> int:
with self.run("count-local-posts") as cur:
return get_count(cur.one())
def put_instance(self,
domain: str,
web_domain: str,
shared_inbox: str | None = None,
name: str | None = None,
description: str | None = None,
short_description: str | None = None) -> Instance:
date = HttpDate.new_utc()
data = Instance.new(
domain = domain,
web_domain = web_domain,
shared_inbox = shared_inbox or f"https://{web_domain}/inbox",
name = name,
description = description,
created = date,
updated = date
)
return self.insert_row(data)
def put_instance_data(self,
nodeinfo: Nodeinfo,
webfinger: Webfinger,
actor: Message) -> Instance:
if self.get_instance(actor.domain) is not None:
raise KeyError("Instance already exists")
with self.run("get-at-least-1-instance") as cur:
has_instance = cur.one() is not None
try:
name: str | None = nodeinfo["metadata"]["nodeName"]
except KeyError:
name = None
try:
description: str | None = nodeinfo["metadata"]["nodeDescription"]
except KeyError:
description = None
current = HttpDate.new_utc()
instance = Instance(
id = 0 if has_instance else 1,
domain = webfinger.domain,
web_domain = actor.domain,
shared_inbox = actor.shared_inbox,
name = name,
description = description,
software = nodeinfo.sw_name,
mod_action = None,
mod_action_reason = None,
mod_action_date = None,
note = None,
created = current,
updated = current
)
return self.insert_row(instance)
def put_local_user(self,
username: str,
email: str,
display_name: str,
password: str,
permission: int = 10,
activate: bool = True) -> User:
with self.run("get-at-least-1-user") as cur:
has_user = cur.one() is not None
with self.execute("SELECT * FROM instance WHERE id = -99") as cur:
if (instance := Instance.from_possible_row(cur.one())) is None:
raise ValueError(TRANS.fetch("error", "empty-row"))
TRANS.print("setup", "create-key", username = username)
signer = Signer.new(f"https://{instance.web_domain}/user/{username}#main-key")
date = HttpDate.new_utc()
data = User(
id = 0 if has_user else 1,
username = username,
domain = instance.domain,
display_name = display_name,
actor = f"https://{instance.web_domain}/user/{username}",
inbox = f"https://{instance.web_domain}/user/{username}/inbox",
page_url = f"https://{instance.web_domain}/@{username}",
email = email,
password = self.hasher.hash(password),
permission = permission,
locked = False,
is_bot = False,
private_key = signer.export(),
public_key = signer.pubkey,
bio = "",
info = {},
activated = date if activate else None,
mod_action = None,
mod_action_reason = None,
mod_action_date = None,
note = None,
created = date,
updated = date
)
return self.insert_row(data)
def put_remote_user(self, actor: Message) -> User:
with self.run("get-at-least-1-user") as cur:
has_user = cur.one() is not None
if (instance := self.get_instance(actor.domain)) is None:
raise ValueError("Failed to get instance row")
if self.get_user(actor.username, instance.domain) is not None:
raise ValueError("User already exists")
try:
date = actor.published
except AttributeError:
date = MessageDate.now()
data = User(
id = 0 if has_user else 1,
username = actor.username,
domain = instance.domain,
display_name = actor.get("name", actor.handle),
actor = actor.id,
inbox = actor.inbox,
page_url = actor.get("url", actor.id),
email = None,
password = None,
permission = PermissionLevel.REMOTE,
locked = actor.manually_approves_followers,
is_bot = actor.type != ObjectType.PERSON,
private_key = None,
public_key = actor.pubkey,
bio = actor.get("summary", ""),
info = actor.get_fields(),
activated = None,
mod_action = None,
mod_action_reason = None,
mod_action_date = None,
note = None,
created = date,
updated = date
)
return self.insert_row(data)
def get_follow(self, source: User, target: User) -> Follow | None:
params = {
"sourceid": source.id,
"targetid": target.id
}
with self.run("get-follow", params) as cur:
return Follow.from_possible_row(cur.one())
def get_follow_by_id(self, followid: str) -> Follow | None:
params = {"followid": followid}
with self.execute("SELECT * FROM follow WHERE followid = $followid", params) as cur:
return Follow.from_possible_row(cur.one())
def get_follower_count(self, user: User) -> int:
with self.run("get-follower-count", {"userid": user.id}) as cur:
return get_count(cur.one())
def get_following_count(self, user: User) -> int:
with self.run("get-following-count", {"userid": user.id}) as cur:
return get_count(cur.one())
def get_post_count(self, user: User) -> int:
with self.run("get-post-count", {"userid": user.id}) as cur:
return get_count(cur.one())
def get_peer_instances(self) -> Iterator[str]:
for row in self.run("get-peer-instances"):
yield row["domain"]
def get_cookie(self, code: str) -> Cookie | None:
with self.select("cookie", code = code) as cur:
return Cookie.from_possible_row(cur.one())
def put_cookie(self, user: User, user_agent: str | None) -> Cookie:
date = HttpDate.new_utc()
params = {
"code": secrets.token_hex(16),
"userid": user.id,
"user_agent": user_agent,
"last_access": date,
"created": date
}
with self.insert("cookie", params) as cur:
if (row := cur.one()) is None:
raise ValueError("heck :/")
return Cookie.from_row(row)
def del_cookie(self, token: str) -> None:
with self.delete("cookie", code = token):
pass
def get_count(row: Row | None) -> int:
if not row:
return 0
return int(tuple(row.values())[0])