Coverage for custom_components/supernotify/hass_api.py: 84%

237 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-11-21 23:31 +0000

1from __future__ import annotations 

2 

3import logging 

4import socket 

5import threading 

6from contextlib import contextmanager 

7from dataclasses import asdict 

8from functools import partial 

9from typing import TYPE_CHECKING, Any 

10 

11from homeassistant.components import mqtt 

12from homeassistant.components.group import expand_entity_ids 

13from homeassistant.components.trace import async_setup, async_store_trace # type: ignore[attr-defined,unused-ignore] 

14from homeassistant.components.trace.const import DATA_TRACE 

15from homeassistant.components.trace.models import ActionTrace 

16from homeassistant.core import Context as HomeAssistantContext 

17from homeassistant.core import HomeAssistant, SupportsResponse 

18from homeassistant.helpers.json import json_dumps 

19from homeassistant.helpers.template import Template 

20from homeassistant.helpers.trace import trace_get, trace_path 

21from homeassistant.helpers.typing import ConfigType 

22 

23if TYPE_CHECKING: 

24 from collections.abc import Iterator 

25 

26 from homeassistant.core import ServiceResponse, State 

27 from homeassistant.helpers.condition import ConditionCheckerType 

28 

29 from .model import ConditionVariables 

30from homeassistant.helpers import condition as condition 

31from homeassistant.helpers import device_registry, entity_registry 

32from homeassistant.helpers import issue_registry as ir 

33from homeassistant.helpers.network import get_url 

34 

35from . import ( 

36 DOMAIN, 

37) 

38 

39if TYPE_CHECKING: 

40 from homeassistant.core import HomeAssistant 

41 from homeassistant.helpers.device_registry import DeviceEntry, DeviceRegistry 

42 from homeassistant.helpers.typing import ConfigType 

43 

44 

45_LOGGER = logging.getLogger(__name__) 

46 

47 

48class HomeAssistantAPI: 

49 def __init__(self, hass: HomeAssistant | None = None) -> None: 

50 self._hass = hass 

51 self.internal_url: str = "" 

52 self.external_url: str = "" 

53 self.hass_name: str = "!UNDEFINED!" 

54 self._entity_registry: entity_registry.EntityRegistry | None = None 

55 self._device_registry: device_registry.DeviceRegistry | None = None 

56 self._service_info: dict[tuple[str, str], Any] = {} 

57 

58 def initialize(self) -> None: 

59 if self._hass: 

60 self.hass_name = self._hass.config.location_name 

61 try: 

62 self.internal_url = get_url(self._hass, prefer_external=False) 

63 except Exception as e: 

64 self.internal_url = f"http://{socket.gethostname()}" 

65 _LOGGER.warning("SUPERNOTIFY could not get internal hass url, defaulting to %s: %s", self.internal_url, e) 

66 try: 

67 self.external_url = get_url(self._hass, prefer_external=True) 

68 except Exception as e: 

69 _LOGGER.warning("SUPERNOTIFY could not get external hass url, defaulting to internal url: %s", e) 

70 self.external_url = self.internal_url 

71 else: 

72 _LOGGER.warning("SUPERNOTIFY Configured without HomeAssistant instance") 

73 

74 _LOGGER.debug( 

75 "SUPERNOTIFY Configured for HomeAssistant instance %s at %s , %s", 

76 self.hass_name, 

77 self.internal_url, 

78 self.external_url, 

79 ) 

80 

81 if not self.internal_url or not self.internal_url.startswith("http"): 

82 _LOGGER.warning("SUPERNOTIFY invalid internal hass url %s", self.internal_url) 

83 

84 def in_hass_loop(self) -> bool: 

85 return self._hass is not None and self._hass.loop_thread_id == threading.get_ident() 

86 

87 def get_state(self, entity_id: str) -> State | None: 

88 if not self._hass: 

89 return None 

90 return self._hass.states.get(entity_id) 

91 

92 def set_state(self, entity_id: str, state: str) -> None: 

93 if not self._hass: 

94 return 

95 if self.in_hass_loop(): 

96 self._hass.states.async_set(entity_id, state) 

97 else: 

98 self._hass.states.set(entity_id, state) 

99 

100 def has_service(self, domain: str, service: str) -> bool: 

101 if not self._hass: 

102 return False 

103 return self._hass.services.has_service(domain, service) 

104 

105 async def call_service( 

106 self, 

107 domain: str, 

108 service: str, 

109 service_data: dict[str, Any] | None = None, 

110 target_data: dict[str, Any] | None = None, 

111 debug: bool = False, 

112 ) -> ServiceResponse | None: 

113 if not self._hass: 

114 raise ValueError("HomeAssistant not available") 

115 return_response: bool = debug 

116 blocking: bool = debug 

117 try: 

118 if (domain, service) not in self._service_info: 

119 service_objs = self._hass.services.async_services() 

120 service_obj = service_objs.get(domain, {}).get(service, {}) 

121 self._service_info[domain, service] = { 

122 "supports_response": service_obj.supports_response, 

123 "schema": service_obj.schema, 

124 } 

125 service_info = self._service_info.get((domain, service), {}) 

