change formatting a bit, and add async close method

This commit is contained in:
Mikhail Epifanov 2024-01-13 13:43:08 +01:00
parent c7dd61e239
commit 87ede4b9cc
No known key found for this signature in database
3 changed files with 36 additions and 11 deletions

View File

@ -16,5 +16,9 @@ class Connector(ABC):
async def on_shopping_list_entry_deleted(self, space: Space, instance: ShoppingListEntry) -> None: async def on_shopping_list_entry_deleted(self, space: Space, instance: ShoppingListEntry) -> None:
pass pass
@abstractmethod
async def close(self) -> None:
pass
# TODO: Maybe add an 'IsEnabled(self) -> Bool' to here # TODO: Maybe add an 'IsEnabled(self) -> Bool' to here
# TODO: Add Recipes & possibly Meal Place listeners/hooks (And maybe more?) # TODO: Add Recipes & possibly Meal Place listeners/hooks (And maybe more?)

View File

@ -5,7 +5,7 @@ import queue
from asyncio import Task from asyncio import Task
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from multiprocessing import Queue from multiprocessing import JoinableQueue
from types import UnionType from types import UnionType
from typing import List, Any, Dict, Optional from typing import List, Any, Dict, Optional
@ -35,11 +35,11 @@ class Work:
class ConnectorManager: class ConnectorManager:
_queue: Queue _queue: JoinableQueue
_listening_to_classes = REGISTERED_CLASSES | CONNECTOR_UPDATE_CLASSES _listening_to_classes = REGISTERED_CLASSES | CONNECTOR_UPDATE_CLASSES
def __init__(self): def __init__(self):
self._queue = multiprocessing.Queue(maxsize=QUEUE_MAX_SIZE) self._queue = multiprocessing.JoinableQueue(maxsize=QUEUE_MAX_SIZE)
self._worker = multiprocessing.Process(target=self.worker, args=(self._queue,), daemon=True) self._worker = multiprocessing.Process(target=self.worker, args=(self._queue,), daemon=True)
self._worker.start() self._worker.start()
@ -60,14 +60,16 @@ class ConnectorManager:
try: try:
self._queue.put_nowait(Work(instance, action_type)) self._queue.put_nowait(Work(instance, action_type))
except queue.Full: except queue.Full:
logging.info("queue was full, so skipping %s", instance)
return return
def stop(self): def stop(self):
self._queue.join()
self._queue.close() self._queue.close()
self._worker.join() self._worker.join()
@staticmethod @staticmethod
def worker(worker_queue: Queue): def worker(worker_queue: JoinableQueue):
from django.db import connections from django.db import connections
connections.close_all() connections.close_all()
@ -77,7 +79,10 @@ class ConnectorManager:
_connectors: Dict[str, List[Connector]] = dict() _connectors: Dict[str, List[Connector]] = dict()
while True: while True:
item: Optional[Work] = worker_queue.get() try:
item: Optional[Work] = worker_queue.get()
except KeyboardInterrupt:
break
if item is None: if item is None:
break break
@ -88,18 +93,32 @@ class ConnectorManager:
connectors: Optional[List[Connector]] = _connectors.get(space.name) connectors: Optional[List[Connector]] = _connectors.get(space.name)
if connectors is None or refresh_connector_cache: if connectors is None or refresh_connector_cache:
if connectors is not None:
loop.run_until_complete(close_connectors(connectors))
with scope(space=space): with scope(space=space):
connectors: List[Connector] = [HomeAssistant(config) for config in space.homeassistantconfig_set.all() if config.enabled] connectors: List[Connector] = [HomeAssistant(config) for config in space.homeassistantconfig_set.all() if config.enabled]
_connectors[space.name] = connectors _connectors[space.name] = connectors
if len(connectors) == 0 or refresh_connector_cache: if len(connectors) == 0 or refresh_connector_cache:
return worker_queue.task_done()
continue
loop.run_until_complete(run_connectors(connectors, space, item.instance, item.actionType)) loop.run_until_complete(run_connectors(connectors, space, item.instance, item.actionType))
worker_queue.task_done()
loop.close() loop.close()
async def close_connectors(connectors: List[Connector]):
tasks: List[Task] = [asyncio.create_task(connector.close()) for connector in connectors]
try:
await asyncio.gather(*tasks, return_exceptions=False)
except BaseException:
logging.exception("received an exception while closing one of the connectors")
async def run_connectors(connectors: List[Connector], space: Space, instance: REGISTERED_CLASSES, action_type: ActionType): async def run_connectors(connectors: List[Connector], space: Space, instance: REGISTERED_CLASSES, action_type: ActionType):
tasks: List[Task] = list() tasks: List[Task] = list()

View File

@ -1,7 +1,5 @@
import logging import logging
from collections import defaultdict
from logging import Logger from logging import Logger
from typing import Dict, Any, Optional
from homeassistant_api import Client, HomeassistantAPIError, Domain from homeassistant_api import Client, HomeassistantAPIError, Domain
@ -58,16 +56,20 @@ class HomeAssistant(Connector):
except HomeassistantAPIError as err: except HomeassistantAPIError as err:
self._logger.warning(f"[HomeAssistant {self._config.name}] Received an exception from the api: {err=}, {type(err)=}") self._logger.warning(f"[HomeAssistant {self._config.name}] Received an exception from the api: {err=}, {type(err)=}")
async def close(self) -> None:
await self._client.async_cache_session.close()
def _format_shopping_list_entry(shopping_list_entry: ShoppingListEntry): def _format_shopping_list_entry(shopping_list_entry: ShoppingListEntry):
item = shopping_list_entry.food.name item = shopping_list_entry.food.name
if shopping_list_entry.amount > 0: if shopping_list_entry.amount > 0:
item += f" ({shopping_list_entry.amount:.2f}".rstrip('0').rstrip('.')
if shopping_list_entry.unit and shopping_list_entry.unit.base_unit and len(shopping_list_entry.unit.base_unit) > 0: if shopping_list_entry.unit and shopping_list_entry.unit.base_unit and len(shopping_list_entry.unit.base_unit) > 0:
item += f" ({shopping_list_entry.amount} {shopping_list_entry.unit.base_unit})" item += f" {shopping_list_entry.unit.base_unit})"
elif shopping_list_entry.unit and shopping_list_entry.unit.name and len(shopping_list_entry.unit.name) > 0: elif shopping_list_entry.unit and shopping_list_entry.unit.name and len(shopping_list_entry.unit.name) > 0:
item += f" ({shopping_list_entry.amount} {shopping_list_entry.unit.name})" item += f" {shopping_list_entry.unit.name})"
else: else:
item += f" ({shopping_list_entry.amount})" item += ")"
description = "Imported by TandoorRecipes" description = "Imported by TandoorRecipes"
if shopping_list_entry.created_by.first_name and len(shopping_list_entry.created_by.first_name) > 0: if shopping_list_entry.created_by.first_name and len(shopping_list_entry.created_by.first_name) > 0: