change formatting a bit, and add async close method

This commit is contained in:
Mikhail Epifanov 2024-01-13 13:43:08 +01:00
parent c7dd61e239
commit 87ede4b9cc
No known key found for this signature in database
3 changed files with 36 additions and 11 deletions

View File

@ -16,5 +16,9 @@ class Connector(ABC):
async def on_shopping_list_entry_deleted(self, space: Space, instance: ShoppingListEntry) -> None:
pass
@abstractmethod
async def close(self) -> None:
pass
# TODO: Maybe add an 'IsEnabled(self) -> Bool' to here
# TODO: Add Recipes & possibly Meal Place listeners/hooks (And maybe more?)

View File

@ -5,7 +5,7 @@ import queue
from asyncio import Task
from dataclasses import dataclass
from enum import Enum
from multiprocessing import Queue
from multiprocessing import JoinableQueue
from types import UnionType
from typing import List, Any, Dict, Optional
@ -35,11 +35,11 @@ class Work:
class ConnectorManager:
_queue: Queue
_queue: JoinableQueue
_listening_to_classes = REGISTERED_CLASSES | CONNECTOR_UPDATE_CLASSES
def __init__(self):
self._queue = multiprocessing.Queue(maxsize=QUEUE_MAX_SIZE)
self._queue = multiprocessing.JoinableQueue(maxsize=QUEUE_MAX_SIZE)
self._worker = multiprocessing.Process(target=self.worker, args=(self._queue,), daemon=True)
self._worker.start()
@ -60,14 +60,16 @@ class ConnectorManager:
try:
self._queue.put_nowait(Work(instance, action_type))
except queue.Full:
logging.info("queue was full, so skipping %s", instance)
return
def stop(self):
self._queue.join()
self._queue.close()
self._worker.join()
@staticmethod
def worker(worker_queue: Queue):
def worker(worker_queue: JoinableQueue):
from django.db import connections
connections.close_all()
@ -77,7 +79,10 @@ class ConnectorManager:
_connectors: Dict[str, List[Connector]] = dict()
while True:
item: Optional[Work] = worker_queue.get()
try:
item: Optional[Work] = worker_queue.get()
except KeyboardInterrupt:
break
if item is None:
break
@ -88,18 +93,32 @@ class ConnectorManager:
connectors: Optional[List[Connector]] = _connectors.get(space.name)
if connectors is None or refresh_connector_cache:
if connectors is not None:
loop.run_until_complete(close_connectors(connectors))
with scope(space=space):
connectors: List[Connector] = [HomeAssistant(config) for config in space.homeassistantconfig_set.all() if config.enabled]
_connectors[space.name] = connectors
if len(connectors) == 0 or refresh_connector_cache:
return
worker_queue.task_done()
continue
loop.run_until_complete(run_connectors(connectors, space, item.instance, item.actionType))
worker_queue.task_done()
loop.close()
async def close_connectors(connectors: List[Connector]):
tasks: List[Task] = [asyncio.create_task(connector.close()) for connector in connectors]
try:
await asyncio.gather(*tasks, return_exceptions=False)
except BaseException:
logging.exception("received an exception while closing one of the connectors")
async def run_connectors(connectors: List[Connector], space: Space, instance: REGISTERED_CLASSES, action_type: ActionType):
tasks: List[Task] = list()

View File

@ -1,7 +1,5 @@
import logging
from collections import defaultdict
from logging import Logger
from typing import Dict, Any, Optional
from homeassistant_api import Client, HomeassistantAPIError, Domain
@ -58,16 +56,20 @@ class HomeAssistant(Connector):
except HomeassistantAPIError as err:
self._logger.warning(f"[HomeAssistant {self._config.name}] Received an exception from the api: {err=}, {type(err)=}")
async def close(self) -> None:
await self._client.async_cache_session.close()
def _format_shopping_list_entry(shopping_list_entry: ShoppingListEntry):
item = shopping_list_entry.food.name
if shopping_list_entry.amount > 0:
item += f" ({shopping_list_entry.amount:.2f}".rstrip('0').rstrip('.')
if shopping_list_entry.unit and shopping_list_entry.unit.base_unit and len(shopping_list_entry.unit.base_unit) > 0:
item += f" ({shopping_list_entry.amount} {shopping_list_entry.unit.base_unit})"
item += f" {shopping_list_entry.unit.base_unit})"
elif shopping_list_entry.unit and shopping_list_entry.unit.name and len(shopping_list_entry.unit.name) > 0:
item += f" ({shopping_list_entry.amount} {shopping_list_entry.unit.name})"
item += f" {shopping_list_entry.unit.name})"
else:
item += f" ({shopping_list_entry.amount})"
item += ")"
description = "Imported by TandoorRecipes"
if shopping_list_entry.created_by.first_name and len(shopping_list_entry.created_by.first_name) > 0: