mqtt_service.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import logging
  2. import paho.mqtt.client as mqtt
  3. from pydantic import BaseModel
  4. import json
  5. from typing import Dict, Callable, Optional
  6. from src.models.device import DeviceStatus
  7. logger = logging.getLogger(__name__)
  8. class InsertedCell(BaseModel):
  9. device_id: int
  10. slot_id: int
  11. cell_id: int
  12. class MQTTService:
  13. def __init__(self, config: dict):
  14. self.config = config
  15. broker_address = config['mqtt']['broker_address']
  16. port = config['mqtt']['port']
  17. keepalive = config['mqtt']['keepalive']
  18. username = config['mqtt']['username']
  19. password = config['mqtt']['password']
  20. debug = config['mqtt']['debug']
  21. self.client = mqtt.Client()
  22. self.client.username_pw_set(username, password)
  23. self.client.on_connect = self.on_connect
  24. self.client.on_message = self.on_message
  25. self.devices: dict[int, int] = {}
  26. self.insertion_callbacks: Dict[str, Dict[int, Callable]] = {}
  27. if debug:
  28. logger.info("No MQTT in debug mode")
  29. return
  30. if broker_address == "debug" or debug:
  31. self.client.connect("test.mosquitto.org", 1883)
  32. return
  33. try:
  34. self.client.connect(broker_address, port, keepalive)
  35. self.client.loop_start()
  36. except ConnectionRefusedError as e:
  37. raise ConnectionError(f"Failed to connect to MQTT broker at {broker_address}:{port}") from e
  38. def register_device(self, device_id, num_slots, callback: Optional[Callable] = None):
  39. """Register a new device to handle"""
  40. self.devices[device_id] = num_slots
  41. self.insertion_callbacks[device_id] = {}
  42. if callback:
  43. for slot in range(num_slots):
  44. self.insertion_callbacks[device_id][slot] = callback
  45. def _subscribe_device_topics(self, device_id: int):
  46. """Subscribe to all topics for a specific device"""
  47. topics = [
  48. f"cells_inserted/device_{device_id}",
  49. ]
  50. for topic in topics:
  51. self.client.subscribe(topic)
  52. logger.info(f"Subscribed to {topic}")
  53. def on_connect(self, client, userdata, flags, rc):
  54. if rc == 0:
  55. logger.info("Connected to MQTT Broker!")
  56. else:
  57. raise ConnectionError(f"Failed to connect, return code {rc}")
  58. # Resubscribe to all device topics on reconnect
  59. for device_id in self.devices.keys():
  60. self._subscribe_device_topics(device_id)
  61. def on_message(self, client, userdata, msg):
  62. try:
  63. payload = json.loads(msg.payload.decode())
  64. topic = msg.topic
  65. device_id = int(topic.split('/')[1].split('_')[1]) # Extract device_id number from topic
  66. inserted_cell = InsertedCell(device_id=device_id, **payload)
  67. logger.info(f"Cell inserted: {inserted_cell}")
  68. if device_id in self.insertion_callbacks and inserted_cell.slot_id in self.insertion_callbacks[device_id]:
  69. self.insertion_callbacks[device_id][inserted_cell.slot_id](inserted_cell)
  70. else:
  71. logger.warning(f"No callback for insertion {inserted_cell}")
  72. except Exception as e:
  73. logger.error(f"Error processing MQTT message: {e}")
  74. def cell_finished(self, device_id: int, slot_id: int, cell_id: int, capacity: float, status: DeviceStatus):
  75. """Publish a message for a cell finishing measurement"""
  76. if device_id not in self.devices:
  77. raise ValueError(f"Device {device_id} not registered")
  78. topic = f"measurement_done/{device_id}"
  79. payload = {
  80. "device_id": device_id,
  81. "slot_id": slot_id,
  82. "cell_id": cell_id,
  83. "capacity": round(capacity, 4),
  84. "status": status.name
  85. }
  86. self.client.publish(topic, json.dumps(payload))
  87. logger.info(f"MQTT msg published for {topic}: {payload}")
  88. def cleanup(self):
  89. """Cleanup MQTT connection"""
  90. self.client.loop_stop()
  91. self.client.disconnect()