working 2.0 clone
This commit is contained in:
745
custom_components/remote_homeassistant/__init__.py
Normal file
745
custom_components/remote_homeassistant/__init__.py
Normal file
@@ -0,0 +1,745 @@
|
||||
"""
|
||||
Connect two Home Assistant instances via the Websocket API.
|
||||
|
||||
For more details about this component, please refer to the documentation at
|
||||
https://home-assistant.io/components/remote_homeassistant/
|
||||
"""
|
||||
import asyncio
|
||||
import copy
|
||||
import fnmatch
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
from contextlib import suppress
|
||||
|
||||
import aiohttp
|
||||
import homeassistant.components.websocket_api.auth as api
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
import voluptuous as vol
|
||||
from homeassistant.config import DATA_CUSTOMIZE
|
||||
from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry
|
||||
from homeassistant.const import (CONF_ABOVE, CONF_ACCESS_TOKEN, CONF_BELOW,
|
||||
CONF_DOMAINS, CONF_ENTITIES, CONF_ENTITY_ID,
|
||||
CONF_EXCLUDE, CONF_HOST, CONF_INCLUDE,
|
||||
CONF_PORT, CONF_UNIT_OF_MEASUREMENT,
|
||||
CONF_VERIFY_SSL, EVENT_CALL_SERVICE,
|
||||
EVENT_HOMEASSISTANT_STOP, EVENT_STATE_CHANGED,
|
||||
SERVICE_RELOAD)
|
||||
from homeassistant.core import (Context, EventOrigin, HomeAssistant, callback,
|
||||
split_entity_id)
|
||||
from homeassistant.helpers import device_registry as dr
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
from homeassistant.helpers.reload import async_integration_yaml_config
|
||||
from homeassistant.helpers.typing import ConfigType, HomeAssistantType
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from custom_components.remote_homeassistant.views import DiscoveryInfoView
|
||||
|
||||
from .const import (CONF_EXCLUDE_DOMAINS, CONF_EXCLUDE_ENTITIES,
|
||||
CONF_INCLUDE_DOMAINS, CONF_INCLUDE_ENTITIES,
|
||||
CONF_LOAD_COMPONENTS, CONF_OPTIONS, CONF_REMOTE_CONNECTION,
|
||||
CONF_SERVICE_PREFIX, CONF_SERVICES, CONF_UNSUB_LISTENER,
|
||||
DOMAIN, REMOTE_ID, DEFAULT_MAX_MSG_SIZE)
|
||||
from .proxy_services import ProxyServices
|
||||
from .rest_api import UnsupportedVersion, async_get_discovery_info
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
PLATFORMS = ["sensor"]
|
||||
|
||||
CONF_INSTANCES = "instances"
|
||||
CONF_SECURE = "secure"
|
||||
CONF_SUBSCRIBE_EVENTS = "subscribe_events"
|
||||
CONF_ENTITY_PREFIX = "entity_prefix"
|
||||
CONF_FILTER = "filter"
|
||||
CONF_MAX_MSG_SIZE = "max_message_size"
|
||||
|
||||
STATE_INIT = "initializing"
|
||||
STATE_CONNECTING = "connecting"
|
||||
STATE_CONNECTED = "connected"
|
||||
STATE_AUTH_INVALID = "auth_invalid"
|
||||
STATE_AUTH_REQUIRED = "auth_required"
|
||||
STATE_RECONNECTING = "reconnecting"
|
||||
STATE_DISCONNECTED = "disconnected"
|
||||
|
||||
DEFAULT_ENTITY_PREFIX = ""
|
||||
|
||||
INSTANCES_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_HOST): cv.string,
|
||||
vol.Optional(CONF_PORT, default=8123): cv.port,
|
||||
vol.Optional(CONF_SECURE, default=False): cv.boolean,
|
||||
vol.Optional(CONF_VERIFY_SSL, default=True): cv.boolean,
|
||||
vol.Required(CONF_ACCESS_TOKEN): cv.string,
|
||||
vol.Optional(CONF_MAX_MSG_SIZE, default=DEFAULT_MAX_MSG_SIZE): vol.Coerce(int),
|
||||
vol.Optional(CONF_EXCLUDE, default={}): vol.Schema(
|
||||
{
|
||||
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
|
||||
vol.Optional(CONF_DOMAINS, default=[]): vol.All(
|
||||
cv.ensure_list, [cv.string]
|
||||
),
|
||||
}
|
||||
),
|
||||
vol.Optional(CONF_INCLUDE, default={}): vol.Schema(
|
||||
{
|
||||
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
|
||||
vol.Optional(CONF_DOMAINS, default=[]): vol.All(
|
||||
cv.ensure_list, [cv.string]
|
||||
),
|
||||
}
|
||||
),
|
||||
vol.Optional(CONF_FILTER, default=[]): vol.All(
|
||||
cv.ensure_list,
|
||||
[
|
||||
vol.Schema(
|
||||
{
|
||||
vol.Optional(CONF_ENTITY_ID): cv.string,
|
||||
vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string,
|
||||
vol.Optional(CONF_ABOVE): vol.Coerce(float),
|
||||
vol.Optional(CONF_BELOW): vol.Coerce(float),
|
||||
}
|
||||
)
|
||||
],
|
||||
),
|
||||
vol.Optional(CONF_SUBSCRIBE_EVENTS): cv.ensure_list,
|
||||
vol.Optional(CONF_ENTITY_PREFIX, default=DEFAULT_ENTITY_PREFIX): cv.string,
|
||||
vol.Optional(CONF_LOAD_COMPONENTS): cv.ensure_list,
|
||||
vol.Required(CONF_SERVICE_PREFIX, default="remote_"): cv.string,
|
||||
vol.Optional(CONF_SERVICES): cv.ensure_list,
|
||||
}
|
||||
)
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema(
|
||||
{
|
||||
DOMAIN: vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_INSTANCES): vol.All(
|
||||
cv.ensure_list, [INSTANCES_SCHEMA]
|
||||
),
|
||||
}
|
||||
),
|
||||
},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
)
|
||||
|
||||
HEARTBEAT_INTERVAL = 20
|
||||
HEARTBEAT_TIMEOUT = 5
|
||||
|
||||
INTERNALLY_USED_EVENTS = [EVENT_STATE_CHANGED]
|
||||
|
||||
|
||||
def async_yaml_to_config_entry(instance_conf):
|
||||
"""Convert YAML config into data and options used by a config entry."""
|
||||
conf = instance_conf.copy()
|
||||
options = {}
|
||||
|
||||
if CONF_INCLUDE in conf:
|
||||
include = conf.pop(CONF_INCLUDE)
|
||||
if CONF_ENTITIES in include:
|
||||
options[CONF_INCLUDE_ENTITIES] = include[CONF_ENTITIES]
|
||||
if CONF_DOMAINS in include:
|
||||
options[CONF_INCLUDE_DOMAINS] = include[CONF_DOMAINS]
|
||||
|
||||
if CONF_EXCLUDE in conf:
|
||||
exclude = conf.pop(CONF_EXCLUDE)
|
||||
if CONF_ENTITIES in exclude:
|
||||
options[CONF_EXCLUDE_ENTITIES] = exclude[CONF_ENTITIES]
|
||||
if CONF_DOMAINS in exclude:
|
||||
options[CONF_EXCLUDE_DOMAINS] = exclude[CONF_DOMAINS]
|
||||
|
||||
for option in [
|
||||
CONF_FILTER,
|
||||
CONF_SUBSCRIBE_EVENTS,
|
||||
CONF_ENTITY_PREFIX,
|
||||
CONF_LOAD_COMPONENTS,
|
||||
CONF_SERVICE_PREFIX,
|
||||
CONF_SERVICES,
|
||||
]:
|
||||
if option in conf:
|
||||
options[option] = conf.pop(option)
|
||||
|
||||
return conf, options
|
||||
|
||||
|
||||
async def _async_update_config_entry_if_from_yaml(hass, entries_by_id, conf):
|
||||
"""Update a config entry with the latest yaml."""
|
||||
try:
|
||||
info = await async_get_discovery_info(
|
||||
hass,
|
||||
conf[CONF_HOST],
|
||||
conf[CONF_PORT],
|
||||
conf[CONF_SECURE],
|
||||
conf[CONF_ACCESS_TOKEN],
|
||||
conf[CONF_VERIFY_SSL],
|
||||
)
|
||||
except Exception:
|
||||
_LOGGER.exception(f"reload of {conf[CONF_HOST]} failed")
|
||||
else:
|
||||
entry = entries_by_id.get(info["uuid"])
|
||||
if entry:
|
||||
data, options = async_yaml_to_config_entry(conf)
|
||||
hass.config_entries.async_update_entry(entry, data=data, options=options)
|
||||
|
||||
|
||||
async def setup_remote_instance(hass: HomeAssistantType):
|
||||
hass.http.register_view(DiscoveryInfoView())
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistantType, config: ConfigType):
|
||||
"""Set up the remote_homeassistant component."""
|
||||
hass.data.setdefault(DOMAIN, {})
|
||||
|
||||
async def _handle_reload(service):
|
||||
"""Handle reload service call."""
|
||||
config = await async_integration_yaml_config(hass, DOMAIN)
|
||||
|
||||
if not config or DOMAIN not in config:
|
||||
return
|
||||
|
||||
current_entries = hass.config_entries.async_entries(DOMAIN)
|
||||
entries_by_id = {entry.unique_id: entry for entry in current_entries}
|
||||
|
||||
instances = config[DOMAIN][CONF_INSTANCES]
|
||||
update_tasks = [
|
||||
_async_update_config_entry_if_from_yaml(hass, entries_by_id, instance)
|
||||
for instance in instances
|
||||
]
|
||||
|
||||
await asyncio.gather(*update_tasks)
|
||||
|
||||
hass.async_create_task(setup_remote_instance(hass))
|
||||
|
||||
hass.helpers.service.async_register_admin_service(
|
||||
DOMAIN,
|
||||
SERVICE_RELOAD,
|
||||
_handle_reload,
|
||||
)
|
||||
|
||||
instances = config.get(DOMAIN, {}).get(CONF_INSTANCES, [])
|
||||
for instance in instances:
|
||||
hass.async_create_task(
|
||||
hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_IMPORT}, data=instance
|
||||
)
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):
|
||||
"""Set up Remote Home-Assistant from a config entry."""
|
||||
_async_import_options_from_yaml(hass, entry)
|
||||
if entry.unique_id == REMOTE_ID:
|
||||
hass.async_create_task(setup_remote_instance(hass))
|
||||
return True
|
||||
else:
|
||||
remote = RemoteConnection(hass, entry)
|
||||
|
||||
hass.data[DOMAIN][entry.entry_id] = {
|
||||
CONF_REMOTE_CONNECTION: remote,
|
||||
CONF_UNSUB_LISTENER: entry.add_update_listener(_update_listener),
|
||||
}
|
||||
|
||||
async def setup_components_and_platforms():
|
||||
"""Set up platforms and initiate connection."""
|
||||
for domain in entry.options.get(CONF_LOAD_COMPONENTS, []):
|
||||
hass.async_create_task(async_setup_component(hass, domain, {}))
|
||||
|
||||
await asyncio.gather(
|
||||
*[
|
||||
hass.config_entries.async_forward_entry_setup(entry, platform)
|
||||
for platform in PLATFORMS
|
||||
]
|
||||
)
|
||||
await remote.async_connect()
|
||||
|
||||
hass.async_create_task(setup_components_and_platforms())
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry):
|
||||
"""Unload a config entry."""
|
||||
unload_ok = all(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
hass.config_entries.async_forward_entry_unload(entry, platform)
|
||||
for platform in PLATFORMS
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
if unload_ok:
|
||||
data = hass.data[DOMAIN].pop(entry.entry_id)
|
||||
await data[CONF_REMOTE_CONNECTION].async_stop()
|
||||
data[CONF_UNSUB_LISTENER]()
|
||||
|
||||
return unload_ok
|
||||
|
||||
|
||||
@callback
|
||||
def _async_import_options_from_yaml(hass: HomeAssistant, entry: ConfigEntry):
|
||||
"""Import options from YAML into options section of config entry."""
|
||||
if CONF_OPTIONS in entry.data:
|
||||
data = entry.data.copy()
|
||||
options = data.pop(CONF_OPTIONS)
|
||||
hass.config_entries.async_update_entry(entry, data=data, options=options)
|
||||
|
||||
|
||||
async def _update_listener(hass, config_entry):
|
||||
"""Update listener."""
|
||||
await hass.config_entries.async_reload(config_entry.entry_id)
|
||||
|
||||
|
||||
class RemoteConnection(object):
|
||||
"""A Websocket connection to a remote home-assistant instance."""
|
||||
|
||||
def __init__(self, hass, config_entry):
|
||||
"""Initialize the connection."""
|
||||
self._hass = hass
|
||||
self._entry = config_entry
|
||||
self._secure = config_entry.data.get(CONF_SECURE, False)
|
||||
self._verify_ssl = config_entry.data.get(CONF_VERIFY_SSL, False)
|
||||
self._access_token = config_entry.data.get(CONF_ACCESS_TOKEN)
|
||||
self._max_msg_size = config_entry.data.get(CONF_MAX_MSG_SIZE)
|
||||
|
||||
# see homeassistant/components/influxdb/__init__.py
|
||||
# for include/exclude logic
|
||||
self._whitelist_e = set(config_entry.options.get(CONF_INCLUDE_ENTITIES, []))
|
||||
self._whitelist_d = set(config_entry.options.get(CONF_INCLUDE_DOMAINS, []))
|
||||
self._blacklist_e = set(config_entry.options.get(CONF_EXCLUDE_ENTITIES, []))
|
||||
self._blacklist_d = set(config_entry.options.get(CONF_EXCLUDE_DOMAINS, []))
|
||||
|
||||
self._filter = [
|
||||
{
|
||||
CONF_ENTITY_ID: re.compile(fnmatch.translate(f.get(CONF_ENTITY_ID)))
|
||||
if f.get(CONF_ENTITY_ID)
|
||||
else None,
|
||||
CONF_UNIT_OF_MEASUREMENT: f.get(CONF_UNIT_OF_MEASUREMENT),
|
||||
CONF_ABOVE: f.get(CONF_ABOVE),
|
||||
CONF_BELOW: f.get(CONF_BELOW),
|
||||
}
|
||||
for f in config_entry.options.get(CONF_FILTER, [])
|
||||
]
|
||||
|
||||
self._subscribe_events = set(
|
||||
config_entry.options.get(CONF_SUBSCRIBE_EVENTS, []) + INTERNALLY_USED_EVENTS
|
||||
)
|
||||
self._entity_prefix = config_entry.options.get(CONF_ENTITY_PREFIX, "")
|
||||
|
||||
self._connection = None
|
||||
self._heartbeat_task = None
|
||||
self._is_stopping = False
|
||||
self._entities = set()
|
||||
self._all_entity_names = set()
|
||||
self._handlers = {}
|
||||
self._remove_listener = None
|
||||
self.proxy_services = ProxyServices(hass, config_entry, self)
|
||||
|
||||
self.set_connection_state(STATE_CONNECTING)
|
||||
|
||||
self.__id = 1
|
||||
|
||||
def _prefixed_entity_id(self, entity_id):
|
||||
if self._entity_prefix:
|
||||
domain, object_id = split_entity_id(entity_id)
|
||||
object_id = self._entity_prefix + object_id
|
||||
entity_id = domain + "." + object_id
|
||||
return entity_id
|
||||
return entity_id
|
||||
|
||||
def set_connection_state(self, state):
|
||||
"""Change current connection state."""
|
||||
signal = f"remote_homeassistant_{self._entry.unique_id}"
|
||||
async_dispatcher_send(self._hass, signal, state)
|
||||
|
||||
@callback
|
||||
def _get_url(self):
|
||||
"""Get url to connect to."""
|
||||
return "%s://%s:%s/api/websocket" % (
|
||||
"wss" if self._secure else "ws",
|
||||
self._entry.data[CONF_HOST],
|
||||
self._entry.data[CONF_PORT],
|
||||
)
|
||||
|
||||
async def async_connect(self):
|
||||
"""Connect to remote home-assistant websocket..."""
|
||||
|
||||
async def _async_stop_handler(event):
|
||||
"""Stop when Home Assistant is shutting down."""
|
||||
await self.async_stop()
|
||||
|
||||
async def _async_instance_get_info():
|
||||
"""Fetch discovery info from remote instance."""
|
||||
try:
|
||||
return await async_get_discovery_info(
|
||||
self._hass,
|
||||
self._entry.data[CONF_HOST],
|
||||
self._entry.data[CONF_PORT],
|
||||
self._secure,
|
||||
self._access_token,
|
||||
self._verify_ssl,
|
||||
)
|
||||
except OSError:
|
||||
_LOGGER.exception("failed to connect")
|
||||
except UnsupportedVersion:
|
||||
_LOGGER.error("Unsupported version, at least 0.111 is required.")
|
||||
except Exception:
|
||||
_LOGGER.exception("failed to fetch instance info")
|
||||
return None
|
||||
|
||||
@callback
|
||||
def _async_instance_id_match(info):
|
||||
"""Verify if remote instance id matches the expected id."""
|
||||
if not info:
|
||||
return False
|
||||
if info and info["uuid"] != self._entry.unique_id:
|
||||
_LOGGER.error(
|
||||
"instance id not matching: %s != %s",
|
||||
info["uuid"],
|
||||
self._entry.unique_id,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
url = self._get_url()
|
||||
|
||||
session = async_get_clientsession(self._hass, self._verify_ssl)
|
||||
self.set_connection_state(STATE_CONNECTING)
|
||||
|
||||
while True:
|
||||
info = await _async_instance_get_info()
|
||||
|
||||
# Verify we are talking to correct instance
|
||||
if not _async_instance_id_match(info):
|
||||
self.set_connection_state(STATE_RECONNECTING)
|
||||
await asyncio.sleep(10)
|
||||
continue
|
||||
|
||||
try:
|
||||
_LOGGER.info("Connecting to %s", url)
|
||||
self._connection = await session.ws_connect(url, max_msg_size = self._max_msg_size)
|
||||
except aiohttp.client_exceptions.ClientError:
|
||||
_LOGGER.error("Could not connect to %s, retry in 10 seconds...", url)
|
||||
self.set_connection_state(STATE_RECONNECTING)
|
||||
await asyncio.sleep(10)
|
||||
else:
|
||||
_LOGGER.info("Connected to home-assistant websocket at %s", url)
|
||||
break
|
||||
|
||||
self._hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_stop_handler)
|
||||
|
||||
device_registry = dr.async_get(self._hass)
|
||||
device_registry.async_get_or_create(
|
||||
config_entry_id=self._entry.entry_id,
|
||||
identifiers={(DOMAIN, f"remote_{self._entry.unique_id}")},
|
||||
name=info.get("location_name"),
|
||||
manufacturer="Home Assistant",
|
||||
model=info.get("installation_type"),
|
||||
sw_version=info.get("ha_version"),
|
||||
)
|
||||
|
||||
asyncio.ensure_future(self._recv())
|
||||
self._heartbeat_task = self._hass.loop.create_task(self._heartbeat_loop())
|
||||
|
||||
async def _heartbeat_loop(self):
|
||||
"""Send periodic heartbeats to remote instance."""
|
||||
while not self._connection.closed:
|
||||
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
||||
|
||||
_LOGGER.debug("Sending ping")
|
||||
event = asyncio.Event()
|
||||
|
||||
def resp(message):
|
||||
_LOGGER.debug("Got pong: %s", message)
|
||||
event.set()
|
||||
|
||||
await self.call(resp, "ping")
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(event.wait(), HEARTBEAT_TIMEOUT)
|
||||
except asyncio.TimeoutError:
|
||||
_LOGGER.error("heartbeat failed")
|
||||
|
||||
# Schedule closing on event loop to avoid deadlock
|
||||
asyncio.ensure_future(self._connection.close())
|
||||
break
|
||||
|
||||
async def async_stop(self):
|
||||
"""Close connection."""
|
||||
self._is_stopping = True
|
||||
if self._connection is not None:
|
||||
await self._connection.close()
|
||||
await self.proxy_services.unload()
|
||||
|
||||
def _next_id(self):
|
||||
_id = self.__id
|
||||
self.__id += 1
|
||||
return _id
|
||||
|
||||
async def call(self, callback, message_type, **extra_args):
|
||||
_id = self._next_id()
|
||||
self._handlers[_id] = callback
|
||||
try:
|
||||
await self._connection.send_json(
|
||||
{"id": _id, "type": message_type, **extra_args}
|
||||
)
|
||||
except aiohttp.client_exceptions.ClientError as err:
|
||||
_LOGGER.error("remote websocket connection closed: %s", err)
|
||||
await self._disconnected()
|
||||
|
||||
async def _disconnected(self):
|
||||
# Remove all published entries
|
||||
for entity in self._entities:
|
||||
self._hass.states.async_remove(entity)
|
||||
if self._heartbeat_task is not None:
|
||||
self._heartbeat_task.cancel()
|
||||
try:
|
||||
await self._heartbeat_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._remove_listener is not None:
|
||||
self._remove_listener()
|
||||
|
||||
self.set_connection_state(STATE_DISCONNECTED)
|
||||
self._heartbeat_task = None
|
||||
self._remove_listener = None
|
||||
self._entities = set()
|
||||
self._all_entity_names = set()
|
||||
if not self._is_stopping:
|
||||
asyncio.ensure_future(self.async_connect())
|
||||
|
||||
async def _recv(self):
|
||||
while not self._connection.closed:
|
||||
try:
|
||||
data = await self._connection.receive()
|
||||
except aiohttp.client_exceptions.ClientError as err:
|
||||
_LOGGER.error("remote websocket connection closed: %s", err)
|
||||
break
|
||||
|
||||
if not data:
|
||||
break
|
||||
|
||||
if data.type in (
|
||||
aiohttp.WSMsgType.CLOSE,
|
||||
aiohttp.WSMsgType.CLOSED,
|
||||
aiohttp.WSMsgType.CLOSING,
|
||||
):
|
||||
_LOGGER.debug("websocket connection is closing")
|
||||
break
|
||||
|
||||
if data.type == aiohttp.WSMsgType.ERROR:
|
||||
_LOGGER.error("websocket connection had an error")
|
||||
if data.data.code == aiohttp.WSCloseCode.MESSAGE_TOO_BIG:
|
||||
_LOGGER.error(f"please consider increasing message size with `{CONF_MAX_MSG_SIZE}`")
|
||||
break
|
||||
|
||||
try:
|
||||
message = data.json()
|
||||
except TypeError as err:
|
||||
_LOGGER.error("could not decode data (%s) as json: %s", data, err)
|
||||
break
|
||||
|
||||
if message is None:
|
||||
break
|
||||
|
||||
_LOGGER.debug("received: %s", message)
|
||||
|
||||
if message["type"] == api.TYPE_AUTH_OK:
|
||||
self.set_connection_state(STATE_CONNECTED)
|
||||
await self._init()
|
||||
|
||||
elif message["type"] == api.TYPE_AUTH_REQUIRED:
|
||||
if self._access_token:
|
||||
data = {"type": api.TYPE_AUTH, "access_token": self._access_token}
|
||||
else:
|
||||
_LOGGER.error("Access token required, but not provided")
|
||||
self.set_connection_state(STATE_AUTH_REQUIRED)
|
||||
return
|
||||
try:
|
||||
await self._connection.send_json(data)
|
||||
except Exception as err:
|
||||
_LOGGER.error("could not send data to remote connection: %s", err)
|
||||
break
|
||||
|
||||
elif message["type"] == api.TYPE_AUTH_INVALID:
|
||||
_LOGGER.error("Auth invalid, check your access token")
|
||||
self.set_connection_state(STATE_AUTH_INVALID)
|
||||
await self._connection.close()
|
||||
return
|
||||
|
||||
else:
|
||||
callback = self._handlers.get(message["id"])
|
||||
if callback is not None:
|
||||
if inspect.iscoroutinefunction(callback):
|
||||
await callback(message)
|
||||
else:
|
||||
callback(message)
|
||||
|
||||
await self._disconnected()
|
||||
|
||||
async def _init(self):
|
||||
async def forward_event(event):
|
||||
"""Send local event to remote instance.
|
||||
|
||||
The affected entity_id has to origin from that remote instance,
|
||||
otherwise the event is dicarded.
|
||||
"""
|
||||
event_data = event.data
|
||||
service_data = event_data["service_data"]
|
||||
|
||||
if not service_data:
|
||||
return
|
||||
|
||||
entity_ids = service_data.get("entity_id", None)
|
||||
|
||||
if not entity_ids:
|
||||
return
|
||||
|
||||
if isinstance(entity_ids, str):
|
||||
entity_ids = (entity_ids.lower(),)
|
||||
|
||||
entities = {entity_id.lower() for entity_id in self._entities}
|
||||
|
||||
entity_ids = entities.intersection(entity_ids)
|
||||
|
||||
if not entity_ids:
|
||||
return
|
||||
|
||||
if self._entity_prefix:
|
||||
|
||||
def _remove_prefix(entity_id):
|
||||
domain, object_id = split_entity_id(entity_id)
|
||||
object_id = object_id.replace(self._entity_prefix.lower(), "", 1)
|
||||
return domain + "." + object_id
|
||||
|
||||
entity_ids = {_remove_prefix(entity_id) for entity_id in entity_ids}
|
||||
|
||||
event_data = copy.deepcopy(event_data)
|
||||
event_data["service_data"]["entity_id"] = list(entity_ids)
|
||||
|
||||
# Remove service_call_id parameter - websocket API
|
||||
# doesn't accept that one
|
||||
event_data.pop("service_call_id", None)
|
||||
|
||||
_id = self._next_id()
|
||||
data = {"id": _id, "type": event.event_type, **event_data}
|
||||
|
||||
_LOGGER.debug("forward event: %s", data)
|
||||
|
||||
try:
|
||||
await self._connection.send_json(data)
|
||||
except Exception as err:
|
||||
_LOGGER.error("could not send data to remote connection: %s", err)
|
||||
await self._disconnected()
|
||||
|
||||
def state_changed(entity_id, state, attr):
|
||||
"""Publish remote state change on local instance."""
|
||||
domain, object_id = split_entity_id(entity_id)
|
||||
|
||||
self._all_entity_names.add(entity_id)
|
||||
|
||||
if entity_id in self._blacklist_e or domain in self._blacklist_d:
|
||||
return
|
||||
|
||||
if (
|
||||
(self._whitelist_e or self._whitelist_d)
|
||||
and entity_id not in self._whitelist_e
|
||||
and domain not in self._whitelist_d
|
||||
):
|
||||
return
|
||||
|
||||
for f in self._filter:
|
||||
if f[CONF_ENTITY_ID] and not f[CONF_ENTITY_ID].match(entity_id):
|
||||
continue
|
||||
if f[CONF_UNIT_OF_MEASUREMENT]:
|
||||
if CONF_UNIT_OF_MEASUREMENT not in attr:
|
||||
continue
|
||||
if f[CONF_UNIT_OF_MEASUREMENT] != attr[CONF_UNIT_OF_MEASUREMENT]:
|
||||
continue
|
||||
try:
|
||||
if f[CONF_BELOW] and float(state) < f[CONF_BELOW]:
|
||||
_LOGGER.info(
|
||||
"%s: ignoring state '%s', because " "below '%s'",
|
||||
entity_id,
|
||||
state,
|
||||
f[CONF_BELOW],
|
||||
)
|
||||
return
|
||||
if f[CONF_ABOVE] and float(state) > f[CONF_ABOVE]:
|
||||
_LOGGER.info(
|
||||
"%s: ignoring state '%s', because " "above '%s'",
|
||||
entity_id,
|
||||
state,
|
||||
f[CONF_ABOVE],
|
||||
)
|
||||
return
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
entity_id = self._prefixed_entity_id(entity_id)
|
||||
|
||||
# Add local customization data
|
||||
if DATA_CUSTOMIZE in self._hass.data:
|
||||
attr.update(self._hass.data[DATA_CUSTOMIZE].get(entity_id))
|
||||
|
||||
self._entities.add(entity_id)
|
||||
self._hass.states.async_set(entity_id, state, attr)
|
||||
|
||||
def fire_event(message):
|
||||
"""Publish remove event on local instance."""
|
||||
if message["type"] == "result":
|
||||
return
|
||||
|
||||
if message["type"] != "event":
|
||||
return
|
||||
|
||||
if message["event"]["event_type"] == "state_changed":
|
||||
data = message["event"]["data"]
|
||||
entity_id = data["entity_id"]
|
||||
if not data["new_state"]:
|
||||
entity_id = self._prefixed_entity_id(entity_id)
|
||||
# entity was removed in the remote instance
|
||||
with suppress(ValueError, AttributeError, KeyError):
|
||||
self._entities.remove(entity_id)
|
||||
with suppress(ValueError, AttributeError, KeyError):
|
||||
self._all_entity_names.remove(entity_id)
|
||||
self._hass.states.async_remove(entity_id)
|
||||
return
|
||||
|
||||
state = data["new_state"]["state"]
|
||||
attr = data["new_state"]["attributes"]
|
||||
state_changed(entity_id, state, attr)
|
||||
else:
|
||||
event = message["event"]
|
||||
self._hass.bus.async_fire(
|
||||
event_type=event["event_type"],
|
||||
event_data=event["data"],
|
||||
context=Context(
|
||||
id=event["context"].get("id"),
|
||||
user_id=event["context"].get("user_id"),
|
||||
parent_id=event["context"].get("parent_id"),
|
||||
),
|
||||
origin=EventOrigin.remote,
|
||||
)
|
||||
|
||||
def got_states(message):
|
||||
"""Called when list of remote states is available."""
|
||||
for entity in message["result"]:
|
||||
entity_id = entity["entity_id"]
|
||||
state = entity["state"]
|
||||
attributes = entity["attributes"]
|
||||
|
||||
state_changed(entity_id, state, attributes)
|
||||
|
||||
self._remove_listener = self._hass.bus.async_listen(
|
||||
EVENT_CALL_SERVICE, forward_event
|
||||
)
|
||||
|
||||
for event in self._subscribe_events:
|
||||
await self.call(fire_event, "subscribe_events", event_type=event)
|
||||
|
||||
await self.call(got_states, "get_states")
|
||||
|
||||
await self.proxy_services.load()
|
||||
Reference in New Issue
Block a user