import zmq
import time
import logging
from typing import Any
from .utils import _get_utils_dict
_level_map = {
    "debug": 10,
    "info": 20,
    "warning": 30,
    "error": 40,
    "critical": 50,
}
_commands: dict[str, dict[str, Any]] = {}
_commands["commands"] = {
    "command": "commands",
    "description": "get all possible commands",
    "message": "",  # message to log to file
    "level": "info",  # log level
    "source": "client",
}
_commands["error"] = {
    "command": "error",
    "message": "Invalid message",
    "source": "client",
}
[docs]
class SocketClient:
    """Client to send messages to server"""
    def __init__(
        self,
        port: int,
        commands: dict,
        mode: str = "json",
        request_timeout: int = 5000,
        logger: logging.Logger | None = None,
    ) -> None:
        self._mode = mode
        self._REQUEST_TIMEOUT = request_timeout
        self._commands = commands
        self._REQUEST_RETRIES = 3
        self._SERVER_ENDPOINT = f"tcp://localhost:{port}"
        self._context = zmq.Context()
        self._socket = self._context.socket(zmq.REQ)
        self._socket.connect(self._SERVER_ENDPOINT)
        self._logger = logger
[docs]
    def log(self, message: str, level: str = "info") -> None:
        """Log message
        Args:
            message (str): message to log
            level (str, optional): log level. Defaults to
            "info". Possible values: "debug", "info", "warning", "error", "critical"
        """
        if self._logger:
            self._logger.log(_level_map[level], message)
        else:
            print(f"{level.upper()}: {message}") 
[docs]
    def command(self, command: dict) -> dict:
        """Send command to board
        Args:
            command (dict): message to send
        Returns:
            reply (dict): reply from board
        """
        if not isinstance(command, dict):
            return self._invalid_command_response(
                "Invalid command type, must be dictionary, check possible commands"
            )
        return self._attempt_command(command) 
    def _attempt_command(self, command: dict) -> dict:
        self._send(command)
        retries_left = self._REQUEST_RETRIES
        while retries_left > 0:
            try:
                if (self._socket.poll(self._REQUEST_TIMEOUT) & zmq.POLLIN) != 0:
                    return self._receive()
            except zmq.ZMQError as e:
                self.log(f"ZMQ Error: {e}", level="error")
            retries_left -= 1
            self.log("No response from server, retrying...", level="warning")
            time.sleep(2)
            self._reset_socket()
        return {
            "command": command["command"],
            "direction": "reply",
            "message": "Server seems to be offline. Connection error",
        }
    def _reset_socket(self) -> None:
        self._socket.setsockopt(zmq.LINGER, 0)
        self._socket.close()
        self._context = zmq.Context()
        self._socket = self._context.socket(zmq.REQ)
        self._socket.connect(self._SERVER_ENDPOINT)
    def _invalid_command_response(self, message: str) -> dict:
        self.log(message, level="error")
        reply = self.commands["error"].copy()
        reply["message"] = message
        reply["source"] = "client"
        return reply
    def _send(self, obj: dict) -> None:
        """Send a json object"""
        if self._mode == "json":
            self._socket.send_json(obj)
        else:
            self._socket.send_pyobj(obj)
    def _receive(self) -> dict:
        """Receive a json or pickle object"""
        if self._mode == "json":
            message = self._socket.recv_json()
        else:
            message = self._socket.recv_pyobj()
        if not isinstance(message, dict):
            self.log("Received invalid message", level="error")
            reply = self._commands["error"].copy()
            reply["message"] = "Received invalid message"
            return reply
        return message
[docs]
    def get_commands(self) -> dict:
        """Get all possible commands
        Returns:
            dict: all possible commands
        """
        return self.command(self._commands["commands"]) 
 
[docs]
class BoardControl(SocketClient):
    """Control BrainAccess Board via messages"""
    def __init__(
        self, logger: logging.Logger | None = None, request_timeout: int = 5000
    ) -> None:
        try:
            utils = _get_utils_dict()
            if utils is None:
                raise Exception("Board is not connected")
            port = utils.socket_port
        except Exception as e:
            self.log(f"Socket port not found: {str(e)}", level="error")
            raise Exception("Socket port not found, please restart the app")
        super().__init__(
            port, _commands, logger=logger, mode="json", request_timeout=request_timeout
        )
        self.log(f"Board Control created using port: {port}")