add startup and shutdown callbacks

This commit is contained in:
Izalia Mae 2022-04-16 06:15:11 -04:00
parent 45727d5926
commit 216a779f22
5 changed files with 190 additions and 36 deletions

View file

@ -9,6 +9,7 @@ from functools import partial
from http_router import Router, MethodNotAllowed, NotFound
from izzylib import DotDict, Path, logging, signal_handler
from jinja2.exceptions import TemplateNotFound
from threading import Event, Thread
from . import http_methods, error, __file__ as module_root
from .config import Config
@ -30,9 +31,8 @@ frontend = Path(module_root).join('../../frontend').resolve()
class ApplicationBase:
ctx = DotDict()
def __init__(self, appname='default', views=[], middleware=[], dbtype=None, dbargs={}, dbclass=Database, **kwargs):
self.ctx = DotDict()
self.name = appname
self.cfg = Config(**kwargs)
self.db = None
@ -61,6 +61,10 @@ class ApplicationBase:
self.ctx[key] = value
def __delitem__(self, key):
del self.ctx[key]
def get_route(self, path, method='GET'):
return self.router(str(path), method.upper())
@ -178,21 +182,23 @@ class Application(ApplicationBase):
def __init__(self, loop=None, **kwargs):
super().__init__(**kwargs)
if loop:
self.loop = loop
else:
if not loop:
try:
self.loop = asyncio.get_running_loop()
loop = asyncio.get_running_loop()
except RuntimeError:
self.loop = asyncio.new_event_loop()
loop = asyncio.new_event_loop()
self.loop = loop
asyncio.set_event_loop(self.loop)
self.client = self.cfg.client_class(loop, *self.cfg.client_args, **self.cfg.client_kwargs)
self._running = Event()
self._blueprints = {}
self._server = None
self._tasks = []
self._callbacks = DotDict(startup = [], shutdown = [])
if self.cfg.tpl_default:
if type(Template) == NotImplementedError:
@ -219,6 +225,41 @@ class Application(ApplicationBase):
self.template = None
@property
def running(self):
return self._running.is_set()
def add_task(self, state, callback, *args, **kwargs):
assert state in ['startup', 'shutdown']
assert asyncio.iscoroutinefunction(callback)
self._callbacks[state].append((callback, args, kwargs))
def add_startup_task(self, callback, *args, **kwargs):
self.add_task('startup', callback, *args, **kwargs)
def add_shutdown_task(self, callback, *args, **kwargs):
self.add_task('shutdown', callback, *args, **kwargs)
def remove_task(self, state, callback):
assert state in ['startup', 'shutdown']
for task_callback, _, _ in self._callbacks[state]:
if task_callback == task:
self._callbacks[state].remove(task)
def remove_startup_task(self, callback):
self.remove_task('startup', callback)
def remove_shutdown_task(self, callback):
self.remove_task('shutdown', callback)
def add_blueprint(self, bp):
assert bp.prefix not in self._blueprints.values()
@ -261,21 +302,14 @@ class Application(ApplicationBase):
)
def stop(self, *_):
if not self._server:
print('server not running')
return
def run(self):
task = self.start()
self._server.close()
for task in self._tasks:
task.cancel()
self._tasks.remove(task)
signal_handler(None)
while not task.done():
time.sleep(1)
def start(self, *tasks, log=True):
def start(self, log=True):
if self._server:
return
@ -303,23 +337,86 @@ class Application(ApplicationBase):
signal_handler(self.stop)
self._server = self.loop.run_until_complete(server)
return asyncio.ensure_future(self.handle_run_server())
for task in tasks:
asyncio.ensure_future(task, loop=self.loop)
self.loop.run_until_complete(self.handle_run_server())
def stop(self, *_):
if not self._server:
print('server not running')
return
self._running.clear()
#self._server.close()
self.loop.run_until_complete(self.handle_stop_server())
if self.cfg.sig_handler:
self.cfg.sig_handler(self, *self.cfg.sig_handler_args, **self.cfg.sig_handler_kwargs)
signal_handler(None)
async def handle_run_server(self):
while self._server.is_serving():
await asyncio.sleep(0.1)
self._running.set()
await self._server.wait_closed()
## Run startup tasks
for callback, args, kwargs in self._callbacks['startup']:
self._tasks.append(asyncio.ensure_future(callback(self, *args, **kwargs)))
## Wait for server to finish
try:
while self._server.is_serving() and self.running:
await asyncio.sleep(0.1)
await self._server.wait_closed()
except:
traceback.print_exc()
self._running.clear()
## Run shutdown tasks
for callback, args, kwargs in self._callbacks['shutdown']:
try:
await asyncio.wait_for(callback(self, *args, **kwargs), 10)
except TimeoutError:
pass
for task in self._tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
self._tasks = []
self._server = None
logging.info('Server stopped')
async def handle_stop_server(self):
for callback, args, kwargs in self._callbacks['shutdown']:
if asyncio.iscoroutinefunction(callback):
await asyncio.wait_for(callback(self, *args, **kwargs), 10)
else:
callback(self, *args, **kwargs)
for task in self._tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
self._tasks = []
async def handle_client(self, reader, writer):
transport = AsyncTransport(reader, writer, self.cfg.timeout)
request = None

