working 2.0 clone
This commit is contained in:
371
custom_components/remote_homeassistant/config_flow.py
Normal file
371
custom_components/remote_homeassistant/config_flow.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""Config flow for Remote Home-Assistant integration."""
|
||||
import logging
|
||||
import enum
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
import voluptuous as vol
|
||||
from homeassistant import config_entries, core
|
||||
from homeassistant.const import (CONF_ABOVE, CONF_ACCESS_TOKEN, CONF_BELOW,
|
||||
CONF_ENTITY_ID, CONF_HOST, CONF_PORT,
|
||||
CONF_UNIT_OF_MEASUREMENT, CONF_VERIFY_SSL, CONF_TYPE)
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers.instance_id import async_get
|
||||
from homeassistant.util import slugify
|
||||
|
||||
from . import async_yaml_to_config_entry
|
||||
from .const import (CONF_ENTITY_PREFIX, # pylint:disable=unused-import
|
||||
CONF_EXCLUDE_DOMAINS, CONF_EXCLUDE_ENTITIES, CONF_FILTER,
|
||||
CONF_INCLUDE_DOMAINS, CONF_INCLUDE_ENTITIES,
|
||||
CONF_LOAD_COMPONENTS, CONF_MAIN, CONF_OPTIONS, CONF_REMOTE, CONF_REMOTE_CONNECTION,
|
||||
CONF_SECURE, CONF_SERVICE_PREFIX, CONF_SERVICES, CONF_MAX_MSG_SIZE,
|
||||
CONF_SUBSCRIBE_EVENTS, DOMAIN, REMOTE_ID, DEFAULT_MAX_MSG_SIZE)
|
||||
from .rest_api import (ApiProblem, CannotConnect, EndpointMissing, InvalidAuth,
|
||||
UnsupportedVersion, async_get_discovery_info)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
ADD_NEW_EVENT = "add_new_event"
|
||||
|
||||
FILTER_OPTIONS = [CONF_ENTITY_ID, CONF_UNIT_OF_MEASUREMENT, CONF_ABOVE, CONF_BELOW]
|
||||
|
||||
|
||||
def _filter_str(index, filter):
|
||||
entity_id = filter[CONF_ENTITY_ID]
|
||||
unit = filter[CONF_UNIT_OF_MEASUREMENT]
|
||||
above = filter[CONF_ABOVE]
|
||||
below = filter[CONF_BELOW]
|
||||
return f"{index+1}. {entity_id}, unit: {unit}, above: {above}, below: {below}"
|
||||
|
||||
|
||||
async def validate_input(hass: core.HomeAssistant, conf):
|
||||
"""Validate the user input allows us to connect."""
|
||||
try:
|
||||
info = await async_get_discovery_info(
|
||||
hass,
|
||||
conf[CONF_HOST],
|
||||
conf[CONF_PORT],
|
||||
conf.get(CONF_SECURE, False),
|
||||
conf[CONF_ACCESS_TOKEN],
|
||||
conf.get(CONF_VERIFY_SSL, False),
|
||||
)
|
||||
except OSError:
|
||||
raise CannotConnect()
|
||||
|
||||
return {"title": info["location_name"], "uuid": info["uuid"]}
|
||||
|
||||
|
||||
class InstanceType(enum.Enum):
|
||||
"""Possible options for instance type."""
|
||||
|
||||
remote = "Setup as remote node"
|
||||
main = "Add a remote"
|
||||
|
||||
|
||||
class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for Remote Home-Assistant."""
|
||||
|
||||
VERSION = 1
|
||||
CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_PUSH
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize a new ConfigFlow."""
|
||||
self.prefill = {CONF_PORT: 8123, CONF_SECURE: True, CONF_MAX_MSG_SIZE: DEFAULT_MAX_MSG_SIZE}
|
||||
|
||||
@staticmethod
|
||||
@callback
|
||||
def async_get_options_flow(config_entry):
|
||||
"""Get options flow for this handler."""
|
||||
return OptionsFlowHandler(config_entry)
|
||||
|
||||
async def async_step_user(self, user_input=None):
|
||||
"""Handle the initial step."""
|
||||
errors = {}
|
||||
|
||||
if user_input is not None:
|
||||
if user_input[CONF_TYPE] == CONF_REMOTE:
|
||||
await self.async_set_unique_id(REMOTE_ID)
|
||||
self._abort_if_unique_id_configured()
|
||||
return self.async_create_entry(title="Remote instance", data=user_input)
|
||||
|
||||
elif user_input[CONF_TYPE] == CONF_MAIN:
|
||||
return await self.async_step_connection_details()
|
||||
|
||||
errors["base"] = "unknown"
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
data_schema=vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_TYPE): vol.In([CONF_REMOTE, CONF_MAIN])
|
||||
}
|
||||
),
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
|
||||
async def async_step_connection_details(self, user_input=None):
|
||||
"""Handle the connection details step."""
|
||||
errors = {}
|
||||
if user_input is not None:
|
||||
try:
|
||||
info = await validate_input(self.hass, user_input)
|
||||
except ApiProblem:
|
||||
errors["base"] = "api_problem"
|
||||
except CannotConnect:
|
||||
errors["base"] = "cannot_connect"
|
||||
except InvalidAuth:
|
||||
errors["base"] = "invalid_auth"
|
||||
except UnsupportedVersion:
|
||||
errors["base"] = "unsupported_version"
|
||||
except EndpointMissing:
|
||||
errors["base"] = "missing_endpoint"
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Unexpected exception")
|
||||
errors["base"] = "unknown"
|
||||
else:
|
||||
await self.async_set_unique_id(info["uuid"])
|
||||
self._abort_if_unique_id_configured()
|
||||
return self.async_create_entry(title=info["title"], data=user_input)
|
||||
|
||||
user_input = user_input or dict()
|
||||
host = user_input.get(CONF_HOST, self.prefill.get(CONF_HOST) or vol.UNDEFINED)
|
||||
port = user_input.get(CONF_PORT, self.prefill.get(CONF_PORT) or vol.UNDEFINED)
|
||||
secure = user_input.get(CONF_SECURE, self.prefill.get(CONF_SECURE) or vol.UNDEFINED)
|
||||
max_msg_size = user_input.get(CONF_MAX_MSG_SIZE, self.prefill.get(CONF_MAX_MSG_SIZE) or vol.UNDEFINED)
|
||||
return self.async_show_form(
|
||||
step_id="connection_details",
|
||||
data_schema=vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_HOST, default=host): str,
|
||||
vol.Required(CONF_PORT, default=port): int,
|
||||
vol.Required(CONF_ACCESS_TOKEN, default=user_input.get(CONF_ACCESS_TOKEN, vol.UNDEFINED)): str,
|
||||
vol.Required(CONF_MAX_MSG_SIZE, default=max_msg_size): int,
|
||||
vol.Optional(CONF_SECURE, default=secure): bool,
|
||||
vol.Optional(CONF_VERIFY_SSL, default=user_input.get(CONF_VERIFY_SSL, True)): bool,
|
||||
}
|
||||
),
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
async def async_step_zeroconf(self, info):
|
||||
"""Handle instance discovered via zeroconf."""
|
||||
properties = info.properties
|
||||
port = info.port
|
||||
uuid = properties["uuid"]
|
||||
|
||||
await self.async_set_unique_id(uuid)
|
||||
self._abort_if_unique_id_configured()
|
||||
|
||||
if await async_get(self.hass) == uuid:
|
||||
return self.async_abort(reason="already_configured")
|
||||
|
||||
url = properties.get("internal_url")
|
||||
if not url:
|
||||
url = properties.get("base_url")
|
||||
url = urlparse(url)
|
||||
|
||||
self.prefill = {
|
||||
CONF_HOST: url.hostname,
|
||||
CONF_PORT: port,
|
||||
CONF_SECURE: url.scheme == "https",
|
||||
}
|
||||
|
||||
# pylint: disable=no-member # https://github.com/PyCQA/pylint/issues/3167
|
||||
self.context["identifier"] = self.unique_id
|
||||
self.context["title_placeholders"] = {"name": properties["location_name"]}
|
||||
return await self.async_step_connection_details()
|
||||
|
||||
async def async_step_import(self, user_input):
|
||||
"""Handle import from YAML."""
|
||||
try:
|
||||
info = await validate_input(self.hass, user_input)
|
||||
except Exception:
|
||||
_LOGGER.exception(f"import of {user_input[CONF_HOST]} failed")
|
||||
return self.async_abort(reason="import_failed")
|
||||
|
||||
conf, options = async_yaml_to_config_entry(user_input)
|
||||
|
||||
# Options cannot be set here, so store them in a special key and import them
|
||||
# before setting up an entry
|
||||
conf[CONF_OPTIONS] = options
|
||||
|
||||
await self.async_set_unique_id(info["uuid"])
|
||||
self._abort_if_unique_id_configured(updates=conf)
|
||||
|
||||
return self.async_create_entry(title=f"{info['title']} (YAML)", data=conf)
|
||||
|
||||
|
||||
class OptionsFlowHandler(config_entries.OptionsFlow):
|
||||
"""Handle options flow for the Home Assistant remote integration."""
|
||||
|
||||
def __init__(self, config_entry):
|
||||
"""Initialize remote_homeassistant options flow."""
|
||||
self.config_entry = config_entry
|
||||
self.filters = None
|
||||
self.events = None
|
||||
self.options = None
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
"""Manage basic options."""
|
||||
if self.config_entry.unique_id == REMOTE_ID:
|
||||
return self.async_abort(reason="not_supported")
|
||||
|
||||
if user_input is not None:
|
||||
self.options = user_input.copy()
|
||||
return await self.async_step_domain_entity_filters()
|
||||
|
||||
domains, _ = self._domains_and_entities()
|
||||
domains = set(domains + self.config_entry.options.get(CONF_LOAD_COMPONENTS, []))
|
||||
|
||||
remote = self.hass.data[DOMAIN][self.config_entry.entry_id][
|
||||
CONF_REMOTE_CONNECTION
|
||||
]
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="init",
|
||||
data_schema=vol.Schema(
|
||||
{
|
||||
vol.Optional(
|
||||
CONF_ENTITY_PREFIX,
|
||||
description={
|
||||
"suggested_value": self.config_entry.options.get(
|
||||
CONF_ENTITY_PREFIX
|
||||
)
|
||||
},
|
||||
): str,
|
||||
vol.Optional(
|
||||
CONF_LOAD_COMPONENTS,
|
||||
default=self._default(CONF_LOAD_COMPONENTS),
|
||||
): cv.multi_select(sorted(domains)),
|
||||
vol.Required(
|
||||
CONF_SERVICE_PREFIX, default=self.config_entry.options.get(CONF_SERVICE_PREFIX) or slugify(self.config_entry.title)
|
||||
): str,
|
||||
vol.Optional(
|
||||
CONF_SERVICES,
|
||||
default=self._default(CONF_SERVICES),
|
||||
): cv.multi_select(remote.proxy_services.services),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
async def async_step_domain_entity_filters(self, user_input=None):
|
||||
"""Manage domain and entity filters."""
|
||||
if user_input is not None:
|
||||
self.options.update(user_input)
|
||||
return await self.async_step_general_filters()
|
||||
|
||||
domains, entities = self._domains_and_entities()
|
||||
return self.async_show_form(
|
||||
step_id="domain_entity_filters",
|
||||
data_schema=vol.Schema(
|
||||
{
|
||||
vol.Optional(
|
||||
CONF_INCLUDE_DOMAINS,
|
||||
default=self._default(CONF_INCLUDE_DOMAINS),
|
||||
): cv.multi_select(domains),
|
||||
vol.Optional(
|
||||
CONF_INCLUDE_ENTITIES,
|
||||
default=self._default(CONF_INCLUDE_ENTITIES),
|
||||
): cv.multi_select(entities),
|
||||
vol.Optional(
|
||||
CONF_EXCLUDE_DOMAINS,
|
||||
default=self._default(CONF_EXCLUDE_DOMAINS),
|
||||
): cv.multi_select(domains),
|
||||
vol.Optional(
|
||||
CONF_EXCLUDE_ENTITIES,
|
||||
default=self._default(CONF_EXCLUDE_ENTITIES),
|
||||
): cv.multi_select(entities),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
async def async_step_general_filters(self, user_input=None):
|
||||
"""Manage domain and entity filters."""
|
||||
if user_input is not None:
|
||||
# Continue to next step if entity id is not specified
|
||||
if CONF_ENTITY_ID not in user_input:
|
||||
# Each filter string is prefixed with a number (index in self.filter+1).
|
||||
# Extract all of them and build the final filter list.
|
||||
selected_indices = [
|
||||
int(filter.split(".")[0]) - 1
|
||||
for filter in user_input.get(CONF_FILTER, [])
|
||||
]
|
||||
self.options[CONF_FILTER] = [self.filters[i] for i in selected_indices]
|
||||
return await self.async_step_events()
|
||||
|
||||
selected = user_input.get(CONF_FILTER, [])
|
||||
new_filter = {conf: user_input.get(conf) for conf in FILTER_OPTIONS}
|
||||
selected.append(_filter_str(len(self.filters), new_filter))
|
||||
self.filters.append(new_filter)
|
||||
else:
|
||||
self.filters = self.config_entry.options.get(CONF_FILTER, [])
|
||||
selected = [_filter_str(i, filter) for i, filter in enumerate(self.filters)]
|
||||
|
||||
strings = [_filter_str(i, filter) for i, filter in enumerate(self.filters)]
|
||||
return self.async_show_form(
|
||||
step_id="general_filters",
|
||||
data_schema=vol.Schema(
|
||||
{
|
||||
vol.Optional(CONF_FILTER, default=selected): cv.multi_select(
|
||||
strings
|
||||
),
|
||||
vol.Optional(CONF_ENTITY_ID): str,
|
||||
vol.Optional(CONF_UNIT_OF_MEASUREMENT): str,
|
||||
vol.Optional(CONF_ABOVE): vol.Coerce(float),
|
||||
vol.Optional(CONF_BELOW): vol.Coerce(float),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
async def async_step_events(self, user_input=None):
|
||||
"""Manage event options."""
|
||||
if user_input is not None:
|
||||
if ADD_NEW_EVENT not in user_input:
|
||||
self.options[CONF_SUBSCRIBE_EVENTS] = user_input.get(
|
||||
CONF_SUBSCRIBE_EVENTS, []
|
||||
)
|
||||
return self.async_create_entry(title="", data=self.options)
|
||||
|
||||
selected = user_input.get(CONF_SUBSCRIBE_EVENTS, [])
|
||||
self.events.add(user_input[ADD_NEW_EVENT])
|
||||
selected.append(user_input[ADD_NEW_EVENT])
|
||||
else:
|
||||
self.events = set(
|
||||
self.config_entry.options.get(CONF_SUBSCRIBE_EVENTS) or []
|
||||
)
|
||||
selected = self._default(CONF_SUBSCRIBE_EVENTS)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="events",
|
||||
data_schema=vol.Schema(
|
||||
{
|
||||
vol.Optional(
|
||||
CONF_SUBSCRIBE_EVENTS, default=selected
|
||||
): cv.multi_select(self.events),
|
||||
vol.Optional(ADD_NEW_EVENT): str,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
def _default(self, conf):
|
||||
"""Return default value for an option."""
|
||||
return self.config_entry.options.get(conf) or vol.UNDEFINED
|
||||
|
||||
def _domains_and_entities(self):
|
||||
"""Return all entities and domains exposed by remote instance."""
|
||||
remote = self.hass.data[DOMAIN][self.config_entry.entry_id][
|
||||
CONF_REMOTE_CONNECTION
|
||||
]
|
||||
|
||||
# Include entities we have in the config explicitly, otherwise they will be
|
||||
# pre-selected and not possible to remove if they are no lobger present on
|
||||
# the remote host.
|
||||
include_entities = set(self.config_entry.options.get(CONF_INCLUDE_ENTITIES, []))
|
||||
exclude_entities = set(self.config_entry.options.get(CONF_EXCLUDE_ENTITIES, []))
|
||||
entities = sorted(
|
||||
remote._all_entity_names | include_entities | exclude_entities
|
||||
)
|
||||
domains = sorted(set([entity_id.split(".")[0] for entity_id in entities]))
|
||||
return domains, entities
|
||||
Reference in New Issue
Block a user