Coverage for custom_components/supernotify/hass_api.py: 84%
237 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-11-21 23:31 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-11-21 23:31 +0000
1from __future__ import annotations
3import logging
4import socket
5import threading
6from contextlib import contextmanager
7from dataclasses import asdict
8from functools import partial
9from typing import TYPE_CHECKING, Any
11from homeassistant.components import mqtt
12from homeassistant.components.group import expand_entity_ids
13from homeassistant.components.trace import async_setup, async_store_trace # type: ignore[attr-defined,unused-ignore]
14from homeassistant.components.trace.const import DATA_TRACE
15from homeassistant.components.trace.models import ActionTrace
16from homeassistant.core import Context as HomeAssistantContext
17from homeassistant.core import HomeAssistant, SupportsResponse
18from homeassistant.helpers.json import json_dumps
19from homeassistant.helpers.template import Template
20from homeassistant.helpers.trace import trace_get, trace_path
21from homeassistant.helpers.typing import ConfigType
23if TYPE_CHECKING:
24 from collections.abc import Iterator
26 from homeassistant.core import ServiceResponse, State
27 from homeassistant.helpers.condition import ConditionCheckerType
29 from .model import ConditionVariables
30from homeassistant.helpers import condition as condition
31from homeassistant.helpers import device_registry, entity_registry
32from homeassistant.helpers import issue_registry as ir
33from homeassistant.helpers.network import get_url
35from . import (
36 DOMAIN,
37)
39if TYPE_CHECKING:
40 from homeassistant.core import HomeAssistant
41 from homeassistant.helpers.device_registry import DeviceEntry, DeviceRegistry
42 from homeassistant.helpers.typing import ConfigType
45_LOGGER = logging.getLogger(__name__)
48class HomeAssistantAPI:
49 def __init__(self, hass: HomeAssistant | None = None) -> None:
50 self._hass = hass
51 self.internal_url: str = ""
52 self.external_url: str = ""
53 self.hass_name: str = "!UNDEFINED!"
54 self._entity_registry: entity_registry.EntityRegistry | None = None
55 self._device_registry: device_registry.DeviceRegistry | None = None
56 self._service_info: dict[tuple[str, str], Any] = {}
58 def initialize(self) -> None:
59 if self._hass:
60 self.hass_name = self._hass.config.location_name
61 try:
62 self.internal_url = get_url(self._hass, prefer_external=False)
63 except Exception as e:
64 self.internal_url = f"http://{socket.gethostname()}"
65 _LOGGER.warning("SUPERNOTIFY could not get internal hass url, defaulting to %s: %s", self.internal_url, e)
66 try:
67 self.external_url = get_url(self._hass, prefer_external=True)
68 except Exception as e:
69 _LOGGER.warning("SUPERNOTIFY could not get external hass url, defaulting to internal url: %s", e)
70 self.external_url = self.internal_url
71 else:
72 _LOGGER.warning("SUPERNOTIFY Configured without HomeAssistant instance")
74 _LOGGER.debug(
75 "SUPERNOTIFY Configured for HomeAssistant instance %s at %s , %s",
76 self.hass_name,
77 self.internal_url,
78 self.external_url,
79 )
81 if not self.internal_url or not self.internal_url.startswith("http"):
82 _LOGGER.warning("SUPERNOTIFY invalid internal hass url %s", self.internal_url)
84 def in_hass_loop(self) -> bool:
85 return self._hass is not None and self._hass.loop_thread_id == threading.get_ident()
87 def get_state(self, entity_id: str) -> State | None:
88 if not self._hass:
89 return None
90 return self._hass.states.get(entity_id)
92 def set_state(self, entity_id: str, state: str) -> None:
93 if not self._hass:
94 return
95 if self.in_hass_loop():
96 self._hass.states.async_set(entity_id, state)
97 else:
98 self._hass.states.set(entity_id, state)
100 def has_service(self, domain: str, service: str) -> bool:
101 if not self._hass:
102 return False
103 return self._hass.services.has_service(domain, service)
105 async def call_service(
106 self,
107 domain: str,
108 service: str,
109 service_data: dict[str, Any] | None = None,
110 target_data: dict[str, Any] | None = None,
111 debug: bool = False,
112 ) -> ServiceResponse | None:
113 if not self._hass:
114 raise ValueError("HomeAssistant not available")
115 return_response: bool = debug
116 blocking: bool = debug
117 try:
118 if (domain, service) not in self._service_info:
119 service_objs = self._hass.services.async_services()
120 service_obj = service_objs.get(domain, {}).get(service, {})
121 self._service_info[domain, service] = {
122 "supports_response": service_obj.supports_response,
123 "schema": service_obj.schema,
124 }
125 service_info = self._service_info.get((domain, service), {})
126 supports_response = service_info.get("supports_response")
127 if supports_response is not None:
128 if supports_response == SupportsResponse.NONE:
129 return_response = False
130 elif supports_response == SupportsResponse.ONLY:
131 return_response = True
132 else:
133 _LOGGER.debug("SUPERNOTIFY Unable to find service info for %s.%s", domain, service)
135 except Exception:
136 _LOGGER.warning("SUPERNOTIFY Unable to get service info for %s.%s: %s")
138 response: ServiceResponse | None = await self._hass.services.async_call(
139 domain,
140 service,
141 service_data=service_data,
142 blocking=blocking,
143 context=None,
144 target=target_data,
145 return_response=return_response,
146 )
147 if response is not None and debug:
148 _LOGGER.info("SUPERNOTIFY Service %s.%s response: %s", domain, service, response)
149 return response
151 def expand_group(self, entity_ids: str | list[str]) -> list[str]:
152 if self._hass is None:
153 return []
154 return expand_entity_ids(self._hass, entity_ids)
156 def template(self, template_format: str) -> Template:
157 return Template(template_format, self._hass)
159 async def trace_condition(
160 self,
161 condition_config: ConfigType,
162 condition_variables: ConditionVariables | None = None,
163 strict: bool = False,
164 validate: bool = False,
165 trace_name: str | None = None,
166 ) -> tuple[bool | None, ActionTrace | None]:
167 result: bool | None = None
168 this_trace: ActionTrace | None = None
169 if self._hass:
170 if DATA_TRACE not in self._hass.data:
171 await async_setup(self._hass, {})
172 with trace_action(self._hass, trace_name or "anon_condition") as cond_trace:
173 cond_trace.set_trace(trace_get())
174 this_trace = cond_trace
175 with trace_path(["condition", "conditions"]) as _tp:
176 result = await self.evaluate_condition(
177 condition_config, condition_variables, strict=strict, validate=validate
178 )
179 _LOGGER.debug(cond_trace.as_dict())
180 return result, this_trace
182 async def evaluate_condition(
183 self,
184 condition_config: ConfigType,
185 condition_variables: ConditionVariables | None = None,
186 strict: bool = False,
187 validate: bool = False,
188 ) -> bool | None:
189 if self._hass is None:
190 raise ValueError("HomeAssistant not available")
192 try:
193 if validate:
194 condition_config = await condition.async_validate_condition_config(self._hass, condition_config)
195 if strict:
196 force_strict_template_mode(condition_config, undo=False)
197 test: ConditionCheckerType = await condition.async_from_config(self._hass, condition_config)
198 return test(self._hass, asdict(condition_variables) if condition_variables else None)
199 except Exception as e:
200 _LOGGER.error("SUPERNOTIFY Condition eval failed: %s", e)
201 raise
202 finally:
203 if strict:
204 force_strict_template_mode(condition_config, undo=False)
206 def abs_url(self, fragment: str | None, prefer_external: bool = True) -> str | None:
207 base_url = self.external_url if prefer_external else self.internal_url
208 if fragment:
209 if fragment.startswith("http"):
210 return fragment
211 if fragment.startswith("/"):
212 return base_url + fragment
213 return base_url + "/" + fragment
214 return None
216 def raise_issue(
217 self,
218 issue_id: str,
219 issue_key: str,
220 issue_map: dict[str, str],
221 severity: ir.IssueSeverity = ir.IssueSeverity.WARNING,
222 learn_more_url: str = "https://supernotify.rhizomatics.org.uk",
223 is_fixable: bool = False,
224 ) -> None:
225 if not self._hass:
226 return
227 ir.async_create_issue(
228 self._hass,
229 DOMAIN,
230 issue_id,
231 translation_key=issue_key,
232 translation_placeholders=issue_map,
233 severity=severity,
234 learn_more_url=learn_more_url,
235 is_fixable=is_fixable,
236 )
238 def discover_devices(self, discover_domain: str) -> list[DeviceEntry]:
239 devices: list[DeviceEntry] = []
240 dev_reg: DeviceRegistry | None = self.device_registry()
241 if dev_reg is None:
242 _LOGGER.warning(f"SUPERNOTIFY Unable to discover devices for {discover_domain} - no device registry found")
243 return []
245 all_devs = enabled_devs = found_devs = 0
246 for dev in dev_reg.devices.values():
247 all_devs += 1
248 if not dev.disabled:
249 enabled_devs += 1
250 for identifier in dev.identifiers:
251 if identifier and len(identifier) > 1 and identifier[0] == discover_domain:
252 _LOGGER.debug("SUPERNOTIFY discovered device %s for id %s", dev.name, identifier)
253 devices.append(dev)
254 found_devs += 1
255 elif identifier:
256 # HomeKit has triples for identifiers, other domains may behave similarly
257 _LOGGER.debug("SUPERNOTIFY Unexpected device %s id: %s", dev.name, identifier)
258 else:
259 _LOGGER.debug( # type: ignore
260 "SUPERNOTIFY Unexpected device %s without id", dev.name
261 )
262 _LOGGER.info(
263 f"SUPERNOTIFY {discover_domain} device discovery, all={all_devs}, enabled={enabled_devs}, found={found_devs}"
264 )
265 return devices
267 def domain_for_device(self, device_id: str, domains: list[str]) -> str | None:
268 # discover domain from device registry
269 verified_domain: str | None = None
270 device_registry = self.device_registry()
271 if device_registry:
272 device: DeviceEntry | None = device_registry.async_get(device_id)
273 if device:
274 matching_domains = [d for d, _id in device.identifiers if d in domains]
275 if matching_domains:
276 # TODO: limited to first domain found, unlikely to be more
277 return matching_domains[0]
278 _LOGGER.warning(
279 "SUPERNOTIFY A target that looks like a device_id can't be matched to supported integration: %s",
280 device_id,
281 )
282 return verified_domain
284 def entity_registry(self) -> entity_registry.EntityRegistry | None:
285 """Hass entity registry is weird, every component ends up creating its own, with a store, subscribing
286 to all entities, so do it once here
287 """ # noqa: D205
288 if self._entity_registry is not None:
289 return self._entity_registry
290 if self._hass:
291 try:
292 self._entity_registry = entity_registry.async_get(self._hass)
293 except Exception as e:
294 _LOGGER.warning("SUPERNOTIFY Unable to get entity registry: %s", e)
295 return self._entity_registry
297 def device_registry(self) -> device_registry.DeviceRegistry | None:
298 """Hass device registry is weird, every component ends up creating its own, with a store, subscribing
299 to all devices, so do it once here
300 """ # noqa: D205
301 if self._device_registry is not None:
302 return self._device_registry
303 if self._hass:
304 try:
305 self._device_registry = device_registry.async_get(self._hass)
306 except Exception as e:
307 _LOGGER.warning("SUPERNOTIFY Unable to get device registry: %s", e)
308 return self._device_registry
310 async def mqtt_available(self, raise_on_error: bool = True) -> bool:
311 if self._hass:
312 try:
313 return await mqtt.async_wait_for_mqtt_client(self._hass) is True
314 except Exception:
315 _LOGGER.exception("SUPERNOTIFY MQTT integration failed on available check")
316 if raise_on_error:
317 raise
318 return False
320 async def mqtt_publish(
321 self, topic: str, payload: Any = None, qos: int = 0, retain: bool = False, raise_on_error: bool = True
322 ) -> None:
323 if self._hass:
324 try:
325 await mqtt.async_publish(
326 self._hass,
327 topic=topic,
328 payload=json_dumps(payload),
329 qos=qos,
330 retain=retain,
331 )
332 except Exception:
333 _LOGGER.exception(f"SUPERNOTIFY MQTT publish failed to {topic}")
334 if raise_on_error:
335 raise
338def force_strict_template_mode(condition: ConfigType, undo: bool = False) -> None:
339 class TemplateWrapper:
340 def __init__(self, obj: Template) -> None:
341 self._obj = obj
343 def __getattr__(self, name: str) -> Any:
344 if name == "async_render_to_info":
345 return partial(self._obj.async_render_to_info, strict=True)
346 return getattr(self._obj, name)
348 def __setattr__(self, name: str, value: Any) -> None:
349 super().__setattr__(name, value)
351 def wrap_template(cond: ConfigType, undo: bool) -> None:
352 for key, val in cond.items():
353 if not undo and isinstance(val, Template) and hasattr(val, "_env"):
354 cond[key] = TemplateWrapper(val)
355 elif undo and isinstance(val, TemplateWrapper):
356 cond[key] = val._obj
357 elif isinstance(val, dict):
358 wrap_template(val, undo)
360 if condition is not None:
361 wrap_template(condition, undo)
364@contextmanager
365def trace_action(
366 hass: HomeAssistant,
367 item_id: str,
368 config: dict[str, Any] | None = None,
369 context: HomeAssistantContext | None = None,
370 stored_traces: int = 5,
371) -> Iterator[ActionTrace]:
372 """Trace execution of a condition"""
373 trace = ActionTrace(item_id, config, None, context or HomeAssistantContext())
374 async_store_trace(hass, trace, stored_traces)
376 try:
377 yield trace
378 except Exception as ex:
379 if item_id:
380 trace.set_error(ex)
381 raise
382 finally:
383 if item_id:
384 trace.finished()