import asyncio
from datetime import datetime, timezone
import sys
from pathlib import Path
project_root = str(Path(__file__).parents[1])
if project_root not in sys.path:
    sys.path.append(project_root)
from powersensor_local.listener import PowersensorListener
from powersensor_local.xlatemsg import translate_raw_message
EXPIRY_CHECK_INTERVAL_S = 30
EXPIRY_TIMEOUT_S = 5 * 60
def _make_events(obj, relayer):
    evs = []
    kvs = translate_raw_message(obj, relayer)
    for key, ev in kvs.items():
        ev['event'] = key
        evs.append(ev)
    return evs
[docs]
class PowersensorDevices:
    """Abstraction interface for the unified event stream from all Powersensor 
    devices on the local network.
    """
    def __init__(self, broadcast_address='<broadcast>'):
        """Creates a fresh instance, without scanning for devices."""
        self._event_cb = None
        self._ps = PowersensorListener(broadcast_address)
        self._devices = dict()
        self._timer = None
[docs]
    async def start(self, async_event_callback):
        """Registers the async event callback function and starts the scan
        of the local network to discover present devices. The callback is
        of the form
        Parameters:
        -----------
        async_event_callback : Callable
            A callable asynchronous method for handling json messages. Example::
                async def your_callback(event: dict):
                    pass
        Known Events:
        -------------
            * scan_complete
                Indicates the discovery of Powersensor devices has completed.
                Emitted in response to start() and rescan() calls.
                The number of found gateways (plugs) is reported.::
                    { event: "scan_complete", gateway_count: N }
            * device_found
                A new device found on the network.
                The order found devices are announced is not fixed.::
                    { event: "device_found",
                      device_type: "plug" or "sensor",
                      mac: "...",
                    }
                An optional field named "via" is present for sensor devices, and
                shows the MAC address of the gateway the sensor is communicating
                via.
            * device_lost
                A device appears to no longer be present on the network.::
                    { event: "device_lost", mac: "..." }
        Additionally, all events described in xlatemsg.translate_raw_message
        may be issued. The event name is inserted into the field 'event'.
        The start function returns the number of found gateway plugs.
        Powersensor devices aren't found directly as they are typically not
        on the network, but are instead detected when they relay data through
        a plug via long-range radio.
        """
        self._event_cb = async_event_callback
        await self._on_scanned(await self._ps.scan())
        self._timer = self._Timer(EXPIRY_CHECK_INTERVAL_S, self._on_timer)
        return len(self._ips) 
[docs]
    async def rescan(self):
        """Performs a fresh scan of the network to discover added devices,
        or devices which have changed their IP address for some reason."""
        await self._on_scanned(await self._ps.scan()) 
[docs]
    async def stop(self):
        """Stops the event streaming and disconnects from the devices.
        To restart the event streaming, call start() again."""
        await self._ps.unsubscribe()
        await self._ps.stop()
        self._event_cb = None
        if self._timer:
            self._timer.terminate()
            self._timer = None 
[docs]
    def subscribe(self, mac):
        """Subscribes to events from the device with the given MAC address."""
        device = self._devices.get(mac)
        if device:
            device.subscribed = True 
[docs]
    def unsubscribe(self, mac):
        """Unsubscribes from events from the given MAC address."""
        device = self._devices.get(mac)
        if device:
            device.subscribed = False 
    async def _on_scanned(self, ips):
        self._ips = ips
        if self._event_cb:
            ev = {
                'event': 'scan_complete',
                'gateway_count': len(ips),
            }
            await self._event_cb(ev)
        await self._ps.subscribe(self._on_msg)
    async def _on_msg(self, obj):
        mac = obj.get('mac')
        if mac and not self._devices.get(mac):
            typ = obj.get('device')
            via = obj.get('via')
            await self._add_device(mac, typ, via)
        device = self._devices[mac]
        device.mark_active()
        if self._event_cb and device.subscribed:
            relayer = obj.get('via') or mac
            evs = _make_events(obj, relayer)
            if len(evs) > 0:
                for ev in evs:
                    await self._event_cb(ev)
    async def _on_timer(self):
        devices = list(self._devices.values())
        for device in devices:
            if device.has_expired():
                await self._remove_device(device.mac)
    async def _add_device(self, mac, typ, via):
        self._devices[mac] = self._Device(mac, typ, via)
        if self._event_cb:
            ev = {
                'event': 'device_found',
                'device_type': typ,
                'mac': mac,
            }
            if via:
                ev['via'] = via
            await self._event_cb(ev)
    async def _remove_device(self, mac):
        if self._devices.get(mac):
            self._devices.pop(mac)
            if self._event_cb:
                ev = {
                    'event': 'device_lost',
                    'mac': mac
                }
                await self._event_cb(ev)
    ### Supporting classes ###
    class _Device:
        def __init__(self, mac, typ, via):
            self.mac = mac
            self.type = typ
            self.via = via
            self.subscribed = False
            self._last_active = datetime.now(timezone.utc)
        def mark_active(self):
            self._last_active = datetime.now(timezone.utc)
        def has_expired(self):
            now = datetime.now(timezone.utc)
            delta = now - self._last_active
            return delta.total_seconds() > EXPIRY_TIMEOUT_S
    class _Timer:
        def __init__(self, interval_s, callback):
            self._terminate = False
            self._interval = interval_s
            self._callback = callback
            self._task = asyncio.create_task(self._run())
        def terminate(self):
            self._terminate = True
            self._task.cancel()
        async def _run(self):
            while not self._terminate:
                await asyncio.sleep(self._interval)
                await self._callback()