change formatting a bit, and add async close method
This commit is contained in:
parent
c7dd61e239
commit
87ede4b9cc
@ -16,5 +16,9 @@ class Connector(ABC):
|
||||
async def on_shopping_list_entry_deleted(self, space: Space, instance: ShoppingListEntry) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
# TODO: Maybe add an 'IsEnabled(self) -> Bool' to here
|
||||
# TODO: Add Recipes & possibly Meal Place listeners/hooks (And maybe more?)
|
||||
|
@ -5,7 +5,7 @@ import queue
|
||||
from asyncio import Task
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from multiprocessing import Queue
|
||||
from multiprocessing import JoinableQueue
|
||||
from types import UnionType
|
||||
from typing import List, Any, Dict, Optional
|
||||
|
||||
@ -35,11 +35,11 @@ class Work:
|
||||
|
||||
|
||||
class ConnectorManager:
|
||||
_queue: Queue
|
||||
_queue: JoinableQueue
|
||||
_listening_to_classes = REGISTERED_CLASSES | CONNECTOR_UPDATE_CLASSES
|
||||
|
||||
def __init__(self):
|
||||
self._queue = multiprocessing.Queue(maxsize=QUEUE_MAX_SIZE)
|
||||
self._queue = multiprocessing.JoinableQueue(maxsize=QUEUE_MAX_SIZE)
|
||||
self._worker = multiprocessing.Process(target=self.worker, args=(self._queue,), daemon=True)
|
||||
self._worker.start()
|
||||
|
||||
@ -60,14 +60,16 @@ class ConnectorManager:
|
||||
try:
|
||||
self._queue.put_nowait(Work(instance, action_type))
|
||||
except queue.Full:
|
||||
logging.info("queue was full, so skipping %s", instance)
|
||||
return
|
||||
|
||||
def stop(self):
|
||||
self._queue.join()
|
||||
self._queue.close()
|
||||
self._worker.join()
|
||||
|
||||
@staticmethod
|
||||
def worker(worker_queue: Queue):
|
||||
def worker(worker_queue: JoinableQueue):
|
||||
from django.db import connections
|
||||
connections.close_all()
|
||||
|
||||
@ -77,7 +79,10 @@ class ConnectorManager:
|
||||
_connectors: Dict[str, List[Connector]] = dict()
|
||||
|
||||
while True:
|
||||
item: Optional[Work] = worker_queue.get()
|
||||
try:
|
||||
item: Optional[Work] = worker_queue.get()
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
if item is None:
|
||||
break
|
||||
|
||||
@ -88,18 +93,32 @@ class ConnectorManager:
|
||||
connectors: Optional[List[Connector]] = _connectors.get(space.name)
|
||||
|
||||
if connectors is None or refresh_connector_cache:
|
||||
if connectors is not None:
|
||||
loop.run_until_complete(close_connectors(connectors))
|
||||
|
||||
with scope(space=space):
|
||||
connectors: List[Connector] = [HomeAssistant(config) for config in space.homeassistantconfig_set.all() if config.enabled]
|
||||
_connectors[space.name] = connectors
|
||||
|
||||
if len(connectors) == 0 or refresh_connector_cache:
|
||||
return
|
||||
worker_queue.task_done()
|
||||
continue
|
||||
|
||||
loop.run_until_complete(run_connectors(connectors, space, item.instance, item.actionType))
|
||||
worker_queue.task_done()
|
||||
|
||||
loop.close()
|
||||
|
||||
|
||||
async def close_connectors(connectors: List[Connector]):
|
||||
tasks: List[Task] = [asyncio.create_task(connector.close()) for connector in connectors]
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except BaseException:
|
||||
logging.exception("received an exception while closing one of the connectors")
|
||||
|
||||
|
||||
async def run_connectors(connectors: List[Connector], space: Space, instance: REGISTERED_CLASSES, action_type: ActionType):
|
||||
tasks: List[Task] = list()
|
||||
|
||||
|
@ -1,7 +1,5 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from logging import Logger
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from homeassistant_api import Client, HomeassistantAPIError, Domain
|
||||
|
||||
@ -58,16 +56,20 @@ class HomeAssistant(Connector):
|
||||
except HomeassistantAPIError as err:
|
||||
self._logger.warning(f"[HomeAssistant {self._config.name}] Received an exception from the api: {err=}, {type(err)=}")
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._client.async_cache_session.close()
|
||||
|
||||
|
||||
def _format_shopping_list_entry(shopping_list_entry: ShoppingListEntry):
|
||||
item = shopping_list_entry.food.name
|
||||
if shopping_list_entry.amount > 0:
|
||||
item += f" ({shopping_list_entry.amount:.2f}".rstrip('0').rstrip('.')
|
||||
if shopping_list_entry.unit and shopping_list_entry.unit.base_unit and len(shopping_list_entry.unit.base_unit) > 0:
|
||||
item += f" ({shopping_list_entry.amount} {shopping_list_entry.unit.base_unit})"
|
||||
item += f" {shopping_list_entry.unit.base_unit})"
|
||||
elif shopping_list_entry.unit and shopping_list_entry.unit.name and len(shopping_list_entry.unit.name) > 0:
|
||||
item += f" ({shopping_list_entry.amount} {shopping_list_entry.unit.name})"
|
||||
item += f" {shopping_list_entry.unit.name})"
|
||||
else:
|
||||
item += f" ({shopping_list_entry.amount})"
|
||||
item += ")"
|
||||
|
||||
description = "Imported by TandoorRecipes"
|
||||
if shopping_list_entry.created_by.first_name and len(shopping_list_entry.created_by.first_name) > 0:
|
||||
|
Loading…
Reference in New Issue
Block a user