Coverage for custom_components/supernotify/hass_api.py: 88%
458 statements
« prev ^ index » next coverage.py v7.10.6, created at 2026-04-01 15:06 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2026-04-01 15:06 +0000
1from __future__ import annotations
3import logging
4from dataclasses import dataclass
5from functools import partial
6from typing import TYPE_CHECKING, Any
8import voluptuous as vol
9from homeassistant.components.person import ATTR_USER_ID
10from homeassistant.const import CONF_ACTION, CONF_DEVICE_ID
11from homeassistant.helpers.aiohttp_client import async_get_clientsession
12from homeassistant.helpers.event import async_track_state_change_event, async_track_time_change
13from homeassistant.util import slugify
15if TYPE_CHECKING:
16 import asyncio
17 from collections.abc import Callable, Iterable, Iterator
19 import aiohttp
20 from homeassistant.core import CALLBACK_TYPE, HomeAssistant, Service, ServiceResponse, State
21 from homeassistant.helpers.entity import Entity
22 from homeassistant.helpers.entity_registry import EntityRegistry
23 from homeassistant.helpers.typing import ConfigType
24 from homeassistant.util.event_type import EventType
26 from .schema import ConditionsFunc
28import socket
29import threading
30from contextlib import contextmanager
31from typing import TYPE_CHECKING, cast
33import homeassistant.components.trace
34from homeassistant.components import mqtt
35from homeassistant.components.group import expand_entity_ids
36from homeassistant.components.trace.const import DATA_TRACE
37from homeassistant.components.trace.models import ActionTrace
38from homeassistant.components.trace.util import async_store_trace
39from homeassistant.core import Context as HomeAssistantContext
40from homeassistant.core import HomeAssistant, SupportsResponse
41from homeassistant.exceptions import ConditionError, ConditionErrorContainer, IntegrationError
42from homeassistant.helpers import condition as condition
43from homeassistant.helpers import device_registry as dr
44from homeassistant.helpers import entity_registry as er
45from homeassistant.helpers import issue_registry as ir
46from homeassistant.helpers.json import json_dumps
47from homeassistant.helpers.network import get_url
48from homeassistant.helpers.template import Template
49from homeassistant.helpers.trace import trace_get, trace_path
50from homeassistant.helpers.typing import ConfigType
52from . import DOMAIN
53from .const import CONF_DEVICE_LABELS, CONF_DEVICE_TRACKER, CONF_MOBILE_APP_ID
54from .model import ConditionVariables, SelectionRule
56if TYPE_CHECKING:
57 from homeassistant.core import HomeAssistant
58 from homeassistant.helpers.device_registry import DeviceEntry, DeviceRegistry
59 from homeassistant.helpers.typing import ConfigType
61# avoid importing from homeassistant.components.mobile_app.const and triggering dependency chain
63CONF_USER_ID = "user_id"
64ATTR_OS_NAME = "os_name"
65ATTR_OS_VERSION = "os_version"
66ATTR_APP_VERSION = "app_version"
67ATTR_DEVICE_NAME = "device_name"
68ATTR_MANUFACTURER = "manufacturer"
69ATTR_MODEL = "model"
71_LOGGER = logging.getLogger(__name__)
74@dataclass
75class DeviceInfo:
76 device_id: str
77 device_labels: list[str] | None = None
78 mobile_app_id: str | None = None
79 device_name: str | None = None
80 device_tracker: str | None = None
81 action: str | None = None
82 user_id: str | None = None
83 area_id: str | None = None
84 manufacturer: str | None = None
85 model: str | None = None
86 os_name: str | None = None
87 os_version: str | None = None
88 app_version: str | None = None
89 identifiers: set[tuple[str, str]] | None = None
91 def as_dict(self) -> dict[str, str | list[str] | None]:
92 return {
93 CONF_MOBILE_APP_ID: self.mobile_app_id,
94 ATTR_DEVICE_NAME: self.device_name,
95 CONF_DEVICE_ID: self.device_id,
96 CONF_USER_ID: self.user_id,
97 CONF_DEVICE_TRACKER: self.device_tracker,
98 CONF_ACTION: self.action,
99 ATTR_OS_NAME: self.os_name,
100 ATTR_OS_VERSION: self.os_version,
101 ATTR_APP_VERSION: self.app_version,
102 ATTR_MANUFACTURER: self.manufacturer,
103 ATTR_MODEL: self.model,
104 CONF_DEVICE_LABELS: self.device_labels,
105 }
107 def __eq__(self, other: Any) -> bool:
108 """Test support"""
109 return other is not None and other.as_dict() == self.as_dict()
112class HomeAssistantAPI:
113 def __init__(self, hass: HomeAssistant) -> None:
114 self._hass: HomeAssistant = hass
115 self.internal_url: str = ""
116 self.external_url: str = ""
117 self.language: str = ""
118 self.hass_name: str = "!UNDEFINED!"
119 self._entity_registry: er.EntityRegistry | None = None
120 self._device_registry: dr.DeviceRegistry | None = None
121 self._service_info: dict[tuple[str, str], Any] = {}
122 self.unsubscribes: list[CALLBACK_TYPE] = []
123 self.mobile_apps_by_tracker: dict[str, DeviceInfo] = {}
124 self.mobile_apps_by_app_id: dict[str, DeviceInfo] = {}
125 self.mobile_apps_by_device_id: dict[str, DeviceInfo] = {}
126 self.mobile_apps_by_user_id: dict[str, list[DeviceInfo]] = {}
128 def initialize(self) -> None:
129 self.hass_name = self._hass.config.location_name
130 self.language = self._hass.config.language
131 try:
132 self.internal_url = get_url(self._hass, prefer_external=False)
133 except Exception as e:
134 self.internal_url = f"http://{socket.gethostname()}"
135 _LOGGER.warning("SUPERNOTIFY could not get internal hass url, defaulting to %s: %s", self.internal_url, e)
136 try:
137 self.external_url = get_url(self._hass, prefer_external=True)
138 except Exception as e:
139 _LOGGER.warning("SUPERNOTIFY could not get external hass url, defaulting to internal url: %s", e)
140 self.external_url = self.internal_url
142 self.build_mobile_app_cache()
144 _LOGGER.debug(
145 "SUPERNOTIFY Configured for HomeAssistant instance %s at %s , %s",
146 self.hass_name,
147 self.internal_url,
148 self.external_url,
149 )
151 if not self.internal_url or not self.internal_url.startswith("http"):
152 _LOGGER.warning("SUPERNOTIFY invalid internal hass url %s", self.internal_url)
154 def disconnect(self) -> None:
155 while self.unsubscribes:
156 unsub = self.unsubscribes.pop()
157 try:
158 _LOGGER.debug("SUPERNOTIFY unsubscribing: %s", unsub)
159 unsub()
160 except Exception as e:
161 _LOGGER.error("SUPERNOTIFY failed to unsubscribe: %s", e)
162 _LOGGER.debug("SUPERNOTIFY disconnection complete")
164 def subscribe_event(self, event: EventType | str, callback: Callable) -> None:
165 self.unsubscribes.append(self._hass.bus.async_listen(event, callback))
167 def subscribe_state(self, entity_ids: str | Iterable[str], callback: Callable) -> None:
168 self.unsubscribes.append(async_track_state_change_event(self._hass, entity_ids, callback))
170 def subscribe_time(self, hour: int, minute: int, second: int, callback: Callable) -> None:
171 self.unsubscribes.append(async_track_time_change(self._hass, callback, hour=hour, minute=minute, second=second))
173 def in_hass_loop(self) -> bool:
174 return self._hass is not None and self._hass.loop_thread_id == threading.get_ident()
176 def get_state(self, entity_id: str) -> State | None:
177 return self._hass.states.get(entity_id)
179 def is_state(self, entity_id: str, state: str) -> bool:
180 return self._hass.states.is_state(entity_id, state)
182 def set_state(self, entity_id: str, state: str | int | bool, attributes: dict[str, Any] | None = None) -> None:
183 if self.in_hass_loop():
184 self._hass.states.async_set(entity_id, str(state), attributes=attributes)
185 else:
186 self._hass.states.set(entity_id, str(state), attributes=attributes)
188 def has_service(self, domain: str, service: str) -> bool:
189 return self._hass.services.has_service(domain, service)
191 def entity_ids_for_domain(self, domain: str) -> list[str]:
192 return self._hass.states.async_entity_ids(domain)
194 def domain_entity(self, domain: str, entity_id: str) -> Entity | None:
195 # TODO: must be a better hass method than this
196 return self._hass.data.get(domain, {}).get_entity(entity_id)
198 def create_job(self, func: Callable, *args: Any) -> asyncio.Future[Any]:
199 """Wrap a blocking function call in a HomeAssistant awaitable job"""
200 return self._hass.async_add_executor_job(func, *args)
202 def fire_event(self, event_name: str, event_data: dict[str, Any] | None = None) -> None:
203 self._hass.bus.async_fire(event_name, event_data)
205 async def call_service(
206 self,
207 domain: str,
208 service: str,
209 service_data: dict[str, Any] | None = None,
210 target: dict[str, Any] | None = None,
211 return_response: bool | None = None,
212 blocking: bool | None = None,
213 debug: bool = False,
214 ) -> ServiceResponse | None:
216 if return_response is None or blocking is None:
217 # unknown service, for example defined in generic action, check if it supports response
218 supports_response: SupportsResponse = self.service_info(domain, service)
219 if supports_response == SupportsResponse.NONE:
220 return_response = False
221 elif supports_response == SupportsResponse.ONLY:
222 return_response = True
223 else:
224 return_response = debug
225 blocking = return_response or debug
227 response: ServiceResponse | None = await self._hass.services.async_call(
228 domain,
229 service,
230 service_data=service_data,
231 blocking=blocking,
232 context=None,
233 target=target,
234 return_response=return_response,
235 )
236 if response is not None and debug:
237 _LOGGER.info("SUPERNOTIFY Service %s.%s response: %s", domain, service, response)
238 return response
240 def coerce_schema(self, domain: str, service: str, data: ConfigType) -> ConfigType:
241 if not data:
242 return data
243 try:
244 if (domain, service) not in self._service_info:
245 self.service_info(domain, service)
246 service_info = self._service_info.get((domain, service))
247 if not service_info:
248 _LOGGER.info("SUPERNOTIFY No service found to pre-validate action data for %s.%s", domain, service)
249 return data
250 if not service_info.get("schema"):
251 _LOGGER.info("SUPERNOTIFY No vol schema found to pre-validate action data for %s.%s", domain, service)
252 return data
253 service_schema = service_info["schema"]
255 while service_schema is not None and not isinstance(service_schema, vol.Schema):
256 # e.g. entity services get schema wrapped in an vol.All
257 if hasattr(service_schema, "validators") and hasattr(service_schema.validators, "__iter__"):
258 # e.g. vol.All — strip extras using first dict Schema sub-validator only
259 # (don't run the full chain; other validators may require target fields not in data)
260 service_schema = next(
261 (v for v in service_schema.validators if isinstance(v, vol.Schema) or hasattr(v, "validators")), None
262 )
263 else:
264 service_schema = None
265 if not isinstance(service_schema, vol.Schema):
266 service_schema = None
267 _LOGGER.info("SUPERNOTIFY Unable to find schema for %s.%s", domain, service)
269 if service_schema:
270 coercing_schema = service_schema.extend(
271 {},
272 extra=vol.REMOVE_EXTRA if service_schema.extra == vol.PREVENT_EXTRA else service_schema.extra,
273 required=service_schema.required,
274 )
275 cleaned = coercing_schema(data)
276 else:
277 return data
278 if cleaned != data:
279 _LOGGER.debug("SUPERNOTIFY Coerced data for %s.%s from %s->%s", domain, service, data, cleaned)
280 return cleaned
281 except Exception:
282 _LOGGER.exception("SUPERNOTIFY Unable to coerce %s.%s schema for %s", domain, service, data)
283 return data
285 def service_info(self, domain: str, service: str) -> SupportsResponse:
286 supports_response: SupportsResponse | None = None
287 try:
288 if (domain, service) not in self._service_info:
289 service_objs: dict[str, Service] = self._hass.services.async_services_for_domain(domain)
290 service_obj: Service | None = service_objs.get(service)
291 if service_obj:
292 self._service_info[domain, service] = {
293 "supports_response": service_obj.supports_response,
294 "schema": service_obj.schema,
295 }
296 service_info: dict[str, Any] = self._service_info.get((domain, service), {})
297 supports_response = service_info.get("supports_response")
298 if supports_response is None:
299 _LOGGER.debug("SUPERNOTIFY Unable to find service info for %s.%s", domain, service)
301 except Exception as e:
302 _LOGGER.warning("SUPERNOTIFY Unable to get service info for %s.%s: %s", domain, service, e)
303 return supports_response or SupportsResponse.NONE # default to no response
305 def find_service(self, domain: str, module: str) -> str | None:
306 try:
307 service_objs: dict[str, Service] = self._hass.services.async_services_for_domain(domain)
308 if service_objs:
309 for service, domain_obj in service_objs.items():
310 if domain_obj.job and domain_obj.job.target:
311 target_module: str | None = (
312 domain_obj.job.target.__self__.__module__
313 if hasattr(domain_obj.job.target, "__self__")
314 else domain_obj.job.target.__module__
315 )
316 if target_module == module:
317 _LOGGER.debug("SUPERNOTIFY Found service %s for domain %s", domain, service)
318 return f"{domain}.{service}"
320 _LOGGER.debug("SUPERNOTIFY Unable to find service for %s", domain)
321 except Exception as e:
322 _LOGGER.warning("SUPERNOTIFY Unable to find service for %s: %s", domain, e)
323 return None
325 def http_session(self) -> aiohttp.ClientSession:
326 """Client aiohttp session for async web requests"""
327 return async_get_clientsession(self._hass)
329 def expand_group(self, entity_ids: str | list[str]) -> list[str]:
330 return expand_entity_ids(self._hass, entity_ids)
332 def template(self, template_format: str) -> Template:
333 return Template(template_format, self._hass)
335 async def trace_conditions(
336 self,
337 conditions: ConditionsFunc,
338 condition_variables: ConditionVariables,
339 trace_name: str | None = None,
340 ) -> tuple[bool | None, ActionTrace | None]:
342 result: bool | None = None
343 this_trace: ActionTrace | None = None
344 if DATA_TRACE not in self._hass.data:
345 _LOGGER.warning("SUPERNOTIFY tracing not configured, attempting to set up")
347 await homeassistant.components.trace.async_setup(self._hass, {}) # type: ignore
348 with trace_action(self._hass, trace_name or "anon_condition") as cond_trace:
349 cond_trace.set_trace(trace_get())
350 this_trace = cond_trace
351 with trace_path(["condition", "conditions"]) as _tp:
352 result = self.evaluate_conditions(conditions, condition_variables)
353 _LOGGER.debug(cond_trace.as_dict())
354 return result, this_trace
356 async def build_conditions(
357 self, condition_config: list[ConfigType], strict: bool = False, validate: bool = False, name: str = DOMAIN
358 ) -> ConditionsFunc | None:
359 capturing_logger: ConditionErrorLoggingAdaptor = ConditionErrorLoggingAdaptor(_LOGGER)
360 condition_variables: ConditionVariables = ConditionVariables()
361 cond_list: list[ConfigType]
362 try:
363 if validate:
364 cond_list = cast(
365 "list[ConfigType]", await condition.async_validate_conditions_config(self._hass, condition_config)
366 )
367 else:
368 cond_list = condition_config
369 except Exception as e:
370 _LOGGER.exception("SUPERNOTIFY Conditions validation failed: %s", e)
371 raise
372 try:
373 if strict:
374 force_strict_template_mode(cond_list, undo=False)
376 test: ConditionsFunc = await condition.async_conditions_from_config(
377 self._hass, cond_list, cast("logging.Logger", capturing_logger), name
378 )
379 if test is None:
380 raise IntegrationError(f"Invalid condition {condition_config}")
381 test(condition_variables.as_dict())
382 return test
383 except Exception as e:
384 _LOGGER.exception("SUPERNOTIFY Conditions eval failed: %s", e)
385 raise
386 finally:
387 if strict:
388 force_strict_template_mode(condition_config, undo=True)
389 if strict and capturing_logger.condition_errors and len(capturing_logger.condition_errors) > 0:
390 for exception in capturing_logger.condition_errors:
391 _LOGGER.warning("SUPERNOTIFY Invalid condition %s:%s", condition_config, exception)
392 raise capturing_logger.condition_errors[0]
394 def evaluate_conditions(
395 self,
396 conditions: ConditionsFunc,
397 condition_variables: ConditionVariables,
398 ) -> bool | None:
399 try:
400 if not condition_variables:
401 _LOGGER.warning("SUPERNOTIFY No cond vars provided for condition")
402 return conditions(condition_variables.as_dict() if condition_variables is not None else None)
403 except Exception as e:
404 _LOGGER.error("SUPERNOTIFY Condition eval failed: %s", e)
405 raise
407 def abs_url(self, fragment: str | None, prefer_external: bool = True) -> str | None:
408 base_url = self.external_url if prefer_external else self.internal_url
409 if fragment:
410 if fragment.startswith("http"):
411 return fragment
412 if fragment.startswith("/"):
413 return base_url + fragment
414 return base_url + "/" + fragment
415 return None
417 def raise_issue(
418 self,
419 issue_id: str,
420 issue_key: str,
421 issue_map: dict[str, str],
422 severity: ir.IssueSeverity = ir.IssueSeverity.WARNING,
423 learn_more_url: str = "https://supernotify.rhizomatics.org.uk",
424 is_fixable: bool = False,
425 ) -> None:
426 ir.async_create_issue(
427 self._hass,
428 DOMAIN,
429 issue_id,
430 translation_key=issue_key,
431 translation_placeholders=issue_map,
432 severity=severity,
433 learn_more_url=learn_more_url,
434 is_fixable=is_fixable,
435 )
437 def mobile_app_by_tracker(self, device_tracker: str) -> DeviceInfo | None:
438 return self.mobile_apps_by_tracker.get(device_tracker)
440 def mobile_app_by_id(self, mobile_app_id: str) -> DeviceInfo | None:
441 return self.mobile_apps_by_app_id.get(mobile_app_id)
443 def mobile_app_by_device_id(self, device_id: str) -> DeviceInfo | None:
444 return self.mobile_apps_by_device_id.get(device_id)
446 def mobile_app_by_user_id(self, user_id: str) -> list[DeviceInfo] | None:
447 return self.mobile_apps_by_user_id.get(user_id)
449 def build_mobile_app_cache(self) -> None:
450 """All enabled mobile apps"""
451 ent_reg: EntityRegistry | None = self.entity_registry()
452 if not ent_reg:
453 _LOGGER.warning("SUPERNOTIFY Unable to discover devices for - no entity registry found")
454 return
456 found: int = 0
457 complete: int = 0
458 for mobile_app_info in self.discover_devices("mobile_app"):
459 try:
460 mobile_app_id: str = f"mobile_app_{slugify(mobile_app_info.device_name)}"
461 device_tracker: str | None = None
462 notify_action: str | None = None
463 if self.has_service("notify", mobile_app_id):
464 notify_action = f"notify.{mobile_app_id}"
465 else:
466 _LOGGER.warning("SUPERNOTIFY Unable to find notify action <%s>", mobile_app_id)
468 registry_entries = ent_reg.entities.get_entries_for_device_id(mobile_app_info.device_id)
469 for reg_entry in registry_entries:
470 if reg_entry.platform == "mobile_app" and reg_entry.domain == "device_tracker":
471 device_tracker = reg_entry.entity_id
473 if device_tracker and notify_action:
474 complete += 1
476 mobile_app_info.mobile_app_id = mobile_app_id
477 mobile_app_info.device_tracker = device_tracker
478 mobile_app_info.action = notify_action
480 found += 1
481 self.mobile_apps_by_app_id[mobile_app_id] = mobile_app_info
482 self.mobile_apps_by_device_id[mobile_app_info.device_id] = mobile_app_info
483 if device_tracker:
484 self.mobile_apps_by_tracker[device_tracker] = mobile_app_info
485 if mobile_app_info.user_id is not None:
486 self.mobile_apps_by_user_id.setdefault(mobile_app_info.user_id, [])
487 self.mobile_apps_by_user_id[mobile_app_info.user_id].append(mobile_app_info)
489 except Exception as e:
490 _LOGGER.error("SUPERNOTIFY Failure examining device %s: %s", mobile_app_info, e)
492 _LOGGER.info(f"SUPERNOTIFY Found {found} enabled mobile app devices, {complete} complete config")
494 def device_config_info(self, device: DeviceEntry) -> dict[str, str | None]:
495 results: dict[str, str | None] = {ATTR_OS_NAME: None, ATTR_OS_VERSION: None, CONF_USER_ID: None, ATTR_APP_VERSION: None}
496 for config_entry_id in device.config_entries:
497 config_entry = self._hass.config_entries.async_get_entry(config_entry_id)
498 if config_entry and config_entry.data:
499 for attr in results:
500 results[attr] = config_entry.data.get(attr) or results[attr]
501 return results
503 def discover_devices(
504 self,
505 discover_domain: str,
506 device_model_select: SelectionRule | None = None,
507 device_manufacturer_select: SelectionRule | None = None,
508 device_os_select: SelectionRule | None = None,
509 device_area_select: SelectionRule | None = None,
510 device_label_select: SelectionRule | None = None,
511 ) -> list[DeviceInfo]:
512 devices: list[DeviceInfo] = []
513 dev_reg: DeviceRegistry | None = self.device_registry()
514 if dev_reg is None or not hasattr(dev_reg, "devices"):
515 _LOGGER.warning(f"SUPERNOTIFY Unable to discover devices for {discover_domain} - no device registry found")
516 return []
518 all_devs = enabled_devs = found_devs = skipped_devs = 0
519 for dev in dev_reg.devices.values():
520 all_devs += 1
522 if dev.disabled:
523 _LOGGER.debug("SUPERNOTIFY excluded disabled device %s", dev.name)
524 else:
525 enabled_devs += 1
526 for identifier in dev.identifiers:
527 if identifier and len(identifier) > 1 and identifier[0] == discover_domain:
528 _LOGGER.debug("SUPERNOTIFY discovered %s device %s for id %s", dev.model, dev.name, identifier)
529 found_devs += 1
530 if device_model_select is not None and not device_model_select.match(dev.model):
531 _LOGGER.debug("SUPERNOTIFY Skipped dev %s, no model %s match", dev.name, dev.model)
532 skipped_devs += 1
533 continue
534 if device_manufacturer_select is not None and not device_manufacturer_select.match(dev.manufacturer):
535 _LOGGER.debug("SUPERNOTIFY Skipped dev %s, no manufacturer %s match", dev.name, dev.manufacturer)
536 skipped_devs += 1
537 continue
538 device_config_info = self.device_config_info(dev)
539 if device_os_select is not None and not device_os_select.match(device_config_info[ATTR_OS_NAME]):
540 _LOGGER.debug(
541 "SUPERNOTIFY Skipped dev %s, no OS %s match", dev.name, device_config_info[ATTR_OS_NAME]
542 )
543 skipped_devs += 1
544 continue
545 if device_area_select is not None and not device_area_select.match(dev.area_id):
546 _LOGGER.debug("SUPERNOTIFY Skipped dev %s, no area %s match", dev.name, dev.area_id)
547 skipped_devs += 1
548 continue
549 if device_label_select is not None and not device_label_select.match(dev.labels):
550 _LOGGER.debug("SUPERNOTIFY Skipped dev %s, no label %s match", dev.name, dev.labels)
551 skipped_devs += 1
552 continue
553 devices.append(
554 DeviceInfo(
555 device_id=dev.id,
556 device_name=dev.name,
557 manufacturer=dev.manufacturer,
558 model=dev.model,
559 area_id=dev.area_id,
560 user_id=device_config_info[ATTR_USER_ID],
561 os_name=device_config_info[ATTR_OS_NAME],
562 os_version=device_config_info[ATTR_OS_VERSION],
563 app_version=device_config_info[ATTR_APP_VERSION],
564 device_labels=list(dev.labels) if dev.labels else [],
565 identifiers=dev.identifiers,
566 )
567 )
569 elif identifier:
570 # HomeKit has triples for identifiers, other domains may behave similarly
571 _LOGGER.debug("SUPERNOTIFY Ignoring device %s id: %s", dev.name, identifier)
572 else:
573 _LOGGER.debug( # type: ignore
574 "SUPERNOTIFY Unexpected %s device %s without id", dev.model, dev.name
575 )
577 _LOGGER.debug(f"SUPERNOTIFY {discover_domain} device discovery, all={all_devs},enabled={enabled_devs} ")
578 _LOGGER.debug(f"SUPERNOTIFY {discover_domain} skipped={skipped_devs}, found={found_devs}")
580 return devices
582 def domain_for_device(self, device_id: str, domains: list[str]) -> str | None:
583 # discover domain from device registry
584 verified_domain: str | None = None
585 device_registry = self.device_registry()
586 if device_registry:
587 device: DeviceEntry | None = device_registry.async_get(device_id)
588 if device:
589 matching_domains = [d for d, _id in device.identifiers if d in domains]
590 if matching_domains:
591 # TODO: limited to first domain found, unlikely to be more
592 return matching_domains[0]
593 _LOGGER.warning(
594 "SUPERNOTIFY A target that looks like a device_id can't be matched to supported integration: %s",
595 device_id,
596 )
597 return verified_domain
599 def entity_registry(self) -> er.EntityRegistry | None:
600 """Hass entity registry is weird, every component ends up creating its own, with a store, subscribing
601 to all entities, so do it once here
602 """ # noqa: D205
603 if self._entity_registry is not None:
604 return self._entity_registry
605 try:
606 self._entity_registry = er.async_get(self._hass)
607 except Exception as e:
608 _LOGGER.warning("SUPERNOTIFY Unable to get entity registry: %s", e)
609 return self._entity_registry
611 def device_registry(self) -> dr.DeviceRegistry | None:
612 """Hass device registry is weird, every component ends up creating its own, with a store, subscribing
613 to all devices, so do it once here
614 """ # noqa: D205
615 if self._device_registry is not None:
616 return self._device_registry
617 try:
618 self._device_registry = dr.async_get(self._hass)
619 except Exception as e:
620 _LOGGER.warning("SUPERNOTIFY Unable to get device registry: %s", e)
621 return self._device_registry
623 async def mqtt_available(self, raise_on_error: bool = True) -> bool:
624 try:
625 return await mqtt.async_wait_for_mqtt_client(self._hass) is True
626 except Exception:
627 _LOGGER.exception("SUPERNOTIFY MQTT integration failed on available check")
628 if raise_on_error:
629 raise
630 return False
632 async def mqtt_publish(
633 self, topic: str, payload: Any = None, qos: int = 0, retain: bool = False, raise_on_error: bool = True
634 ) -> None:
635 try:
636 await mqtt.async_publish(
637 self._hass,
638 topic=topic,
639 payload=json_dumps(payload),
640 qos=qos,
641 retain=retain,
642 )
643 except Exception:
644 _LOGGER.exception(f"SUPERNOTIFY MQTT publish failed to {topic}")
645 if raise_on_error:
646 raise
649class ConditionErrorLoggingAdaptor(logging.LoggerAdapter):
650 def __init__(self, *args: Any, **kwargs: Any) -> None:
651 super().__init__(*args, **kwargs)
652 self.condition_errors: list[ConditionError] = []
654 def capture(self, args: Any) -> None:
655 if args and isinstance(args, list | tuple):
656 for arg in args:
657 if isinstance(arg, ConditionErrorContainer):
658 self.condition_errors.extend(arg.errors)
659 elif isinstance(arg, ConditionError):
660 self.condition_errors.append(arg)
662 def error(self, msg: Any, *args: object, **kwargs: Any) -> None:
663 self.capture(args)
664 self.logger.error(msg, args, kwargs)
666 def warning(self, msg: Any, *args: Any, **kwargs: Any) -> None:
667 self.capture(args)
668 self.logger.warning(msg, args, kwargs)
671def force_strict_template_mode(conditions: list[ConfigType], undo: bool = False) -> None:
672 class TemplateWrapper:
673 def __init__(self, obj: Template) -> None:
674 self._obj = obj
676 def __getattr__(self, name: str) -> Any:
677 if name == "async_render_to_info":
678 return partial(self._obj.async_render_to_info, strict=True)
679 return getattr(self._obj, name)
681 def __setattr__(self, name: str, value: Any) -> None:
682 super().__setattr__(name, value)
684 def __repr__(self) -> str:
685 return self._obj.__repr__() if self._obj else "NULL TEMPLATE"
687 def wrap_template(cond: ConfigType, undo: bool) -> ConfigType:
688 for key, val in cond.items():
689 if not undo and isinstance(val, Template) and hasattr(val, "_env"):
690 cond[key] = TemplateWrapper(val)
691 elif undo and isinstance(val, TemplateWrapper):
692 cond[key] = val._obj
693 elif isinstance(val, dict):
694 wrap_template(val, undo)
695 return cond
697 if conditions is not None:
698 conditions = [wrap_template(condition, undo) for condition in conditions]
701@contextmanager
702def trace_action(
703 hass: HomeAssistant,
704 item_id: str,
705 config: dict[str, Any] | None = None,
706 context: HomeAssistantContext | None = None,
707 stored_traces: int = 5,
708) -> Iterator[ActionTrace]:
709 """Trace execution of a condition"""
710 trace = ActionTrace(item_id, config, None, context or HomeAssistantContext())
711 async_store_trace(hass, trace, stored_traces)
713 try:
714 yield trace
715 except Exception as ex:
716 if item_id:
717 trace.set_error(ex)
718 raise
719 finally:
720 if item_id:
721 trace.finished()