Browse Source

add mqtt api and test

Silas Gruen 10 months ago
parent
commit
598c718c79

+ 33 - 0
experiments/mqtt-test.py

@@ -0,0 +1,33 @@
+import paho.mqtt.client as mqtt
+import time
+import random  # Replace with actual voltage measurement code
+
+BROKER = "localhost"
+PORT = 1883  # Use 8883 for TLS
+TOPIC = "cells_inserted/device1"  # Change for each device
+USERNAME = "robot"
+PASSWORD = "robot"
+
+def on_connect(client, userdata, flags, rc):
+    if rc == 0:
+        print("Connected to MQTT Broker!")
+    else:
+        print(f"Failed to connect, return code {rc}\n")
+
+client = mqtt.Client()
+client.username_pw_set(USERNAME, PASSWORD)
+client.on_connect = on_connect
+
+client.connect(BROKER, PORT, 60)
+client.loop_start()
+
+try:
+    while True:
+        slot = int(random.uniform(0, 5))  # Replace with actual measurement
+        client.publish(TOPIC, slot)
+        print(f"Published: {slot} to topic {TOPIC}")
+        time.sleep(2)  # Adjust as needed
+except KeyboardInterrupt:
+    print("Disconnecting from broker")
+    client.disconnect()
+    client.loop_stop()

+ 7 - 0
robot-control/config/config.yaml

@@ -31,6 +31,13 @@ dropoff_grades:
     position: [500.0, 100.0, 50.0]
     capacity_threshold: 0
 
+mqtt:
+  broker: localhost
+  port: 1883
+  username: robot
+  password: robot
+  keepalive: 60
+
 system_settings:
   speed: 100.0
   acceleration: 500.0

+ 83 - 12
robot-control/src/api/routes.py

@@ -1,23 +1,94 @@
-from fastapi import FastAPI, HTTPException
-from typing import Dict
+import paho.mqtt.client as mqtt
+from typing import Dict, List, Callable
 from pydantic import BaseModel
+import json
+import logging
 
-app = FastAPI()
+logger = logging.getLogger(__name__)
 
-# TODO[SG]: add real API. THis is just a placeholder
+class Device:
+    def __init__(self, device_id: str, num_slots: int):
+        self.device_id = device_id
+        self.num_slots = num_slots
 
 class MeasurementResult(BaseModel):
+    device_id: str
     cell_id: str
     slot: int
     capacity: float
     status: str
 
-@app.post("/measurement/start")
-async def start_measurement(slot: int, cell_id: str):
-    # Implementation to start measurement
-    return {"status": "started", "slot": slot, "cell_id": cell_id}
+class MQTTHandler:
+    def __init__(self, broker="localhost", port=1883):
+        self.client = mqtt.Client()
+        self.devices: List[Device] = []
+        self.measurement_callbacks: Dict[str, Dict[int, Callable]] = {}
+        
+        self.client.on_connect = self.on_connect
+        self.client.on_message = self.on_message
+        
+        self.client.connect(broker, port, 60)
+        self.client.loop_start()
 
-@app.post("/measurement/complete")
-async def complete_measurement(result: MeasurementResult):
-    # Add completed measurement to robot's work queue
-    return {"status": "queued", "slot": result.slot}
+    def register_device(self, device: Device):
+        """Register a new device to handle"""
+        self.devices.append(device)
+        self.measurement_callbacks[device.device_id] = {}
+        # Subscribe to device specific topics
+        self._subscribe_device_topics(device.device_id)
+
+    def _subscribe_device_topics(self, device_id: str):
+        """Subscribe to all topics for a specific device"""
+        topics = [
+            f"measurement_done/{device_id}",
+            f"soa/{device_id}"
+        ]
+        for topic in topics:
+            self.client.subscribe(topic)
+            logger.info(f"Subscribed to {topic}")
+
+    def on_connect(self, client, userdata, flags, rc):
+        logger.info(f"Connected with result code {rc}")
+        # Resubscribe to all device topics on reconnect
+        for device in self.devices:
+            self._subscribe_device_topics(device.device_id)
+
+    def on_message(self, client, userdata, msg):
+        try:
+            payload = json.loads(msg.payload.decode())
+            topic = msg.topic
+            device_id = topic.split('/')[1]  # Extract device_id from topic
+
+            if topic.startswith("measurement_done/"):
+                result = MeasurementResult(**payload)
+                logger.info(f"Measurement complete for device {device_id}, slot {result.slot}")
+                if device_id in self.measurement_callbacks and result.slot in self.measurement_callbacks[device_id]:
+                    self.measurement_callbacks[device_id][result.slot](result)
+                    
+            elif topic.startswith("soa/"):
+                logger.info(f"SOA update for device {device_id}: {payload}")
+                # Handle SOA update here
+                
+        except Exception as e:
+            logger.error(f"Error processing message: {e}")
+
+    def start_measurement(self, device_id: str, slot: int, cell_id: str, callback: Callable = None):
+        """Publish measurement start command for specific device"""
+        if device_id not in [d.device_id for d in self.devices]:
+            raise ValueError(f"Device {device_id} not registered")
+            
+        payload = {"slot": slot, "cell_id": cell_id}
+        self.client.publish(f"cells_inserted/{device_id}", json.dumps(payload))
+        
+        if callback:
+            if device_id not in self.measurement_callbacks:
+                self.measurement_callbacks[device_id] = {}
+            self.measurement_callbacks[device_id][slot] = callback
+
+    def cleanup(self):
+        """Cleanup MQTT connection"""
+        self.client.loop_stop()
+        self.client.disconnect()
+
+# Create global MQTT handler instance
+mqtt_handler = MQTTHandler()

