mqtt_handler.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import paho.mqtt.client as mqtt
  2. from typing import Callable
  3. from pydantic import BaseModel
  4. import json
  5. from robot_control.src.utils.logging import LoggerSingleton
  6. logger = LoggerSingleton.get_logger()
  7. class MQTTDevice:
  8. def __init__(self, device_id: str, num_slots: int):
  9. self.device_id = device_id
  10. self.num_slots = num_slots
  11. class MeasurementResult(BaseModel):
  12. device_id: str
  13. slot_id: int
  14. cell_id: int
  15. capacity: float
  16. status: str
  17. class MQTTHandler:
  18. def __init__(self, broker="localhost", port=1883, username=None, password=None):
  19. self.client = mqtt.Client()
  20. self.devices: list[MQTTDevice] = []
  21. self.measurement_callbacks: dict[str, dict[int, Callable]] = {}
  22. self.client.username_pw_set(username, password)
  23. self.client.on_connect = self.on_connect
  24. self.client.on_message = self.on_message
  25. if broker == "debug":
  26. self.client.connect("test.mosquitto.org", 1883)
  27. return
  28. self.client.connect(broker, port, 60)
  29. self.client.loop_start()
  30. def register_device(self, device_id, num_slots, callback: Callable = None):
  31. """Register a new device to handle"""
  32. device = MQTTDevice(device_id, num_slots)
  33. self.devices.append(device)
  34. self.measurement_callbacks[device_id] = {}
  35. if callback:
  36. for slot in range(num_slots):
  37. self.measurement_callbacks[device_id][slot] = callback
  38. def _subscribe_device_topics(self, device_id: str):
  39. """Subscribe to all topics for a specific device"""
  40. topics = [
  41. f"measurement_done/{device_id}",
  42. f"soa/{device_id}"
  43. ]
  44. for topic in topics:
  45. self.client.subscribe(topic)
  46. logger.info(f"Subscribed to {topic}")
  47. def on_connect(self, client, userdata, flags, rc):
  48. if rc == 0:
  49. logger.info("Connected to MQTT Broker!")
  50. else:
  51. raise ConnectionError(f"Failed to connect, return code {rc}")
  52. # Resubscribe to all device topics on reconnect
  53. for device in self.devices:
  54. self._subscribe_device_topics(device.device_id)
  55. def on_message(self, client, userdata, msg):
  56. try:
  57. payload = json.loads(msg.payload.decode())
  58. topic = msg.topic
  59. device_id = topic.split('/')[1] # Extract device_id from topic
  60. if topic.startswith("measurement_done/"):
  61. result = MeasurementResult(**payload)
  62. logger.info(f"Measurement complete {result}")
  63. if result.device_id in self.measurement_callbacks and result.slot_id in self.measurement_callbacks[device_id]:
  64. self.measurement_callbacks[device_id][result.slot_id](result)
  65. else:
  66. logger.warning(f"No callback for measurement {result}")
  67. elif topic.startswith("soa/"):
  68. logger.info(f"SOA update for device {device_id}: {payload}")
  69. # TODO[SG]: Handle SOA update here
  70. except Exception as e:
  71. logger.error(f"Error processing message: {e}")
  72. def start_measurement(self, device_id: str, slot: int, cell_id: int):
  73. """Publish measurement start command for specific device"""
  74. if device_id not in [d.device_id for d in self.devices]:
  75. raise ValueError(f"Device {device_id} not registered")
  76. topic = f"cells_inserted/{device_id}"
  77. payload = {"slot_id": slot, "cell_id": cell_id}
  78. self.client.publish(topic, json.dumps(payload))
  79. logger.info(f"MQTT msg published for {topic}: {payload}")
  80. def cleanup(self):
  81. """Cleanup MQTT connection"""
  82. self.client.loop_stop()
  83. self.client.disconnect()