TandoorRecipes/cookbook/connectors/connector_manager.py

150 lines
5.2 KiB
Python

import asyncio
import logging
import multiprocessing
import queue
from asyncio import Task
from dataclasses import dataclass
from enum import Enum
from multiprocessing import JoinableQueue
from types import UnionType
from typing import List, Any, Dict, Optional
from django_scopes import scope
from cookbook.connectors.connector import Connector
from cookbook.connectors.example import Example
from cookbook.connectors.homeassistant import HomeAssistant
from cookbook.models import ShoppingListEntry, Recipe, MealPlan, Space, HomeAssistantConfig, ExampleConfig
multiprocessing.set_start_method('fork') # https://code.djangoproject.com/ticket/31169
QUEUE_MAX_SIZE = 25
REGISTERED_CLASSES: UnionType = ShoppingListEntry | Recipe | MealPlan
CONNECTOR_UPDATE_CLASSES: UnionType = HomeAssistantConfig | ExampleConfig
class ActionType(Enum):
CREATED = 1
UPDATED = 2
DELETED = 3
@dataclass
class Work:
instance: REGISTERED_CLASSES
actionType: ActionType
class ConnectorManager:
_queue: JoinableQueue
_listening_to_classes = REGISTERED_CLASSES | CONNECTOR_UPDATE_CLASSES
def __init__(self):
self._queue = multiprocessing.JoinableQueue(maxsize=QUEUE_MAX_SIZE)
self._worker = multiprocessing.Process(target=self.worker, args=(self._queue,), daemon=True)
self._worker.start()
def __call__(self, instance: Any, **kwargs) -> None:
if not isinstance(instance, self._listening_to_classes) or not hasattr(instance, "space"):
return
action_type: ActionType
if "created" in kwargs and kwargs["created"]:
action_type = ActionType.CREATED
elif "created" in kwargs and not kwargs["created"]:
action_type = ActionType.UPDATED
elif "origin" in kwargs:
action_type = ActionType.DELETED
else:
return
try:
self._queue.put_nowait(Work(instance, action_type))
except queue.Full:
logging.info("queue was full, so skipping %s", instance)
return
def stop(self):
self._queue.join()
self._queue.close()
self._worker.join()
@staticmethod
def worker(worker_queue: JoinableQueue):
from django.db import connections
connections.close_all()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
_connectors: Dict[str, List[Connector]] = dict()
while True:
try:
item: Optional[Work] = worker_queue.get()
except KeyboardInterrupt:
break
if item is None:
break
# If a Connector was changed/updated, refresh connector from the database for said space
refresh_connector_cache = isinstance(item.instance, CONNECTOR_UPDATE_CLASSES)
space: Space = item.instance.space
connectors: Optional[List[Connector]] = _connectors.get(space.name)
if connectors is None or refresh_connector_cache:
if connectors is not None:
loop.run_until_complete(close_connectors(connectors))
with scope(space=space):
connectors: List[Connector] = [
*(HomeAssistant(config) for config in space.homeassistantconfig_set.all() if config.enabled),
*(Example(config) for config in space.exampleconfig_set.all() if config.enabled)
]
_connectors[space.name] = connectors
if len(connectors) == 0 or refresh_connector_cache:
worker_queue.task_done()
continue
loop.run_until_complete(run_connectors(connectors, space, item.instance, item.actionType))
worker_queue.task_done()
loop.close()
async def close_connectors(connectors: List[Connector]):
tasks: List[Task] = [asyncio.create_task(connector.close()) for connector in connectors]
try:
await asyncio.gather(*tasks, return_exceptions=False)
except BaseException:
logging.exception("received an exception while closing one of the connectors")
async def run_connectors(connectors: List[Connector], space: Space, instance: REGISTERED_CLASSES, action_type: ActionType):
tasks: List[Task] = list()
if isinstance(instance, ShoppingListEntry):
shopping_list_entry: ShoppingListEntry = instance
match action_type:
case ActionType.CREATED:
for connector in connectors:
tasks.append(asyncio.create_task(connector.on_shopping_list_entry_created(space, shopping_list_entry)))
case ActionType.UPDATED:
for connector in connectors:
tasks.append(asyncio.create_task(connector.on_shopping_list_entry_updated(space, shopping_list_entry)))
case ActionType.DELETED:
for connector in connectors:
tasks.append(asyncio.create_task(connector.on_shopping_list_entry_deleted(space, shopping_list_entry)))
if len(tasks) == 0:
return
try:
await asyncio.gather(*tasks, return_exceptions=False)
except BaseException:
logging.exception("received an exception from one of the connectors")