+ 11 - 0
robot-control/src/robot/config.py

@@ -81,3 +81,14 @@ class ConfigParser:
             acceleration=settings['acceleration'],
             safe_height=settings['safe_height']
         )
+
+    def get_mqtt_config(self) -> dict:
+        """Get MQTT broker configuration"""
+        mqtt_config = self.config.get('mqtt', {})
+        return {
+            'broker': mqtt_config.get('broker', 'localhost'),
+            'port': mqtt_config.get('port', 1883),
+            'username': mqtt_config.get('username', None),
+            'password': mqtt_config.get('password', None),
+            'keepalive': mqtt_config.get('keepalive', 60)
+        }

+ 39 - 20
robot-control/src/robot/controller.py

@@ -4,6 +4,7 @@ from typing import List, Tuple
 from .config import ConfigParser, SlotConfig
 from .movement import RobotMovement
 import logging
+from api.routes import MQTTHandler, Device
 
 logger = logging.getLogger(__name__)
 
@@ -38,12 +39,30 @@ class RobotController:
         self.movement = RobotMovement()
         self.movement.set_speed(self.system_settings.speed)
         
+        # Initialize MQTT handler
+        mqtt_config = self.config.get_mqtt_config()  # Add this to config parser
+        self.mqtt_handler = MQTTHandler(
+            broker=mqtt_config.get('broker'),
+            port=mqtt_config.get('port')
+        )
+        
+        # Register all devices with MQTT handler
+        for device in self.devices:
+            self.mqtt_handler.register_device(
+                Device(device_id=device.id, num_slots=len(device.slots))
+            )
+        
         # Initialize with configured values
         self.total_slots = sum(len(device.slots) for device in self.devices)
         self.work_queue: List[SlotConfig] = []
 
         self.gripper_occupied = False
 
+    async def cleanup(self):
+        """Cleanup resources on shutdown"""
+        # await self.movement.cleanup() TODO[SG]: Implement cleanup method in movement.py
+        self.mqtt_handler.cleanup()
+
     async def pick_cell_from_feeder(self):
         if self.gripper_occupied:
             logger.error("Gripper already occupied")
@@ -58,9 +77,9 @@ class RobotController:
             await self.movement.move_to_position(x, y, z)
             # Grip cell
             if await self.movement.activate_gripper():
-                self.gripper_occupied = True
-            else:
                 raise RuntimeError("Failed to grip cell")
+            self.gripper_occupied = True
+
             # Move back to safe height
             await self.movement.move_to_position(x, y, self.system_settings.safe_height)
             logger.info("Cell picked from feeder")
@@ -90,10 +109,10 @@ class RobotController:
             logger.error("Gripper not occupied")
             return
         slot = self.get_next_free_slot()
-        if slot:
-            await self.insert_cell_to_slot(cell, slot)
-        else:
+        if not slot:
             logger.error("No free slots available")
+            return
+        await self.insert_cell_to_slot(cell, slot)
 
     async def insert_cell_to_slot(self, cell: Cell, slot: SlotConfig):
         if slot.occupied:
@@ -111,12 +130,12 @@ class RobotController:
             # Move down to insertion position
             await self.movement.move_to_position(x, y, z)
             # Release cell
-            if await self.movement.deactivate_gripper():
-                slot.occupied = True
-                slot.cell_id = cell.id
-                self.gripper_occupied = False
-            else:
-                raise RuntimeError("Failed to release cell")
+            if not await self.movement.deactivate_gripper():
+                raise RuntimeError("Failed to release cell")            
+            slot.occupied = True
+            slot.cell_id = cell.id
+            self.gripper_occupied = False
+            
             # Move back to safe height
             await self.movement.move_to_position(x, y, self.system_settings.safe_height)
             logger.info(f"Cell {cell.id} inserted to slot at position {slot.position}")