View file

@ -1,7 +1,6 @@
from izzylib import (
BaseConfig,
LowerDotDict,
boolean
LowerDotDict
)
from .request import ServerRequest
@ -49,7 +48,7 @@ class Config(BaseConfig):
self._startup = False
self.default_headers.update(kwargs.pop('default_headers', {}))
self.set_data(kwargs)
self.update(kwargs)
if not self.default_headers.get('server'):
self.default_headers['server'] = f'{self.name}/{__version__}'

View file

@ -1,5 +1,5 @@
from datetime import datetime, timezone, timedelta
from izzylib import DotDict, Path, boolean, logging
from izzylib import DotDict, Path, convert_to_boolean, logging
UtcTime = timezone.utc
@ -240,7 +240,7 @@ class CookieItem:
@secure.setter
def secure(self, data):
self.args['Secure'] = boolean(data)
self.args['Secure'] = convert_to_boolean(data)
@property
@ -250,7 +250,7 @@ class CookieItem:
@httponly.setter
def httponly(self, data):
self.args['HttpOnly'] = boolean(data)
self.args['HttpOnly'] = convert_to_boolean(data)
@property

View file

@ -1,4 +1,4 @@
import json, traceback
import base64, hashlib, json, traceback
from datetime import datetime
from izzylib import MultiDotDict
@ -192,6 +192,12 @@ class ServerResponse:
transport.write(self.compile(body=False))
async def set_websocket(self, transport, protocol, headers={}):
raise RuntimeError('Not implemented yet')
return WebSocketHandler(self.request, self, protocol, headers)
def set_cookie(self, key, value, **kwargs):
self.cookies[key] = CookieItem(key, value, **kwargs)
@ -199,3 +205,56 @@ class ServerResponse:
def compile(self, body=True):
first = first_line(status=self.status)
return create_message(first, self.headers, self.cookies, self.body if body else None)
class WebSocketHandler:
def __init__(self, request, response, protocol='sample', headers={}):
self.request = request
self.response = response
self.protocol = protocol
self.headers = headers
self.started = False
self.ended = False
def __enter__(self):
self.start()
return request.transport
def __exit__(self, *args):
self.stop()
def start(self):
if self.started:
return
self.response.headers.update(headers)
self.response.headers.update(transport.app.cfg.default_headers)
self.headers.setall('Upgrade', 'websocket')
self.headers.setall('Connection', 'upgrade')
#self.headers.setall('WebSocket-Origin', self.url)
#self.headers.setall('WebSocket-Location', self.url)
self.headers.setall('sec-websocket-protocol', 'sample')
if (hash := self.request.headers.get('sec-websocket-key')):
accept_text = hash + '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
accept_hash = hashlib.sha1(hash.encode())
self.headers.setall('sec-websocket-accept', base64.b64encode(accept_hash).decode())
self.request.transport.write(self.compile(body=False))
self.started = true
def stop(self):
if not self.started:
raise AttributeError('Response not started yet')
if self.ended:
return
self.request.transport.write('> EOF')
self.ended = True

View file

@ -3,7 +3,7 @@ import codecs, traceback, os, json, xml
from functools import partial
from hamlish_jinja import HamlishExtension
from izzylib import DotDict, Path, izzylog
from jinja2 import Environment, FileSystemLoader, ChoiceLoader, select_autoescape, Markup
from jinja2 import Environment, FileSystemLoader, ChoiceLoader, select_autoescape
from os import listdir, makedirs
from os.path import isfile, isdir, getmtime, abspath
from xml.dom import minidom
@ -39,7 +39,6 @@ class Template(Environment):
self.add_search_path(Path(path))
self.globals.update({
'markup': Markup,
'cleanhtml': lambda text: ''.join(xml.etree.ElementTree.fromstring(text).itertext()),
'color': Color,
'lighten': partial(color_func, 'lighten'),