126 supports_response = service_info.get("supports_response") 

127 if supports_response is not None: 

128 if supports_response == SupportsResponse.NONE: 

129 return_response = False 

130 elif supports_response == SupportsResponse.ONLY: 

131 return_response = True 

132 else: 

133 _LOGGER.debug("SUPERNOTIFY Unable to find service info for %s.%s", domain, service) 

134 

135 except Exception: 

136 _LOGGER.warning("SUPERNOTIFY Unable to get service info for %s.%s: %s") 

137 

138 response: ServiceResponse | None = await self._hass.services.async_call( 

139 domain, 

140 service, 

141 service_data=service_data, 

142 blocking=blocking, 

143 context=None, 

144 target=target_data, 

145 return_response=return_response, 

146 ) 

147 if response is not None and debug: 

148 _LOGGER.info("SUPERNOTIFY Service %s.%s response: %s", domain, service, response) 

149 return response 

150 

151 def expand_group(self, entity_ids: str | list[str]) -> list[str]: 

152 if self._hass is None: 

153 return [] 

154 return expand_entity_ids(self._hass, entity_ids) 

155 

156 def template(self, template_format: str) -> Template: 

157 return Template(template_format, self._hass) 

158 

159 async def trace_condition( 

160 self, 

161 condition_config: ConfigType, 

162 condition_variables: ConditionVariables | None = None, 

163 strict: bool = False, 

164 validate: bool = False, 

165 trace_name: str | None = None, 

166 ) -> tuple[bool | None, ActionTrace | None]: 

167 result: bool | None = None 

168 this_trace: ActionTrace | None = None 

169 if self._hass: 

170 if DATA_TRACE not in self._hass.data: 

171 await async_setup(self._hass, {}) 

172 with trace_action(self._hass, trace_name or "anon_condition") as cond_trace: 

173 cond_trace.set_trace(trace_get()) 

174 this_trace = cond_trace 

175 with trace_path(["condition", "conditions"]) as _tp: 

176 result = await self.evaluate_condition( 

177 condition_config, condition_variables, strict=strict, validate=validate 

178 ) 

179 _LOGGER.debug(cond_trace.as_dict()) 

180 return result, this_trace 

181 

182 async def evaluate_condition( 

183 self, 

184 condition_config: ConfigType, 

185 condition_variables: ConditionVariables | None = None, 

186 strict: bool = False, 

187 validate: bool = False, 

188 ) -> bool | None: 

189 if self._hass is None: 

190 raise ValueError("HomeAssistant not available") 

191 

192 try: 

193 if validate: 

194 condition_config = await condition.async_validate_condition_config(self._hass, condition_config) 

195 if strict: 

196 force_strict_template_mode(condition_config, undo=False) 

197 test: ConditionCheckerType = await condition.async_from_config(self._hass, condition_config) 

198 return test(self._hass, asdict(condition_variables) if condition_variables else None) 

199 except Exception as e: 

200 _LOGGER.error("SUPERNOTIFY Condition eval failed: %s", e) 

201 raise 

202 finally: 

203 if strict: 

204 force_strict_template_mode(condition_config, undo=False) 

205 

206 def abs_url(self, fragment: str | None, prefer_external: bool = True) -> str | None: 

207 base_url = self.external_url if prefer_external else self.internal_url 

208 if fragment: 

209 if fragment.startswith("http"): 

210 return fragment 

211 if fragment.startswith("/"): 

212 return base_url + fragment 

213 return base_url + "/" + fragment 

214 return None 

215 

216 def raise_issue( 

217 self, 

218 issue_id: str, 

219 issue_key: str, 

220 issue_map: dict[str, str], 

221 severity: ir.IssueSeverity = ir.IssueSeverity.WARNING, 

222 learn_more_url: str = "https://supernotify.rhizomatics.org.uk", 

223 is_fixable: bool = False, 

224 ) -> None: 

225 if not self._hass: 

226 return 

227 ir.async_create_issue( 

228 self._hass, 

229 DOMAIN, 

230 issue_id, 

231 translation_key=issue_key, 

232 translation_placeholders=issue_map, 

233 severity=severity, 

234 learn_more_url=learn_more_url, 

235 is_fixable=is_fixable, 

236 ) 

237 

238 def discover_devices(self, discover_domain: str) -> list[DeviceEntry]: 

239 devices: list[DeviceEntry] = [] 

240 dev_reg: DeviceRegistry | None = self.device_registry() 

241 if dev_reg is None: 

242 _LOGGER.warning(f"SUPERNOTIFY Unable to discover devices for {discover_domain} - no device registry found") 

243 return [] 

244 

245 all_devs = enabled_devs = found_devs = 0 

246 for dev in dev_reg.devices.values(): 

247 all_devs += 1 

248 if not dev.disabled: 

249 enabled_devs += 1 

250 for identifier in dev.identifiers: 

251 if identifier and len(identifier) > 1 and identifier[0] == discover_domain: 

252 _LOGGER.debug("SUPERNOTIFY discovered device %s for id %s", dev.name, identifier) 

253 devices.append(dev) 

254 found_devs += 1 

