Coverage for custom_components/supernotify/scenario.py: 89%
140 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-11-21 23:31 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-11-21 23:31 +0000
1import logging
2from typing import TYPE_CHECKING, Any
4from homeassistant.const import CONF_ENABLED
5from homeassistant.helpers import issue_registry as ir
7from . import (
8 CONF_DATA,
9 DELIVERY_SELECTION_IMPLICIT,
10 SCENARIO_TEMPLATE_ATTRS,
11)
12from .common import safe_get
13from .hass_api import HomeAssistantAPI
15if TYPE_CHECKING:
16 from homeassistant.core import HomeAssistant
17 from homeassistant.helpers.typing import ConfigType
20from collections.abc import Iterator
21from contextlib import contextmanager
23import voluptuous as vol
25# type: ignore[attr-defined,unused-ignore]
26from homeassistant.components.trace import async_store_trace
27from homeassistant.components.trace.models import ActionTrace
28from homeassistant.const import ATTR_FRIENDLY_NAME, ATTR_NAME, CONF_ALIAS, CONF_CONDITION
29from homeassistant.core import Context, HomeAssistant
30from homeassistant.helpers.typing import ConfigType
32from . import ATTR_DEFAULT, ATTR_ENABLED, CONF_ACTION_GROUP_NAMES, CONF_DELIVERY, CONF_DELIVERY_SELECTION, CONF_MEDIA
33from .delivery import Delivery
34from .model import ConditionVariables
36_LOGGER = logging.getLogger(__name__)
39class ScenarioRegistry:
40 def __init__(self, scenario_configs: ConfigType | None = None) -> None:
41 self._config: ConfigType = scenario_configs or {}
42 self.scenarios: dict[str, Scenario] = {}
43 self.content_scenario_templates: ConfigType = {}
44 self.delivery_by_scenario: dict[str, list[str]] = {}
46 async def initialize(
47 self,
48 deliveries: dict[str, Delivery],
49 implicit_deliveries: list[Delivery],
50 mobile_actions: ConfigType,
51 hass_api: HomeAssistantAPI,
52 ) -> None:
53 for scenario_name, scenario_definition in self._config.items():
54 scenario = Scenario(scenario_name, scenario_definition, hass_api)
55 if await scenario.validate(valid_deliveries=list(deliveries), valid_action_groups=list(mobile_actions)):
56 self.scenarios[scenario_name] = scenario
57 self.refresh(deliveries, implicit_deliveries)
59 def refresh(self, deliveries: dict[str, Delivery], implicit_deliveries: list[Delivery]) -> None:
60 self.delivery_by_scenario = {}
61 for scenario_name, scenario in self.scenarios.items():
62 if scenario.enabled:
63 self.delivery_by_scenario.setdefault(scenario_name, [])
64 if scenario.delivery_selection == DELIVERY_SELECTION_IMPLICIT:
65 scenario_deliveries: list[str] = [d.name for d in implicit_deliveries]
66 else:
67 scenario_deliveries = []
68 scenario_definition_delivery = scenario.delivery
69 scenario_deliveries.extend(s for s in scenario_definition_delivery if s not in scenario_deliveries)
71 for scenario_delivery in scenario_deliveries:
72 if safe_get(scenario_definition_delivery.get(scenario_delivery), CONF_ENABLED, True):
73 if deliveries[scenario_delivery].enabled:
74 self.delivery_by_scenario[scenario_name].append(scenario_delivery)
76 scenario_delivery_config = safe_get(scenario_definition_delivery.get(scenario_delivery), CONF_DATA, {})
78 # extract message and title templates per scenario per delivery
79 for template_field in SCENARIO_TEMPLATE_ATTRS:
80 template_format = scenario_delivery_config.get(template_field)
81 if template_format is not None:
82 self.content_scenario_templates.setdefault(template_field, {})
83 self.content_scenario_templates[template_field].setdefault(scenario_delivery, [])
84 self.content_scenario_templates[template_field][scenario_delivery].append(scenario_name)
87class Scenario:
88 def __init__(self, name: str, scenario_definition: dict[str, Any], hass_api: HomeAssistantAPI) -> None:
89 self.hass_api: HomeAssistantAPI = hass_api
90 self.enabled: bool = scenario_definition.get(CONF_ENABLED, True)
91 self.name: str = name
92 self.alias: str | None = scenario_definition.get(CONF_ALIAS)
93 self.condition: ConfigType | None = scenario_definition.get(CONF_CONDITION)
94 self.media: dict[str, Any] | None = scenario_definition.get(CONF_MEDIA)
95 self.delivery_selection: str | None = scenario_definition.get(CONF_DELIVERY_SELECTION)
96 self.action_groups: list[str] = scenario_definition.get(CONF_ACTION_GROUP_NAMES, [])
97 self.delivery: dict[str, Any] = scenario_definition.get(CONF_DELIVERY) or {}
98 self.default: bool = self.name == ATTR_DEFAULT
99 self.last_trace: ActionTrace | None = None
100 self.condition_func = None
102 async def validate(self, valid_deliveries: list[str] | None = None, valid_action_groups: list[str] | None = None) -> bool:
103 """Validate Home Assistant conditiion definition at initiation"""
104 if self.condition:
105 error: str | None = None
106 try:
107 # note: basic template syntax within conditions already validated by voluptuous checks
108 await self.hass_api.evaluate_condition(self.condition, ConditionVariables(), strict=True, validate=True)
109 except vol.Invalid as vi:
110 _LOGGER.error(
111 f"SUPERNOTIFY Condition definition for scenario {self.name} fails Home Assistant schema check {vi}"
112 )
113 error = f"Schema error {vi}"
114 except Exception as e:
115 _LOGGER.error("SUPERNOTIFY Disabling scenario %s with error validating %s: %s", self.name, self.condition, e)
116 error = f"Unknown error {e}"
117 if error is not None:
118 self.hass_api.raise_issue(
119 f"scenario_{self.name}_condition",
120 is_fixable=False,
121 issue_key="scenario_condition",
122 issue_map={"scenario": self.name, "error": error},
123 severity=ir.IssueSeverity.ERROR,
124 learn_more_url="https://supernotify.rhizomatics.org.uk/scenarios/",
125 )
126 return False
128 if valid_deliveries is not None:
129 invalid_deliveries: list[str] = []
130 for delivery_name in self.delivery:
131 if delivery_name not in valid_deliveries:
132 _LOGGER.error(f"SUPERNOTIFY Unknown delivery {delivery_name} removed from scenario {self.name}")
133 invalid_deliveries.append(delivery_name)
134 self.hass_api.raise_issue(
135 f"scenario_{self.name}_delivery_{delivery_name}",
136 is_fixable=False,
137 issue_key="scenario_delivery",
138 issue_map={"scenario": self.name, "delivery": delivery_name},
139 severity=ir.IssueSeverity.WARNING,
140 learn_more_url="https://supernotify.rhizomatics.org.uk/scenarios/",
141 )
142 for delivery_name in invalid_deliveries:
143 del self.delivery[delivery_name]
145 if valid_action_groups is not None:
146 invalid_action_groups: list[str] = []
147 for action_group_name in self.action_groups:
148 if action_group_name not in valid_action_groups:
149 _LOGGER.error(f"SUPERNOTIFY Unknown action group {action_group_name} removed from scenario {self.name}")
150 invalid_action_groups.append(action_group_name)
151 self.hass_api.raise_issue(
152 f"scenario_{self.name}_action_group_{action_group_name}",
153 is_fixable=False,
154 issue_key="scenario_delivery",
155 issue_map={"scenario": self.name, "action_group": action_group_name},
156 severity=ir.IssueSeverity.WARNING,
157 learn_more_url="https://supernotify.rhizomatics.org.uk/scenarios/",
158 )
159 for action_group_name in invalid_action_groups:
160 self.action_groups.remove(action_group_name)
161 return True
163 def attributes(self, include_condition: bool = True, include_trace: bool = False) -> dict[str, Any]:
164 """Return scenario attributes"""
165 attrs = {
166 ATTR_NAME: self.name,
167 ATTR_ENABLED: self.enabled,
168 "media": self.media,
169 "delivery_selection": self.delivery_selection,
170 "action_groups": self.action_groups,
171 "delivery": self.delivery,
172 "default": self.default,
173 }
174 if self.alias:
175 attrs[ATTR_FRIENDLY_NAME] = self.alias
176 if include_condition:
177 attrs["condition"] = self.condition
178 if include_trace and self.last_trace:
179 attrs["trace"] = self.last_trace.as_extended_dict()
180 return attrs
182 def contents(self, minimal: bool = False) -> dict[str, Any]:
183 """Archive friendly view of scenario"""
184 return self.attributes(include_condition=False, include_trace=not minimal)
186 async def evaluate(self, condition_variables: ConditionVariables | None = None) -> bool:
187 """Evaluate scenario conditions"""
188 result: bool | None = False
189 if self.enabled and self.condition:
190 try:
191 result = await self.hass_api.evaluate_condition(self.condition, condition_variables)
192 if result is None:
193 _LOGGER.warning("SUPERNOTIFY Scenario condition empty result")
194 except Exception as e:
195 _LOGGER.error(
196 "SUPERNOTIFY Scenario condition eval failed: %s, vars: %s",
197 e,
198 condition_variables.as_dict() if condition_variables else {},
199 )
200 return result if result is not None else False
202 async def trace(
203 self, condition_variables: ConditionVariables | None = None, strict: bool = False, validate: bool = False
204 ) -> bool:
205 """Trace scenario condition execution"""
206 result: bool | None = False
207 trace: ActionTrace | None = None
208 if self.enabled and self.condition:
209 result, trace = await self.hass_api.trace_condition(
210 self.condition, condition_variables, strict=strict, validate=validate, trace_name=f"scenario_{self.name}"
211 )
212 if trace:
213 self.last_trace = trace
214 return result if result is not None else False
217@contextmanager
218def trace_action(
219 hass: HomeAssistant,
220 item_id: str,
221 config: dict[str, Any],
222 context: Context | None = None,
223 stored_traces: int = 5,
224) -> Iterator[ActionTrace]:
225 """Trace execution of a scenario."""
226 trace = ActionTrace(item_id, config, None, context or Context())
227 async_store_trace(hass, trace, stored_traces)
229 try:
230 yield trace
231 except Exception as ex:
232 if item_id:
233 trace.set_error(ex)
234 raise
235 finally:
236 if item_id:
237 trace.finished()