from __future__ import annotations
import asyncio
import contextlib
import json
import os
import random
import statistics
import string
import timeit
import traceback
from import Callable, Generator, Iterator, Sequence
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone, tzinfo
from functools import wraps
from http.client import HTTPResponse
from pathlib import Path
from typing import Any, Generic, TypeVar
from urllib.request import Request, urlopen
from . import __version__
from .enums import FileSizeUnit, HttpMethod
from importlib.resources import files as pkgfiles
except ImportError:
# this is a mess and I'm now debating my decision to never have outside dependencies
# outside of docs and dev stuff
import inspect
from importlib import import_module
from import Traversable
from types import ModuleType
def pkgfiles(anchor: str | ModuleType | None = None) -> Traversable:
if anchor is None:
raise ValueError("A module or name must be specified")
if isinstance(anchor, str):
anchor = import_module(anchor)
return Path(inspect.getfile(anchor)).parent
from typing import Self
except ImportError:
from typing_extensions import Self
DictValueType = TypeVar("DictValueType")
TRUE_STR = ['on', 'y', 'yes', 'true', 'enable', 'enabled', '1']
FALSE_STR = ['off', 'n', 'no', 'false', 'disable', 'disable', '0']
TLD_CACHE_PATH = Path("~/.cache").expanduser().joinpath('public_suffix_list.txt')
TLD_CACHE_DATA: list[str] = []
def catch_errors(suppress: bool = False) -> Generator[None, None, None]:
Context manager for running a block of code and catching any errors that get raised
:param suppress: If ``True``, don't print a raised exception
except Exception:
if not suppress:
def convert_to_boolean(value: Any) -> bool:
Convert an object to :class:`bool`. If it can't be directly converted, ``True``
is returned.
:param value: Object to be converted
if value is None:
return False
if isinstance(value, bool):
return value
if isinstance(value, str):
if value.lower() in TRUE_STR:
return True
if value.lower() in FALSE_STR:
return False
if isinstance(value, int):
if value == 1:
return True
if value == 0:
return False
return bool(value)
def convert_to_bytes(value: Any, encoding: str = "utf-8") -> bytes:
Convert an object to :class:`bytes`
:param value: Object to be converted
:param encoding: Character encoding to use if the object is a string or gets converted to
one in the process
:raises TypeError: If the object cannot be converted
if isinstance(value, bytes):
return value
return convert_to_string(value).encode(encoding)
except TypeError:
raise TypeError(f"Cannot convert '{get_object_name(value)}' into bytes") from None
def convert_to_string(value: Any, encoding: str = 'utf-8') -> str:
Convert an object to :class:`str`
:param value: Object to be converted
:param encoding: Character encoding to use if the object is a :class:`bytes` object
if isinstance(value, bytes):
return value.decode(encoding)
if isinstance(value, bool):
return str(value)
if isinstance(value, str):
return value
if isinstance(value, (dict, list, tuple, set)):
return json.dumps(value)
if isinstance(value, (int, float)):
return str(value)
if value is None:
return ''
raise TypeError(f'Cannot convert "{get_object_name(value)}" into a string') from None
def deprecated(new_method: str, version: str, remove: str | None = None) -> Callable[..., Any]:
Decorator to mark a function as deprecated and display a warning on first use.
:param new_method: Name of the function to replace the wrapped function
:param version: Version of the module in which the wrapped function was considered
:param remove: Version the wrapped function will get removed
called = False
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
def inner(*args: Any, **kwargs: Any) -> Any:
if not called:
name = func.__qualname__ if hasattr(func, "__qualname__") else func.__name__
if not remove:
print(f"WARN: {name} was deprecated in {version}. Use {new_method} instead.")
msg = f"WARN: {name} was deprecated in {version} and will be removed in "
msg += f"{remove}. Use {new_method} instead."
return func(*args, **kwargs)
return inner
return wrapper
def get_object_name(obj: Any) -> str:
Get the name of an object
:param obj: Object to get the name of
return obj.__name__ # type: ignore[no-any-return]
except AttributeError:
return type(obj).__name__
def get_object_properties(
obj: Any,
ignore_descriptors: bool = True,
ignore_underscore: bool = True) -> Iterator[tuple[str, Any]]:
Get an objet's properties and their values
:param obj: Object to get the properties of
:param ignore_descriptors: Don't get the value of descriptor objects (ex. ``@property``)
:param ignore_underscore: Skip properties that start with an underscore (``_``)
for key in dir(obj):
if ignore_descriptors and key in dir(type(obj)):
if ignore_underscore and key.startswith("_"):
if callable(value := getattr(obj, key)):
yield key, value
def get_resource_path(module: str, path: str | None = None) -> Path:
Get a path to a module resource
:param module: Name of the module to get the resource from
:param path: Path of the resource starting from the path of the module
new_path = Path(str(pkgfiles(module)))
return new_path.joinpath(path.lstrip("/")) if path is not None else new_path
def get_top_domain(domain: str) -> str:
Get the main domain from a string. The top-level domain list is cached as
``~/.cache/public_suffix_list.txt`` for 7 days
:param str domain: The domain to extract the top-level domain from
:raises ValueError: When the top domain cannot be found
if len(TLD_CACHE_DATA) == 0:
exists = TLD_CACHE_PATH.exists()
modified = datetime.fromtimestamp(TLD_CACHE_PATH.stat().st_mtime)
except FileNotFoundError:
modified = None
if not exists or not modified or modified + timedelta(days=7) <
TLD_CACHE_PATH.parent.mkdir(exist_ok = True, parents = True)
with"wt", encoding = "utf-8") as fd:
with http_request(TLD_CACHE_URL) as resp:
for line in'utf-8').splitlines():
if 'end icann domains' in line.lower():
if not line or line.startswith('//'):
if line.startswith('*'):
line = line[2:]
fd.write(line + "\n")
with"r", encoding = "utf-8") as fd:
domain_split = domain.split('.')
if '.'.join(domain_split[-2:]) in TLD_CACHE_DATA:
return '.'.join(domain_split[-3:])
except IndexError:
if '.'.join(domain_split[-1:]) in TLD_CACHE_DATA:
return '.'.join(domain_split[-2:])
raise ValueError('Cannot find TLD')
def http_request(
url: str,
data: Any = None,
headers: dict[str, str] | None = None,
method: HttpMethod | str = HttpMethod.GET,
timeout: int = 60) -> HTTPResponse:
Make an http request. The default User-Agent is "blib/:attr:`BLib.__version__`"
:param url: Url to send the request to
:param data: Data to send with the request. Must be parsable by :class:`convert_to_bytes`.
:param method: HTTP method to use when making the request
:param headers: HTTP header key/value pairs to send with the request
:param timeout: How long to wait when connecting before giving up
:raises TimeoutError: When the connection was not established before the timeout limit
:raises urllib.error.HTTPError: When the server returns an error
method = HttpMethod.parse(method)
if not headers:
headers = {}
headers = {key.title(): value for key, value in headers.items()}
if headers.get("User-Agent") is None:
headers["User-Agent"] = f"BLib/{__version__}"
request = Request(
url = url,
method = method.upper(),
data = convert_to_bytes(data) if data else None,
headers = headers
return urlopen(request, timeout = timeout) # type: ignore[no-any-return]
def is_loop_running() -> bool:
"Check if an event loop is running in the current thread"
return asyncio.get_running_loop().is_running()
except RuntimeError:
return False
def random_str(
length: int = 20,
letters: bool = True,
numbers: bool = True,
extra: str = "") -> str:
Return a randomly generated string. Uses alphanumeric characters by default, but more can
be specified as a string.
:param int length: Length of the resulting string in characters
:param bool letters: If ``True``, include all ascii letters
:param bool numbers: If ``True``, include numbers
:param str extra: Characters to also include in the resulting string
characters = extra
if letters:
characters += string.ascii_letters
if numbers:
characters += string.digits
if extra:
characters += extra
return "".join(random.choices(characters, k=length))
def time_function(
func: Callable[..., Any],
*args: Any,
passes: int = 1,
use_gc: bool = True,
**kwargs: Any) -> RunData:
Call a function n times and return each run time, the average time, and the total time in
:param func: Function to call
:param args: Positional arguments to pass to the function
:param passes: Number of times to call the function
:param use_gc: Enable garbage collection during the runs
:param kwargs: Keyword arguments to pass to the function
if use_gc:
timer = timeit.Timer(lambda: func(*args, **kwargs), "gc.enable()")
timer = timeit.Timer(lambda: func(*args, **kwargs))
if passes > 1:
times = timer.repeat(passes, 1)
times = [timer.timeit(1)]
return RunData(tuple(times), statistics.fmean(times), sum(times))
def time_function_pprint(
func: Callable[..., Any],
*args: Any,
passes: int = 5,
use_gc: bool = True,
floatout: bool = True,
**kwargs: Any) -> RunData:
Prints out readable results from ``time_function`` and returns the raw data. Convert the
printed times to an ``int`` by setting ``floatout`` to ``False``
:param func: Function to call
:param args: Positional arguments to pass to the function
:param passes: Number of times to call the function
:param use_gc: Enable garbage collection during the runs
:param floatout: Print values as ``float`` instead of ``int``
:param kwargs: Keyword arguments to pass to the function
data = time_function(func, *args, **kwargs, passes = passes, use_gc = use_gc)
for idx, passtime in enumerate(data.runs):
if not floatout:
print(f"Pass {idx+1}: {passtime:.0f}")
print(f"Pass {idx+1}: {passtime:.8f}")
if not floatout:
print(f"Average: {data.average:.0f}")
print(f"Total: {}")
print(f"Average: {data.average:.8f}")
print(f"Total: {}")
return data
class DictProperty(Generic[DictValueType]):
"Represents a key in a dict"
def __init__(self, key: str) -> None:
Create a new dict property
:param key: Name of the key to be handled by this ``Property``
self.key: str = key
def __get__(self,
obj: dict[str, DictValueType | Any] | None,
objtype: Any = None) -> Self | DictValueType:
if obj is None:
return self
return obj[self.key]
except KeyError:
objname = get_object_name(obj)
raise AttributeError(f"'{objname}' has no attribute '{self.key}'") from None
def __set__(self, obj: dict[str, DictValueType | Any], value: DictValueType) -> None:
obj[self.key] = value
def __delete__(self, obj: dict[str, DictValueType | Any]) -> None:
del obj[self.key]
class Env:
"Easy access to environmental variables"
def get(key: str,
default: Any = None,
converter: Callable[[str], Any] = str) -> Any:
Get an environmental variable
:param key: Name of the variable
:param converter: Function to convert the value to a different type
:param default: The default value to return if the key is not found
return converter(os.environ[key])
except KeyError:
return default
def get_array(cls: type[Env],
key: str,
separator: str = ":",
converter: Callable[[str], Any] = str) -> Iterator[Any]:
Get an environmental variable as an iterator of items
:param key: Name of the variable
:param separator: String to use to split items
:param converter: Function to convert each value to a different type
for value in cls.get(key, "").split(separator):
yield (converter(value.strip()))
def get_int(cls: type[Env], key: str, default: int = 0) -> int:
Get an environmental variable as an ``int``
:param key: Name of the variable
:param default: The default value to return if the key is not found
return cls.get(key, default, int) # type: ignore[no-any-return]
def get_float(cls: type[Env], key: str, default: float = 0.0) -> float:
Get an environmental variable as a ``float``
:param key: Name of the variable
:param default: The default value to return if the key is not found
return cls.get(key, default, float) # type: ignore[no-any-return]
def get_bool(cls: type[Env], key: str, default: bool = False) -> bool:
Get an environmental variable as a ``bool``
:param key: Name of the variable
:param default: The default value to return if the key is not found
return cls.get(key, default, convert_to_boolean) # type: ignore[no-any-return]
def get_json(cls: type[Env], key: str, default: dict[Any, Any] | None = None) -> JsonBase:
Get an environmental variable as a JSON-parsed ``dict``
:param key: Name of the variable
:param default: The default value to return if the key is not found
return cls.get(key, default, JsonBase.parse) # type: ignore[no-any-return]
def get_list(cls: type[Env],
key: str,
separator: str = ":",
converter: Callable[[str], Any] = str) -> list[Any]:
Get an environmental variable as a ``tuple``
:param key: Name of the variable
:param separator: String to use to split items
:param converter: Function to convert each value to a different type
return list(cls.get_array(key, separator, converter))
def get_tuple(cls: type[Env],
key: str,
separator: str = ":",
converter: Callable[[str], Any] = str) -> tuple[Any]:
Get an environmental variable as a ``tuple``
:param key: Name of the
:param separator: String to use to split items
:param converter: Function to convert each value to a different type
return tuple(cls.get_list(key, separator, converter))
def get_set(cls: type[Env],
key: str,
separator: str = ":",
converter: Callable[[str], Any] = str) -> set[Any]:
Get an environmental variable as a ``set``
:param key: Name of the variable
:param separator: String to use to split items
:param converter: Function to convert each value to a different type
return set(cls.get_list(key, separator, converter))
def keys(cls: type[Env]) -> Iterator[str]:
"Fetch all environmental variable names"
for key in os.environ:
yield key
def items(cls: type[Env]) -> Iterator[tuple[str, str]]:
"Fetch all environmental variable names and values"
for key in os.environ:
yield key, os.environ[key]
def values(cls: type[Env]) -> Iterator[str]:
"Fetch all environmental variable values"
for value in os.environ.values():
yield value
class FileSize(int):
"Converts a human-readable file size to bytes"
def __new__(cls: type[Self], size: int | float, unit: FileSizeUnit = FileSizeUnit.B) -> Self:
return int.__new__(cls, FileSizeUnit.parse(unit).multiply(size))
def __repr__(self) -> str:
value = int(self)
return f"FileSize({value:,} bytes)"
def __str__(self) -> str:
return int.__str__(self)
def parse(cls: type[Self], text: str) -> Self:
Parse a file size string
:param text: String representation of a file size
:raises AttributeError: If the text cannot be parsed
size_str, unit = text.strip().split(" ", 1)
size = float(size_str)
unit = FileSizeUnit.parse(unit)
return cls(size, unit)
def to_optimal_string(self) -> str:
Attempts to display the size as the highest whole unit
index = 0
size: int | float = int(self)
while True:
if size < 1024 or index == 8:
unit = FileSizeUnit.from_index(index)
return f'{size:.2f} {unit}'
index += 1
size = self / FileSizeUnit.from_index(index).multiplier
except IndexError:
raise ValueError('File size is too large to convert to a string') from None
def to_string(self, unit: FileSizeUnit, decimals: int = 2) -> str:
Convert to the specified file size unit
:param unit: Unit to convert to
:param decimals: Number of decimal points to round to
unit = FileSizeUnit.parse(unit)
if unit == FileSizeUnit.BYTE:
return f'{self} B'
size = round(self / unit.multiplier, decimals)
return f'{size} {unit}'
class HttpDate(datetime):
``datetime`` object with convenience methods for parsing and creating HTTP date
strings. All objects assume a ``UTC`` timezone if one is not specified.
FORMAT: str = "%a, %d %b %Y %H:%M:%S GMT"
"Format to pass to datetime when (de)serializing a raw HTTP date"
def __new__(cls: type[Self],
year: int,
month: int,
day: int,
hour: int = 0,
minute: int = 0,
second: int = 0,
microsecond: int = 0,
tzinfo: tzinfo = timezone.utc) -> Self:
return datetime.__new__(
cls, year, month, day, hour, minute, second, microsecond, tzinfo
def __str__(self) -> str:
return self.to_string()
def parse(cls: type[Self], date: datetime | str | int | float) -> Self:
Parse a unix timestamp or HTTP date in string format
:param date: Data to be parsed
if isinstance(date, cls):
return date
elif isinstance(date, datetime):
return cls.fromisoformat(date.isoformat())
elif isinstance(date, (int | float)):
data = cls.fromtimestamp(float(date) if type(date) is int else date)
data = cls.strptime(date, cls.FORMAT)
if data.tzinfo is None:
return data.replace(tzinfo = timezone.utc)
return data.astimezone(tz = timezone.utc)
def new_utc(cls: type[Self]) -> Self:
"Create a new ``HttpDate`` object from the current UTC time"
def timestamp(self) -> int:
"Return the date as a unix timestamp without microseconds"
return int(datetime.timestamp(self))
def to_string(self) -> str:
"Create an HTTP Date header string from the datetime object"
return self.strftime(self.FORMAT)
class JsonBase(dict[str, Any]):
"A ``dict`` with methods to convert to JSON and back"
def parse(cls: type[Self], data: str | bytes | dict[str, Any]) -> Self:
Parse a JSON object
:param data: JSON object to parse
:raises TypeError: When an invalid object type is provided
if isinstance(data, (str, bytes)):
data = json.loads(data)
if isinstance(data, cls):
return data
if not isinstance(data, dict):
raise TypeError(f"Cannot parse objects of type \"{type(data).__name__}\"")
return cls(data)
def to_json(self, indent: int | str | None = None, **kwargs: Any) -> str:
Return the message as a JSON string
:param indent: Number of spaces or the string to use for indention
:param kwargs: Keyword arguments to pass to :func:`json.dumps`
return json.dumps(self, indent = indent, default = self.handle_value_dump, **kwargs)
def handle_value_dump(self, value: Any) -> Any:
Gets called when a value is of the wrong type and needs to be converted for dumping to
json. If the type is unknown, it will be forcibly converted to a ``str``.
:param value: Value to be converted
if not isinstance(value, (str, int, float, bool, dict, list, tuple, type(None))):
print(f"Warning: Cannot properly convert value of type '{type(value).__name__}'")
return str(value)
return value
class NamedTuple(tuple[tuple[str, Any], ...]):
"A tuple with dict-like access of items"
slots = ("_keys", )
_keys: tuple[str, ...]
def __new__(cls: type[Self],
keys: Sequence[str],
values: Sequence[Any]) -> Self:
Create a new ``NamedTuple`` object. ``kwargs`` will override any matching keys in
:param _items: ``dict`` object of key/value pairs to add
:param kwargs: key/value pairs to add
if len(keys) != len(values):
raise ValueError("Keys and values must be the same length")
data = tuple.__new__(cls, zip(keys, values))
data._keys = tuple(keys)
return data
def __getitem__(self, key: int | str) -> Any: # type: ignore[override]
if isinstance(key, str):
key = self._keys.index(key)
return tuple.__getitem__(self, key)[1]
def __repr__(self) -> str:
data = ", ".join(f"{key}={repr(value)}" for key, value in self.items())
return f"NamedTuple({data})"
def from_dict(cls: type[Self], data: dict[str, Any]) -> Self:
return cls(tuple(data.keys()), tuple(data.values()))
def keys(self) -> Iterator[str]:
"Keys associated with each value"
for key in self._keys:
yield key
def items(self) -> Iterator[tuple[str, Any]]:
"Key/value pairs as a tuple of tuples"
for item in self:
yield item
def to_dict(self) -> dict[str, Any]:
"Convert to a :class:`dict`"
return dict(self)
def values(self) -> Iterator[Any]:
"Iterate throug the values"
for key, value in self:
yield value
class RunData:
"Data returned from :meth:`time_function` and `time_function_pprint`"
runs: tuple[float, ...]
"Elapsed time of each run"
average: float
"Average time of all runs"
total: float
"Time it took for all runs"