@@ -137,13 +156,13 @@ class RobotController:
             # Move down to collection position
             await self.movement.move_to_position(x, y, z)
             # Grip cell
-            if await self.movement.activate_gripper():
-                self.gripper_occupied = True
-                cell_id = slot.cell_id
-                slot.occupied = False
-                slot.cell_id = None
-            else:
+            if not await self.movement.activate_gripper():
                 raise RuntimeError("Failed to grip cell")
+            self.gripper_occupied = True
+            cell_id = slot.cell_id
+            slot.occupied = False
+            slot.cell_id = None
+            
             # Move back to safe height
             await self.movement.move_to_position(x, y, self.system_settings.safe_height)
             logger.info(f"Cell {cell_id} collected from slot at position {slot.position}")
@@ -183,10 +202,10 @@ class RobotController:
             # Move to dropoff position
             await self.movement.move_to_position(x, y, z)
             # Release cell
-            if await self.movement.deactivate_gripper():
-                self.gripper_occupied = False
-            else:
+            if not await self.movement.deactivate_gripper():
                 raise RuntimeError("Failed to release cell")
+            self.gripper_occupied = False
+
             # Move back to safe height
             await self.movement.move_to_position(x, y, self.system_settings.safe_height)
             logger.info(f"Cell dropped off at grade {dropoff_grade.name}")

+ 54 - 31
robot-control/tests/test_api.py

@@ -1,39 +1,62 @@
 import pytest
 from fastapi.testclient import TestClient
-from api.routes import app
+from api.routes import MQTTHandler, Device, MeasurementResult, mqtt
+from unittest.mock import Mock, patch
 
 @pytest.fixture
-def client():
-    return TestClient(app)
+def mock_mqtt_client():
+    with patch('paho.mqtt.client.Client') as mock_client:
+        client_instance = Mock()
+        mock_client.return_value = client_instance
+        yield client_instance
 
-class TestAPIEndpoints:
-    def test_get_status(self, client):
-        pass
-        # response = client.get("/status")
-        # assert response.status_code == 200
-        # assert "status" in response.json()
+@pytest.fixture
+def mqtt_handler(mock_mqtt_client):
+    handler = MQTTHandler()
+    yield handler
+    handler.cleanup()
+
+class TestMQTTHandler:
+    def test_device_registration(self, mqtt_handler:MQTTHandler):
+        device = Device("test_device", 4)
+        mqtt_handler.register_device(device)
+        assert device in mqtt_handler.devices
+        assert "test_device" in mqtt_handler.measurement_callbacks
 
-    def test_get_robot_position(self, client):
-        pass
-        # response = client.get("/robot/position")
-        # assert response.status_code == 200
-        # data = response.json()
-        # assert "x" in data
-        # assert "y" in data
-        # assert "z" in data
+    def test_start_measurement(self, mqtt_handler:MQTTHandler, mock_mqtt_client:mqtt.Client):
+        device = Device("test_device", 4)
+        mqtt_handler.register_device(device)
+        
+        callback = Mock()
+        mqtt_handler.start_measurement("test_device", 1, "cell123", callback)
+        
+        mock_mqtt_client.publish.assert_called_once()
+        assert mqtt_handler.measurement_callbacks["test_device"][1] == callback
 
-    def test_get_slots_status(self, client):
-        pass
-        # response = client.get("/slots")
-        # assert response.status_code == 200
-        # slots = response.json()
-        # assert isinstance(slots, list)
-        # if len(slots) > 0:
-        #     assert "id" in slots[0]
-        #     assert "occupied" in slots[0]
+    def test_measurement_callback(self, mqtt_handler:MQTTHandler):
+        device = Device("test_device", 4)
+        mqtt_handler.register_device(device)
+        
+        callback = Mock()
+        mqtt_handler.start_measurement("test_device", 1, "cell123", callback)
+        
+        # Simulate measurement complete message
+        result = MeasurementResult(
+            device_id="test_device",
+            cell_id="cell123",
+            slot=1,
+            capacity=3000.0,
+            status="complete"
+        )
+        
+        # Simulate MQTT message
+        message = Mock()
+        message.topic = "measurement_done/test_device"
+        message.payload = result.json().encode()
+        
+        mqtt_handler.on_message(None, None, message)
+        callback.assert_called_once()
 
-    def test_emergency_stop(self, client):
-        pass
-        # response = client.post("/robot/stop")
-        # assert response.status_code == 200
-        # assert response.json()["status"] == "stopped"
+    def test_invalid_device(self, mqtt_handler:MQTTHandler):
+        with pytest.raises(ValueError):
+            mqtt_handler.start_measurement("invalid_device", 1, "cell123")