502 lines
15 KiB
Python

"""
VitalLink Backend API
FastAPI server for managing patients, wristbands, and real-time data
"""
from fastapi import FastAPI, WebSocket, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Dict, Optional
from datetime import datetime, timedelta
from contextlib import asynccontextmanager
import asyncio
import json
import time
from collections import defaultdict
from triage_engine import TriageEngine, VitalSigns, TriageLevel, triage_from_vitals
# ============================================================================
# LIFESPAN MANAGEMENT
# ============================================================================
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
print("=" * 80)
print("VitalLink Backend API Started")
print("=" * 80)
print("API Documentation: http://localhost:8000/docs")
print("WebSocket Endpoint: ws://localhost:8000/ws")
print("=" * 80)
yield
# Shutdown
print("\nVitalLink Backend API Shutting Down")
# ============================================================================
# APP INITIALIZATION
# ============================================================================
app = FastAPI(title="VitalLink API", version="1.0.0", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ============================================================================
# DATA MODELS
# ============================================================================
class PatientCheckIn(BaseModel):
firstName: str
lastName: str
dob: str
symptoms: List[str]
severity: str
class Patient(BaseModel):
patient_id: str
band_id: str
first_name: str
last_name: str
dob: str
symptoms: List[str]
severity: str
check_in_time: datetime
current_tier: str = "NORMAL"
last_vitals: Optional[dict] = None
is_active: bool = True
class VitalsData(BaseModel):
band_id: str
patient_id: str
timestamp: float
tier: str
hr_bpm: int
spo2: int
temp_c: float
activity: float
flags: List[str]
seq: int
class QueuePosition(BaseModel):
patient_id: str
band_id: str
name: str
tier: str
priority_score: float
wait_time_minutes: int
last_hr: int
last_spo2: int
last_temp: float
# ============================================================================
# IN-MEMORY STORAGE
# ============================================================================
patients_db: Dict[str, Patient] = {}
vitals_history: Dict[str, List[VitalsData]] = defaultdict(list)
# Tier stability tracking
tier_trackers: Dict[str, Dict] = {} # patient_id -> tracker info
available_bands = [
f"VitalLink-{hex(i)[2:].upper().zfill(4)}" for i in range(0x1000, 0x2000)
]
active_websockets: List[WebSocket] = []
# Wristband details cache
wristband_details_cache = {}
# ============================================================================
# PRIORITY ALGORITHM
# ============================================================================
def calculate_priority_score(patient: Patient) -> float:
"""
Calculate priority score using centralized triage engine
"""
# If no vitals yet, use basic scoring
if not patient.last_vitals:
severity_scores = {"severe": 50, "moderate": 30, "mild": 20}
return severity_scores.get(patient.severity, 20)
# Create VitalSigns object
vitals = VitalSigns(
heart_rate=patient.last_vitals.get("hr_bpm", 75),
spo2=patient.last_vitals.get("spo2", 98),
temperature=patient.last_vitals.get("temp_c", 37.0),
activity=patient.last_vitals.get("activity", 0.0),
)
# Calculate wait time
wait_minutes = int((datetime.now() - patient.check_in_time).total_seconds() / 60)
# Use triage engine to assess
assessment = TriageEngine.assess_patient(
vitals=vitals,
symptoms=patient.symptoms,
severity=patient.severity,
age=None, # Add age field to Patient model if needed
preexisting=[], # Add preexisting field to Patient model if needed
wait_time_minutes=wait_minutes,
)
# Update patient tier based on assessment
patient.current_tier = assessment["tier_name"]
return assessment["priority_score"]
# ============================================================================
# API ENDPOINTS
# ============================================================================
@app.get("/")
async def root():
return {
"message": "VitalLink Backend API",
"version": "1.0.0",
"docs": "/docs",
"status": "running",
}
@app.post("/api/checkin")
async def check_in_patient(data: PatientCheckIn):
if not available_bands:
raise HTTPException(status_code=503, detail="No wristbands available")
patient_id = f"P{len(patients_db) + 100001}"
band_id = available_bands.pop(0)
patient = Patient(
patient_id=patient_id,
band_id=band_id,
first_name=data.firstName,
last_name=data.lastName,
dob=data.dob,
symptoms=data.symptoms,
severity=data.severity,
check_in_time=datetime.now(),
current_tier="NORMAL",
)
patients_db[patient_id] = patient
await broadcast_update({"type": "patient_added", "patient": patient.dict()})
return {
"patient_id": patient_id,
"band_id": band_id,
"message": "Check-in successful",
}
@app.post("/api/vitals")
async def receive_vitals(data: VitalsData):
"""Receive vitals data from base station with tier stability"""
patient_id = data.patient_id
if patient_id not in patients_db:
raise HTTPException(status_code=404, detail="Patient not found")
patient = patients_db[patient_id]
# Initialize tier tracker if needed
if patient_id not in tier_trackers:
tier_trackers[patient_id] = {
"current_tier": "NORMAL",
"tier_since": time.time(),
"consecutive_readings": 0,
"pending_tier": None,
"pending_count": 0,
}
tracker = tier_trackers[patient_id]
# Use triage engine to determine what tier vitals suggest
vitals = VitalSigns(
heart_rate=data.hr_bpm,
spo2=data.spo2,
temperature=data.temp_c,
activity=data.activity,
)
# Determine suggested tier
suggested_tier = triage_from_vitals(data.hr_bpm, data.spo2, data.temp_c)
# Apply tier stability logic
current_tier = tracker["current_tier"]
time_in_tier = time.time() - tracker["tier_since"]
# Determine if tier should change
should_change = False
if suggested_tier == current_tier:
# Same tier - reset pending change
tracker["consecutive_readings"] += 1
tracker["pending_tier"] = None
tracker["pending_count"] = 0
else:
# Different tier suggested
is_upgrade = (
suggested_tier == "EMERGENCY" and current_tier in ["ALERT", "NORMAL"]
) or (suggested_tier == "ALERT" and current_tier == "NORMAL")
is_downgrade = (
suggested_tier == "NORMAL" and current_tier in ["ALERT", "EMERGENCY"]
) or (suggested_tier == "ALERT" and current_tier == "EMERGENCY")
# Track pending change
if tracker["pending_tier"] == suggested_tier:
tracker["pending_count"] += 1
else:
tracker["pending_tier"] = suggested_tier
tracker["pending_count"] = 1
# Determine if we have enough confirmations
if is_upgrade:
# Upgrade to higher tier - need 2 consecutive readings
required_confirmations = 2
min_time_required = 10.0 # 10 seconds minimum
else:
# Downgrade to lower tier - need 5 consecutive readings
required_confirmations = 5
min_time_required = 60.0 # 60 seconds minimum
# Check if we should change tier
if (
tracker["pending_count"] >= required_confirmations
and time_in_tier >= min_time_required
):
should_change = True
# Apply tier change if confirmed
if should_change:
old_tier = tracker["current_tier"]
new_tier = suggested_tier
tracker["current_tier"] = new_tier
tracker["tier_since"] = time.time()
tracker["consecutive_readings"] = 0
tracker["pending_tier"] = None
tracker["pending_count"] = 0
print(
f"🔄 TIER CHANGE: {patient_id} {old_tier}{new_tier} (confirmed after {tracker['pending_count']} readings)"
)
# Use confirmed tier
final_tier = tracker["current_tier"]
patient.current_tier = final_tier
patient.last_vitals = data.dict()
# Store in history
vitals_history[patient_id].append(data)
if len(vitals_history[patient_id]) > 1000:
vitals_history[patient_id] = vitals_history[patient_id][-1000:]
# Check for deterioration if we have history
if len(vitals_history[patient_id]) >= 3:
previous = [
VitalSigns(v.hr_bpm, v.spo2, v.temp_c, v.activity)
for v in vitals_history[patient_id][-4:-1]
]
deterioration = TriageEngine.detect_deterioration(vitals, previous)
if deterioration["deteriorating"]:
print(f"⚠️ DETERIORATION DETECTED: {patient_id}")
print(f" Concerns: {', '.join(deterioration['concerns'])}")
# If deteriorating, force upgrade to at least ALERT
if final_tier == "NORMAL" and tracker["pending_tier"] != "ALERT":
tracker["pending_tier"] = "ALERT"
tracker["pending_count"] = 1
print(f" ⬆️ Escalation initiated due to deterioration")
# Broadcast update
await broadcast_update(
{
"type": "vitals_update",
"patient_id": patient_id,
"vitals": data.dict(),
"tier": final_tier,
}
)
return {
"status": "received",
"tier": final_tier,
"suggested_tier": suggested_tier,
"confirmed": suggested_tier == final_tier,
}
@app.get("/api/queue")
async def get_queue():
active_patients = [p for p in patients_db.values() if p.is_active]
queue = []
for patient in active_patients:
priority_score = calculate_priority_score(patient)
wait_minutes = int(
(datetime.now() - patient.check_in_time).total_seconds() / 60
)
queue.append(
QueuePosition(
patient_id=patient.patient_id,
band_id=patient.band_id,
name=f"{patient.first_name} {patient.last_name}",
tier=patient.current_tier,
priority_score=priority_score,
wait_time_minutes=wait_minutes,
last_hr=patient.last_vitals.get("hr_bpm", 0)
if patient.last_vitals
else 0,
last_spo2=patient.last_vitals.get("spo2", 0)
if patient.last_vitals
else 0,
last_temp=patient.last_vitals.get("temp_c", 0)
if patient.last_vitals
else 0,
)
)
queue.sort(key=lambda x: x.priority_score, reverse=True)
return queue
@app.get("/api/patients/{patient_id}")
async def get_patient_details(patient_id: str):
if patient_id not in patients_db:
raise HTTPException(status_code=404, detail="Patient not found")
patient = patients_db[patient_id]
history = vitals_history.get(patient_id, [])
return {
"patient": patient.dict(),
"vitals_history": [v.dict() for v in history[-50:]],
"priority_score": calculate_priority_score(patient),
}
@app.post("/api/patients/{patient_id}/discharge")
async def discharge_patient(patient_id: str):
if patient_id not in patients_db:
raise HTTPException(status_code=404, detail="Patient not found")
patient = patients_db[patient_id]
patient.is_active = False
available_bands.append(patient.band_id)
await broadcast_update({"type": "patient_discharged", "patient_id": patient_id})
return {"message": "Patient discharged", "band_returned": patient.band_id}
@app.get("/api/stats")
async def get_statistics():
active_patients = [p for p in patients_db.values() if p.is_active]
tier_counts = {"EMERGENCY": 0, "ALERT": 0, "NORMAL": 0}
for patient in active_patients:
tier_counts[patient.current_tier] += 1
total_vitals = sum(len(v) for v in vitals_history.values())
avg_wait = 0
if active_patients:
wait_times = [
(datetime.now() - p.check_in_time).total_seconds() / 60
for p in active_patients
]
avg_wait = sum(wait_times) / len(wait_times)
return {
"total_patients": len(patients_db),
"active_patients": len(active_patients),
"tier_breakdown": tier_counts,
"available_bands": len(available_bands),
"total_vitals_received": total_vitals,
"average_wait_minutes": round(avg_wait, 1),
}
# ============================================================================
# WRISTBAND ENDPOINTS
# ============================================================================
@app.post("/api/wristband-details")
async def update_wristband_details(data: dict):
"""Receive wristband details from wristband system"""
global wristband_details_cache
wristband_details_cache = data
return {"status": "updated"}
@app.get("/api/wristband-details")
async def get_cached_wristband_details():
"""Get cached wristband details"""
return wristband_details_cache
# ============================================================================
# WEBSOCKET
# ============================================================================
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
active_websockets.append(websocket)
await websocket.send_json(
{"type": "connected", "message": "Connected to VitalLink server"}
)
try:
while True:
data = await websocket.receive_text()
except:
active_websockets.remove(websocket)
async def broadcast_update(message: dict):
disconnected = []
for websocket in active_websockets:
try:
await websocket.send_json(message)
except:
disconnected.append(websocket)
for ws in disconnected:
active_websockets.remove(ws)
# ============================================================================
# RUN SERVER
# ============================================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)