Coverage for custom_components/supernotify/scenario.py: 89%

140 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-11-21 23:31 +0000

1import logging 

2from typing import TYPE_CHECKING, Any 

3 

4from homeassistant.const import CONF_ENABLED 

5from homeassistant.helpers import issue_registry as ir 

6 

7from . import ( 

8 CONF_DATA, 

9 DELIVERY_SELECTION_IMPLICIT, 

10 SCENARIO_TEMPLATE_ATTRS, 

11) 

12from .common import safe_get 

13from .hass_api import HomeAssistantAPI 

14 

15if TYPE_CHECKING: 

16 from homeassistant.core import HomeAssistant 

17 from homeassistant.helpers.typing import ConfigType 

18 

19 

20from collections.abc import Iterator 

21from contextlib import contextmanager 

22 

23import voluptuous as vol 

24 

25# type: ignore[attr-defined,unused-ignore] 

26from homeassistant.components.trace import async_store_trace 

27from homeassistant.components.trace.models import ActionTrace 

28from homeassistant.const import ATTR_FRIENDLY_NAME, ATTR_NAME, CONF_ALIAS, CONF_CONDITION 

29from homeassistant.core import Context, HomeAssistant 

30from homeassistant.helpers.typing import ConfigType 

31 

32from . import ATTR_DEFAULT, ATTR_ENABLED, CONF_ACTION_GROUP_NAMES, CONF_DELIVERY, CONF_DELIVERY_SELECTION, CONF_MEDIA 

33from .delivery import Delivery 

34from .model import ConditionVariables 

35 

36_LOGGER = logging.getLogger(__name__) 

37 

38 

39class ScenarioRegistry: 

40 def __init__(self, scenario_configs: ConfigType | None = None) -> None: 

41 self._config: ConfigType = scenario_configs or {} 

42 self.scenarios: dict[str, Scenario] = {} 

43 self.content_scenario_templates: ConfigType = {} 

44 self.delivery_by_scenario: dict[str, list[str]] = {} 

45 

46 async def initialize( 

47 self, 

48 deliveries: dict[str, Delivery], 

49 implicit_deliveries: list[Delivery], 

50 mobile_actions: ConfigType, 

51 hass_api: HomeAssistantAPI, 

52 ) -> None: 

53 for scenario_name, scenario_definition in self._config.items(): 

54 scenario = Scenario(scenario_name, scenario_definition, hass_api) 

55 if await scenario.validate(valid_deliveries=list(deliveries), valid_action_groups=list(mobile_actions)): 

56 self.scenarios[scenario_name] = scenario 

57 self.refresh(deliveries, implicit_deliveries) 

58 

59 def refresh(self, deliveries: dict[str, Delivery], implicit_deliveries: list[Delivery]) -> None: 

60 self.delivery_by_scenario = {} 

61 for scenario_name, scenario in self.scenarios.items(): 

62 if scenario.enabled: 

63 self.delivery_by_scenario.setdefault(scenario_name, []) 

64 if scenario.delivery_selection == DELIVERY_SELECTION_IMPLICIT: 

65 scenario_deliveries: list[str] = [d.name for d in implicit_deliveries] 

66 else: 

67 scenario_deliveries = [] 

68 scenario_definition_delivery = scenario.delivery 

69 scenario_deliveries.extend(s for s in scenario_definition_delivery if s not in scenario_deliveries) 

70 

71 for scenario_delivery in scenario_deliveries: 

72 if safe_get(scenario_definition_delivery.get(scenario_delivery), CONF_ENABLED, True): 

73 if deliveries[scenario_delivery].enabled: 

74 self.delivery_by_scenario[scenario_name].append(scenario_delivery) 

75 

76 scenario_delivery_config = safe_get(scenario_definition_delivery.get(scenario_delivery), CONF_DATA, {}) 

77 

78 # extract message and title templates per scenario per delivery 

79 for template_field in SCENARIO_TEMPLATE_ATTRS: 

80 template_format = scenario_delivery_config.get(template_field) 

81 if template_format is not None: 

82 self.content_scenario_templates.setdefault(template_field, {}) 

83 self.content_scenario_templates[template_field].setdefault(scenario_delivery, []) 

84 self.content_scenario_templates[template_field][scenario_delivery].append(scenario_name) 

85 

86 

87class Scenario: 

88 def __init__(self, name: str, scenario_definition: dict[str, Any], hass_api: HomeAssistantAPI) -> None: 

89 self.hass_api: HomeAssistantAPI = hass_api 

90 self.enabled: bool = scenario_definition.get(CONF_ENABLED, True) 

91 self.name: str = name 

92 self.alias: str | None = scenario_definition.get(CONF_ALIAS) 

93 self.condition: ConfigType | None = scenario_definition.get(CONF_CONDITION) 

94 self.media: dict[str, Any] | None = scenario_definition.get(CONF_MEDIA) 

95 self.delivery_selection: str | None = scenario_definition.get(CONF_DELIVERY_SELECTION) 

96 self.action_groups: list[str] = scenario_definition.get(CONF_ACTION_GROUP_NAMES, []) 

97 self.delivery: dict[str, Any] = scenario_definition.get(CONF_DELIVERY) or {} 

98 self.default: bool = self.name == ATTR_DEFAULT 

99 self.last_trace: ActionTrace | None = None 

100 self.condition_func = None 

101 

102 async def validate(self, valid_deliveries: list[str] | None = None, valid_action_groups: list[str] | None = None) -> bool: 

103 """Validate Home Assistant conditiion definition at initiation""" 

104 if self.condition: 

105 error: str | None = None 

106 try: 

107 # note: basic template syntax within conditions already validated by voluptuous checks 