255 elif identifier: 

256 # HomeKit has triples for identifiers, other domains may behave similarly 

257 _LOGGER.debug("SUPERNOTIFY Unexpected device %s id: %s", dev.name, identifier) 

258 else: 

259 _LOGGER.debug( # type: ignore 

260 "SUPERNOTIFY Unexpected device %s without id", dev.name 

261 ) 

262 _LOGGER.info( 

263 f"SUPERNOTIFY {discover_domain} device discovery, all={all_devs}, enabled={enabled_devs}, found={found_devs}" 

264 ) 

265 return devices 

266 

267 def domain_for_device(self, device_id: str, domains: list[str]) -> str | None: 

268 # discover domain from device registry 

269 verified_domain: str | None = None 

270 device_registry = self.device_registry() 

271 if device_registry: 

272 device: DeviceEntry | None = device_registry.async_get(device_id) 

273 if device: 

274 matching_domains = [d for d, _id in device.identifiers if d in domains] 

275 if matching_domains: 

276 # TODO: limited to first domain found, unlikely to be more 

277 return matching_domains[0] 

278 _LOGGER.warning( 

279 "SUPERNOTIFY A target that looks like a device_id can't be matched to supported integration: %s", 

280 device_id, 

281 ) 

282 return verified_domain 

283 

284 def entity_registry(self) -> entity_registry.EntityRegistry | None: 

285 """Hass entity registry is weird, every component ends up creating its own, with a store, subscribing 

286 to all entities, so do it once here 

287 """ # noqa: D205 

288 if self._entity_registry is not None: 

289 return self._entity_registry 

290 if self._hass: 

291 try: 

292 self._entity_registry = entity_registry.async_get(self._hass) 

293 except Exception as e: 

294 _LOGGER.warning("SUPERNOTIFY Unable to get entity registry: %s", e) 

295 return self._entity_registry 

296 

297 def device_registry(self) -> device_registry.DeviceRegistry | None: 

298 """Hass device registry is weird, every component ends up creating its own, with a store, subscribing 

299 to all devices, so do it once here 

300 """ # noqa: D205 

301 if self._device_registry is not None: 

302 return self._device_registry 

303 if self._hass: 

304 try: 

305 self._device_registry = device_registry.async_get(self._hass) 

306 except Exception as e: 

307 _LOGGER.warning("SUPERNOTIFY Unable to get device registry: %s", e) 

308 return self._device_registry 

309 

310 async def mqtt_available(self, raise_on_error: bool = True) -> bool: 

311 if self._hass: 

312 try: 

313 return await mqtt.async_wait_for_mqtt_client(self._hass) is True 

314 except Exception: 

315 _LOGGER.exception("SUPERNOTIFY MQTT integration failed on available check") 

316 if raise_on_error: 

317 raise 

318 return False 

319 

320 async def mqtt_publish( 

321 self, topic: str, payload: Any = None, qos: int = 0, retain: bool = False, raise_on_error: bool = True 

322 ) -> None: 

323 if self._hass: 

324 try: 

325 await mqtt.async_publish( 

326 self._hass, 

327 topic=topic, 

328 payload=json_dumps(payload), 

329 qos=qos, 

330 retain=retain, 

331 ) 

332 except Exception: 

333 _LOGGER.exception(f"SUPERNOTIFY MQTT publish failed to {topic}") 

334 if raise_on_error: 

335 raise 

336 

337 

338def force_strict_template_mode(condition: ConfigType, undo: bool = False) -> None: 

339 class TemplateWrapper: 

340 def __init__(self, obj: Template) -> None: 

341 self._obj = obj 

342 

343 def __getattr__(self, name: str) -> Any: 

344 if name == "async_render_to_info": 

345 return partial(self._obj.async_render_to_info, strict=True) 

346 return getattr(self._obj, name) 

347 

348 def __setattr__(self, name: str, value: Any) -> None: 

349 super().__setattr__(name, value) 

350 

351 def wrap_template(cond: ConfigType, undo: bool) -> None: 

352 for key, val in cond.items(): 

353 if not undo and isinstance(val, Template) and hasattr(val, "_env"): 

354 cond[key] = TemplateWrapper(val) 

355 elif undo and isinstance(val, TemplateWrapper): 

356 cond[key] = val._obj 

357 elif isinstance(val, dict): 

358 wrap_template(val, undo) 

359 

360 if condition is not None: 

361 wrap_template(condition, undo) 

362 

363 

364@contextmanager 

365def trace_action( 

366 hass: HomeAssistant, 

367 item_id: str, 

368 config: dict[str, Any] | None = None, 

369 context: HomeAssistantContext | None = None, 

370 stored_traces: int = 5, 

371) -> Iterator[ActionTrace]: 

372 """Trace execution of a condition""" 

373 trace = ActionTrace(item_id, config, None, context or HomeAssistantContext()) 

374 async_store_trace(hass, trace, stored_traces) 

375 

376 try: 

377 yield trace 

378 except Exception as ex: 

379 if item_id: 

380 trace.set_error(ex) 

381 raise 

382 finally: 

383 if item_id: 

384 trace.finished()