test_api.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import pytest
  2. from fastapi.testclient import TestClient
  3. from api.routes import MQTTHandler, Device, MeasurementResult, mqtt
  4. from unittest.mock import Mock, patch
  5. @pytest.fixture
  6. def mock_mqtt_client():
  7. with patch('paho.mqtt.client.Client') as mock_client:
  8. client_instance = Mock()
  9. mock_client.return_value = client_instance
  10. yield client_instance
  11. @pytest.fixture
  12. def mqtt_handler(mock_mqtt_client):
  13. handler = MQTTHandler()
  14. yield handler
  15. handler.cleanup()
  16. class TestMQTTHandler:
  17. def test_device_registration(self, mqtt_handler:MQTTHandler):
  18. device = Device("test_device", 4)
  19. mqtt_handler.register_device(device)
  20. assert device in mqtt_handler.devices
  21. assert "test_device" in mqtt_handler.measurement_callbacks
  22. def test_start_measurement(self, mqtt_handler:MQTTHandler, mock_mqtt_client:mqtt.Client):
  23. device = Device("test_device", 4)
  24. mqtt_handler.register_device(device)
  25. callback = Mock()
  26. mqtt_handler.start_measurement("test_device", 1, "cell123", callback)
  27. mock_mqtt_client.publish.assert_called_once()
  28. assert mqtt_handler.measurement_callbacks["test_device"][1] == callback
  29. def test_measurement_callback(self, mqtt_handler:MQTTHandler):
  30. device = Device("test_device", 4)
  31. mqtt_handler.register_device(device)
  32. callback = Mock()
  33. mqtt_handler.start_measurement("test_device", 1, "cell123", callback)
  34. # Simulate measurement complete message
  35. result = MeasurementResult(
  36. device_id="test_device",
  37. cell_id="cell123",
  38. slot=1,
  39. capacity=3000.0,
  40. status="complete"
  41. )
  42. # Simulate MQTT message
  43. message = Mock()
  44. message.topic = "measurement_done/test_device"
  45. message.payload = result.json().encode()
  46. mqtt_handler.on_message(None, None, message)
  47. callback.assert_called_once()
  48. def test_invalid_device(self, mqtt_handler:MQTTHandler):
  49. with pytest.raises(ValueError):
  50. mqtt_handler.start_measurement("invalid_device", 1, "cell123")