108 await self.hass_api.evaluate_condition(self.condition, ConditionVariables(), strict=True, validate=True) 

109 except vol.Invalid as vi: 

110 _LOGGER.error( 

111 f"SUPERNOTIFY Condition definition for scenario {self.name} fails Home Assistant schema check {vi}" 

112 ) 

113 error = f"Schema error {vi}" 

114 except Exception as e: 

115 _LOGGER.error("SUPERNOTIFY Disabling scenario %s with error validating %s: %s", self.name, self.condition, e) 

116 error = f"Unknown error {e}" 

117 if error is not None: 

118 self.hass_api.raise_issue( 

119 f"scenario_{self.name}_condition", 

120 is_fixable=False, 

121 issue_key="scenario_condition", 

122 issue_map={"scenario": self.name, "error": error}, 

123 severity=ir.IssueSeverity.ERROR, 

124 learn_more_url="https://supernotify.rhizomatics.org.uk/scenarios/", 

125 ) 

126 return False 

127 

128 if valid_deliveries is not None: 

129 invalid_deliveries: list[str] = [] 

130 for delivery_name in self.delivery: 

131 if delivery_name not in valid_deliveries: 

132 _LOGGER.error(f"SUPERNOTIFY Unknown delivery {delivery_name} removed from scenario {self.name}") 

133 invalid_deliveries.append(delivery_name) 

134 self.hass_api.raise_issue( 

135 f"scenario_{self.name}_delivery_{delivery_name}", 

136 is_fixable=False, 

137 issue_key="scenario_delivery", 

138 issue_map={"scenario": self.name, "delivery": delivery_name}, 

139 severity=ir.IssueSeverity.WARNING, 

140 learn_more_url="https://supernotify.rhizomatics.org.uk/scenarios/", 

141 ) 

142 for delivery_name in invalid_deliveries: 

143 del self.delivery[delivery_name] 

144 

145 if valid_action_groups is not None: 

146 invalid_action_groups: list[str] = [] 

147 for action_group_name in self.action_groups: 

148 if action_group_name not in valid_action_groups: 

149 _LOGGER.error(f"SUPERNOTIFY Unknown action group {action_group_name} removed from scenario {self.name}") 

150 invalid_action_groups.append(action_group_name) 

151 self.hass_api.raise_issue( 

152 f"scenario_{self.name}_action_group_{action_group_name}", 

153 is_fixable=False, 

154 issue_key="scenario_delivery", 

155 issue_map={"scenario": self.name, "action_group": action_group_name}, 

156 severity=ir.IssueSeverity.WARNING, 

157 learn_more_url="https://supernotify.rhizomatics.org.uk/scenarios/", 

158 ) 

159 for action_group_name in invalid_action_groups: 

160 self.action_groups.remove(action_group_name) 

161 return True 

162 

163 def attributes(self, include_condition: bool = True, include_trace: bool = False) -> dict[str, Any]: 

164 """Return scenario attributes""" 

165 attrs = { 

166 ATTR_NAME: self.name, 

167 ATTR_ENABLED: self.enabled, 

168 "media": self.media, 

169 "delivery_selection": self.delivery_selection, 

170 "action_groups": self.action_groups, 

171 "delivery": self.delivery, 

172 "default": self.default, 

173 } 

174 if self.alias: 

175 attrs[ATTR_FRIENDLY_NAME] = self.alias 

176 if include_condition: 

177 attrs["condition"] = self.condition 

178 if include_trace and self.last_trace: 

179 attrs["trace"] = self.last_trace.as_extended_dict() 

180 return attrs 

181 

182 def contents(self, minimal: bool = False) -> dict[str, Any]: 

183 """Archive friendly view of scenario""" 

184 return self.attributes(include_condition=False, include_trace=not minimal) 

185 

186 async def evaluate(self, condition_variables: ConditionVariables | None = None) -> bool: 

187 """Evaluate scenario conditions""" 

188 result: bool | None = False 

189 if self.enabled and self.condition: 

190 try: 

191 result = await self.hass_api.evaluate_condition(self.condition, condition_variables) 

192 if result is None: 

193 _LOGGER.warning("SUPERNOTIFY Scenario condition empty result") 

194 except Exception as e: 

195 _LOGGER.error( 

196 "SUPERNOTIFY Scenario condition eval failed: %s, vars: %s", 

197 e, 

198 condition_variables.as_dict() if condition_variables else {}, 

199 ) 

200 return result if result is not None else False 

201 

202 async def trace( 

203 self, condition_variables: ConditionVariables | None = None, strict: bool = False, validate: bool = False 

204 ) -> bool: 

205 """Trace scenario condition execution""" 

206 result: bool | None = False 

207 trace: ActionTrace | None = None 

208 if self.enabled and self.condition: 

209 result, trace = await self.hass_api.trace_condition( 

210 self.condition, condition_variables, strict=strict, validate=validate, trace_name=f"scenario_{self.name}" 

211 ) 

212 if trace: 

213 self.last_trace = trace 

214 return result if result is not None else False 

215 

216 

217@contextmanager 

218def trace_action( 

219 hass: HomeAssistant, 

220 item_id: str, 

221 config: dict[str, Any], 

222 context: Context | None = None, 

223 stored_traces: int = 5, 

224) -> Iterator[ActionTrace]: 

225 """Trace execution of a scenario.""" 

226 trace = ActionTrace(item_id, config, None, context or Context()) 

227 async_store_trace(hass, trace, stored_traces) 

228 

229 try: 

230 yield trace 

231 except Exception as ex: 

232 if item_id: 

233 trace.set_error(ex) 

234 raise 

235 finally: 

236 if item_id: 

237 trace.finished()