add startup and shutdown callbacks
This commit is contained in:
parent
45727d5926
commit
216a779f22
|
@ -9,6 +9,7 @@ from functools import partial
|
|||
from http_router import Router, MethodNotAllowed, NotFound
|
||||
from izzylib import DotDict, Path, logging, signal_handler
|
||||
from jinja2.exceptions import TemplateNotFound
|
||||
from threading import Event, Thread
|
||||
|
||||
from . import http_methods, error, __file__ as module_root
|
||||
from .config import Config
|
||||
|
@ -30,9 +31,8 @@ frontend = Path(module_root).join('../../frontend').resolve()
|
|||
|
||||
|
||||
class ApplicationBase:
|
||||
ctx = DotDict()
|
||||
|
||||
def __init__(self, appname='default', views=[], middleware=[], dbtype=None, dbargs={}, dbclass=Database, **kwargs):
|
||||
self.ctx = DotDict()
|
||||
self.name = appname
|
||||
self.cfg = Config(**kwargs)
|
||||
self.db = None
|
||||
|
@ -61,6 +61,10 @@ class ApplicationBase:
|
|||
self.ctx[key] = value
|
||||
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.ctx[key]
|
||||
|
||||
|
||||
def get_route(self, path, method='GET'):
|
||||
return self.router(str(path), method.upper())
|
||||
|
||||
|
@ -178,21 +182,23 @@ class Application(ApplicationBase):
|
|||
def __init__(self, loop=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if loop:
|
||||
self.loop = loop
|
||||
|
||||
else:
|
||||
if not loop:
|
||||
try:
|
||||
self.loop = asyncio.get_running_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
except RuntimeError:
|
||||
self.loop = asyncio.new_event_loop()
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
self.loop = loop
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
self.client = self.cfg.client_class(loop, *self.cfg.client_args, **self.cfg.client_kwargs)
|
||||
|
||||
self._running = Event()
|
||||
self._blueprints = {}
|
||||
self._server = None
|
||||
self._tasks = []
|
||||
self._callbacks = DotDict(startup = [], shutdown = [])
|
||||
|
||||
if self.cfg.tpl_default:
|
||||
if type(Template) == NotImplementedError:
|
||||
|
@ -219,6 +225,41 @@ class Application(ApplicationBase):
|
|||
self.template = None
|
||||
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return self._running.is_set()
|
||||
|
||||
|
||||
def add_task(self, state, callback, *args, **kwargs):
|
||||
assert state in ['startup', 'shutdown']
|
||||
assert asyncio.iscoroutinefunction(callback)
|
||||
self._callbacks[state].append((callback, args, kwargs))
|
||||
|
||||
|
||||
def add_startup_task(self, callback, *args, **kwargs):
|
||||
self.add_task('startup', callback, *args, **kwargs)
|
||||
|
||||
|
||||
def add_shutdown_task(self, callback, *args, **kwargs):
|
||||
self.add_task('shutdown', callback, *args, **kwargs)
|
||||
|
||||
|
||||
def remove_task(self, state, callback):
|
||||
assert state in ['startup', 'shutdown']
|
||||
|
||||
for task_callback, _, _ in self._callbacks[state]:
|
||||
if task_callback == task:
|
||||
self._callbacks[state].remove(task)
|
||||
|
||||
|
||||
def remove_startup_task(self, callback):
|
||||
self.remove_task('startup', callback)
|
||||
|
||||
|
||||
def remove_shutdown_task(self, callback):
|
||||
self.remove_task('shutdown', callback)
|
||||
|
||||
|
||||
def add_blueprint(self, bp):
|
||||
assert bp.prefix not in self._blueprints.values()
|
||||
|
||||
|
@ -261,21 +302,14 @@ class Application(ApplicationBase):
|
|||
)
|
||||
|
||||
|
||||
def stop(self, *_):
|
||||
if not self._server:
|
||||
print('server not running')
|
||||
return
|
||||
def run(self):
|
||||
task = self.start()
|
||||
|
||||
self._server.close()
|
||||
|
||||
for task in self._tasks:
|
||||
task.cancel()
|
||||
self._tasks.remove(task)
|
||||
|
||||
signal_handler(None)
|
||||
while not task.done():
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def start(self, *tasks, log=True):
|
||||
def start(self, log=True):
|
||||
if self._server:
|
||||
return
|
||||
|
||||
|
@ -303,23 +337,86 @@ class Application(ApplicationBase):
|
|||
|
||||
signal_handler(self.stop)
|
||||
self._server = self.loop.run_until_complete(server)
|
||||
return asyncio.ensure_future(self.handle_run_server())
|
||||
|
||||
for task in tasks:
|
||||
asyncio.ensure_future(task, loop=self.loop)
|
||||
|
||||
self.loop.run_until_complete(self.handle_run_server())
|
||||
def stop(self, *_):
|
||||
if not self._server:
|
||||
print('server not running')
|
||||
return
|
||||
|
||||
self._running.clear()
|
||||
#self._server.close()
|
||||
|
||||
self.loop.run_until_complete(self.handle_stop_server())
|
||||
|
||||
if self.cfg.sig_handler:
|
||||
self.cfg.sig_handler(self, *self.cfg.sig_handler_args, **self.cfg.sig_handler_kwargs)
|
||||
|
||||
signal_handler(None)
|
||||
|
||||
|
||||
async def handle_run_server(self):
|
||||
while self._server.is_serving():
|
||||
await asyncio.sleep(0.1)
|
||||
self._running.set()
|
||||
|
||||
await self._server.wait_closed()
|
||||
## Run startup tasks
|
||||
for callback, args, kwargs in self._callbacks['startup']:
|
||||
self._tasks.append(asyncio.ensure_future(callback(self, *args, **kwargs)))
|
||||
|
||||
## Wait for server to finish
|
||||
try:
|
||||
while self._server.is_serving() and self.running:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await self._server.wait_closed()
|
||||
|
||||
except:
|
||||
traceback.print_exc()
|
||||
|
||||
self._running.clear()
|
||||
|
||||
## Run shutdown tasks
|
||||
for callback, args, kwargs in self._callbacks['shutdown']:
|
||||
try:
|
||||
await asyncio.wait_for(callback(self, *args, **kwargs), 10)
|
||||
|
||||
except TimeoutError:
|
||||
pass
|
||||
|
||||
for task in self._tasks:
|
||||
task.cancel()
|
||||
|
||||
try:
|
||||
await task
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._tasks = []
|
||||
self._server = None
|
||||
|
||||
logging.info('Server stopped')
|
||||
|
||||
|
||||
async def handle_stop_server(self):
|
||||
for callback, args, kwargs in self._callbacks['shutdown']:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await asyncio.wait_for(callback(self, *args, **kwargs), 10)
|
||||
|
||||
else:
|
||||
callback(self, *args, **kwargs)
|
||||
|
||||
for task in self._tasks:
|
||||
task.cancel()
|
||||
|
||||
try:
|
||||
await task
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._tasks = []
|
||||
|
||||
|
||||
async def handle_client(self, reader, writer):
|
||||
transport = AsyncTransport(reader, writer, self.cfg.timeout)
|
||||
request = None
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from izzylib import (
|
||||
BaseConfig,
|
||||
LowerDotDict,
|
||||
boolean
|
||||
LowerDotDict
|
||||
)
|
||||
|
||||
from .request import ServerRequest
|
||||
|
@ -49,7 +48,7 @@ class Config(BaseConfig):
|
|||
|
||||
self._startup = False
|
||||
self.default_headers.update(kwargs.pop('default_headers', {}))
|
||||
self.set_data(kwargs)
|
||||
self.update(kwargs)
|
||||
|
||||
if not self.default_headers.get('server'):
|
||||
self.default_headers['server'] = f'{self.name}/{__version__}'
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from datetime import datetime, timezone, timedelta
|
||||
from izzylib import DotDict, Path, boolean, logging
|
||||
from izzylib import DotDict, Path, convert_to_boolean, logging
|
||||
|
||||
|
||||
UtcTime = timezone.utc
|
||||
|
@ -240,7 +240,7 @@ class CookieItem:
|
|||
|
||||
@secure.setter
|
||||
def secure(self, data):
|
||||
self.args['Secure'] = boolean(data)
|
||||
self.args['Secure'] = convert_to_boolean(data)
|
||||
|
||||
|
||||
@property
|
||||
|
@ -250,7 +250,7 @@ class CookieItem:
|
|||
|
||||
@httponly.setter
|
||||
def httponly(self, data):
|
||||
self.args['HttpOnly'] = boolean(data)
|
||||
self.args['HttpOnly'] = convert_to_boolean(data)
|
||||
|
||||
|
||||
@property
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import json, traceback
|
||||
import base64, hashlib, json, traceback
|
||||
|
||||
from datetime import datetime
|
||||
from izzylib import MultiDotDict
|
||||
|
@ -192,6 +192,12 @@ class ServerResponse:
|
|||
transport.write(self.compile(body=False))
|
||||
|
||||
|
||||
async def set_websocket(self, transport, protocol, headers={}):
|
||||
raise RuntimeError('Not implemented yet')
|
||||
|
||||
return WebSocketHandler(self.request, self, protocol, headers)
|
||||
|
||||
|
||||
def set_cookie(self, key, value, **kwargs):
|
||||
self.cookies[key] = CookieItem(key, value, **kwargs)
|
||||
|
||||
|
@ -199,3 +205,56 @@ class ServerResponse:
|
|||
def compile(self, body=True):
|
||||
first = first_line(status=self.status)
|
||||
return create_message(first, self.headers, self.cookies, self.body if body else None)
|
||||
|
||||
|
||||
class WebSocketHandler:
|
||||
def __init__(self, request, response, protocol='sample', headers={}):
|
||||
self.request = request
|
||||
self.response = response
|
||||
self.protocol = protocol
|
||||
self.headers = headers
|
||||
|
||||
self.started = False
|
||||
self.ended = False
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return request.transport
|
||||
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.stop()
|
||||
|
||||
|
||||
def start(self):
|
||||
if self.started:
|
||||
return
|
||||
|
||||
self.response.headers.update(headers)
|
||||
self.response.headers.update(transport.app.cfg.default_headers)
|
||||
self.headers.setall('Upgrade', 'websocket')
|
||||
self.headers.setall('Connection', 'upgrade')
|
||||
#self.headers.setall('WebSocket-Origin', self.url)
|
||||
#self.headers.setall('WebSocket-Location', self.url)
|
||||
self.headers.setall('sec-websocket-protocol', 'sample')
|
||||
|
||||
if (hash := self.request.headers.get('sec-websocket-key')):
|
||||
accept_text = hash + '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
|
||||
accept_hash = hashlib.sha1(hash.encode())
|
||||
self.headers.setall('sec-websocket-accept', base64.b64encode(accept_hash).decode())
|
||||
|
||||
self.request.transport.write(self.compile(body=False))
|
||||
self.started = true
|
||||
|
||||
|
||||
def stop(self):
|
||||
if not self.started:
|
||||
raise AttributeError('Response not started yet')
|
||||
|
||||
if self.ended:
|
||||
return
|
||||
|
||||
self.request.transport.write('> EOF')
|
||||
|
||||
self.ended = True
|
||||
|
|
|
@ -3,7 +3,7 @@ import codecs, traceback, os, json, xml
|
|||
from functools import partial
|
||||
from hamlish_jinja import HamlishExtension
|
||||
from izzylib import DotDict, Path, izzylog
|
||||
from jinja2 import Environment, FileSystemLoader, ChoiceLoader, select_autoescape, Markup
|
||||
from jinja2 import Environment, FileSystemLoader, ChoiceLoader, select_autoescape
|
||||
from os import listdir, makedirs
|
||||
from os.path import isfile, isdir, getmtime, abspath
|
||||
from xml.dom import minidom
|
||||
|
@ -39,7 +39,6 @@ class Template(Environment):
|
|||
self.add_search_path(Path(path))
|
||||
|
||||
self.globals.update({
|
||||
'markup': Markup,
|
||||
'cleanhtml': lambda text: ''.join(xml.etree.ElementTree.fromstring(text).itertext()),
|
||||
'color': Color,
|
||||
'lighten': partial(color_func, 'lighten'),
|
||||
|
|
Loading…
Reference in a new issue