import paho.mqtt.client as mqtt from typing import Callable from pydantic import BaseModel import json from robot_control.src.utils.logging import LoggerSingleton logger = LoggerSingleton.get_logger() class MQTTDevice: def __init__(self, device_id: str, num_slots: int): self.device_id = device_id self.num_slots = num_slots class MeasurementResult(BaseModel): device_id: str slot_id: int cell_id: int capacity: float status: str class MQTTHandler: def __init__(self, broker="localhost", port=1883, username=None, password=None): self.client = mqtt.Client() self.devices: list[MQTTDevice] = [] self.measurement_callbacks: dict[str, dict[int, Callable]] = {} self.client.username_pw_set(username, password) self.client.on_connect = self.on_connect self.client.on_message = self.on_message if broker == "debug": self.client.connect("test.mosquitto.org", 1883) return self.client.connect(broker, port, 60) self.client.loop_start() def register_device(self, device_id, num_slots, callback: Callable = None): """Register a new device to handle""" device = MQTTDevice(device_id, num_slots) self.devices.append(device) self.measurement_callbacks[device_id] = {} if callback: for slot in range(num_slots): self.measurement_callbacks[device_id][slot] = callback def _subscribe_device_topics(self, device_id: str): """Subscribe to all topics for a specific device""" topics = [ f"measurement_done/{device_id}", f"soa/{device_id}" ] for topic in topics: self.client.subscribe(topic) logger.info(f"Subscribed to {topic}") def on_connect(self, client, userdata, flags, rc): if rc == 0: logger.info("Connected to MQTT Broker!") else: raise ConnectionError(f"Failed to connect, return code {rc}") # Resubscribe to all device topics on reconnect for device in self.devices: self._subscribe_device_topics(device.device_id) def on_message(self, client, userdata, msg): try: payload = json.loads(msg.payload.decode()) topic = msg.topic device_id = topic.split('/')[1] # Extract device_id from topic if topic.startswith("measurement_done/"): result = MeasurementResult(**payload) logger.info(f"Measurement complete {result}") if result.device_id in self.measurement_callbacks and result.slot_id in self.measurement_callbacks[device_id]: self.measurement_callbacks[device_id][result.slot_id](result) else: logger.warning(f"No callback for measurement {result}") elif topic.startswith("soa/"): logger.info(f"SOA update for device {device_id}: {payload}") # TODO[SG]: Handle SOA update here except Exception as e: logger.error(f"Error processing message: {e}") def start_measurement(self, device_id: str, slot: int, cell_id: int): """Publish measurement start command for specific device""" if device_id not in [d.device_id for d in self.devices]: raise ValueError(f"Device {device_id} not registered") topic = f"cells_inserted/{device_id}" payload = {"slot_id": slot, "cell_id": cell_id} self.client.publish(topic, json.dumps(payload)) logger.info(f"MQTT msg published for {topic}: {payload}") def cleanup(self): """Cleanup MQTT connection""" self.client.loop_stop() self.client.disconnect()