Overview
Production ML systems require continuous monitoring to ensure models perform as expected, detect degradation, and identify when retraining is needed. This guide covers comprehensive monitoring strategies for AQI prediction systems.Monitoring Architecture
┌─────────────────┐
│ Prediction │
│ Service │
└────────┬────────┘
│
├──────────▶ Metrics (Prometheus)
├──────────▶ Logs (ELK/Loki)
├──────────▶ Traces (Jaeger)
└──────────▶ Data Store
│
▼
┌───────────────────────┐
│ Monitoring Agent │
│ - Drift Detection │
│ - Performance Checks │
│ - Data Quality │
└───────────┬───────────┘
│
▼
┌───────────────────────┐
│ Alerting System │
│ - Slack/PagerDuty │
│ - Email │
└───────────────────────┘
Key Metrics
- Performance Metrics
- System Metrics
- Business Metrics
Track prediction accuracy and model performance:
from prometheus_client import Gauge, Histogram, Counter
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
# Define metrics
prediction_mae = Gauge('aqi_prediction_mae', 'Mean Absolute Error')
prediction_rmse = Gauge('aqi_prediction_rmse', 'Root Mean Squared Error')
prediction_r2 = Gauge('aqi_prediction_r2', 'R² Score')
prediction_error_distribution = Histogram(
'aqi_prediction_error',
'Distribution of prediction errors',
buckets=[0, 5, 10, 20, 30, 50, 100, float('inf')]
)
def track_predictions(y_true, y_pred):
"""Track prediction metrics."""
mae = mean_absolute_error(y_true, y_pred)
rmse = np.sqrt(mean_squared_error(y_true, y_pred))
r2 = r2_score(y_true, y_pred)
prediction_mae.set(mae)
prediction_rmse.set(rmse)
prediction_r2.set(r2)
# Track error distribution
for true, pred in zip(y_true, y_pred):
error = abs(true - pred)
prediction_error_distribution.observe(error)
Monitor infrastructure and service health:
from prometheus_client import Counter, Histogram, Gauge
import time
import psutil
# Request metrics
request_count = Counter(
'aqi_api_requests_total',
'Total API requests',
['endpoint', 'status']
)
request_duration = Histogram(
'aqi_api_request_duration_seconds',
'Request duration',
['endpoint']
)
# System metrics
cpu_usage = Gauge('aqi_cpu_usage_percent', 'CPU usage')
memory_usage = Gauge('aqi_memory_usage_bytes', 'Memory usage')
model_load_time = Gauge('aqi_model_load_time_seconds', 'Model load time')
def track_system_metrics():
"""Update system metrics."""
cpu_usage.set(psutil.cpu_percent())
memory_usage.set(psutil.virtual_memory().used)
def track_request(endpoint, status, duration):
"""Track API request metrics."""
request_count.labels(endpoint=endpoint, status=status).inc()
request_duration.labels(endpoint=endpoint).observe(duration)
Track business-relevant KPIs:
from prometheus_client import Counter, Gauge
# Usage metrics
daily_predictions = Counter('aqi_daily_predictions', 'Daily predictions')
unique_users = Gauge('aqi_unique_users_24h', 'Unique users (24h)')
api_quota_usage = Gauge('aqi_api_quota_usage_percent', 'API quota usage')
# Alert metrics
high_aqi_alerts = Counter(
'aqi_high_alerts_total',
'High AQI alerts sent',
['category']
)
# Geographic coverage
predictions_by_location = Counter(
'aqi_predictions_by_location',
'Predictions by location',
['city', 'country']
)
def track_business_metrics(prediction, location):
"""Track business metrics."""
daily_predictions.inc()
if prediction > 150:
category = 'unhealthy' if prediction <= 200 else 'very_unhealthy'
high_aqi_alerts.labels(category=category).inc()
predictions_by_location.labels(
city=location['city'],
country=location['country']
).inc()
Data Quality Monitoring
Input Validation
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple
from dataclasses import dataclass
@dataclass
class ValidationResult:
is_valid: bool
errors: List[str]
warnings: List[str]
class InputValidator:
"""Validate input features for AQI prediction."""
def __init__(self):
self.feature_ranges = {
'pm25': (0, 500),
'pm10': (0, 600),
'no2': (0, 200),
'so2': (0, 100),
'co': (0, 50),
'o3': (0, 300),
'temperature': (-50, 60),
'humidity': (0, 100),
'wind_speed': (0, 50),
'pressure': (900, 1100)
}
self.missing_threshold = 0.1 # 10% missing values
def validate(self, data: pd.DataFrame) -> ValidationResult:
"""Validate input data."""
errors = []
warnings = []
# Check missing values
missing_pct = data.isnull().sum() / len(data)
for col, pct in missing_pct.items():
if pct > self.missing_threshold:
errors.append(f"{col}: {pct:.1%} missing values")
# Check value ranges
for col, (min_val, max_val) in self.feature_ranges.items():
if col not in data.columns:
errors.append(f"Missing required column: {col}")
continue
out_of_range = (
(data[col] < min_val) | (data[col] > max_val)
).sum()
if out_of_range > 0:
pct = out_of_range / len(data)
if pct > 0.01: # More than 1%
errors.append(
f"{col}: {out_of_range} values out of range "
f"[{min_val}, {max_val}]"
)
else:
warnings.append(
f"{col}: {out_of_range} values out of range"
)
# Check for suspicious patterns
for col in data.select_dtypes(include=[np.number]).columns:
# Check for constant values
if data[col].nunique() == 1:
warnings.append(f"{col}: constant value {data[col].iloc[0]}")
# Check for suspicious zeros
zero_pct = (data[col] == 0).sum() / len(data)
if zero_pct > 0.5:
warnings.append(f"{col}: {zero_pct:.1%} zero values")
is_valid = len(errors) == 0
return ValidationResult(is_valid, errors, warnings)
Set up alerts for critical data quality issues. Invalid inputs can lead to unreliable predictions.
Prediction Drift Detection
Statistical Drift Detection
- PSI Calculation
- KS Test
- ADWIN
Population Stability Index for distribution shift:
import numpy as np
import pandas as pd
def calculate_psi(expected: np.ndarray, actual: np.ndarray, bins=10) -> float:
"""Calculate Population Stability Index."""
# Create bins based on expected distribution
breakpoints = np.percentile(expected, np.linspace(0, 100, bins + 1))
breakpoints = np.unique(breakpoints)
# Calculate distribution for each dataset
expected_percents = np.histogram(expected, breakpoints)[0] / len(expected)
actual_percents = np.histogram(actual, breakpoints)[0] / len(actual)
# Avoid division by zero
expected_percents = np.where(expected_percents == 0, 0.0001, expected_percents)
actual_percents = np.where(actual_percents == 0, 0.0001, actual_percents)
# Calculate PSI
psi = np.sum((actual_percents - expected_percents) *
np.log(actual_percents / expected_percents))
return psi
def interpret_psi(psi_value: float) -> str:
"""Interpret PSI value."""
if psi_value < 0.1:
return "No significant change"
elif psi_value < 0.25:
return "Small change detected"
else:
return "Major shift detected - retrain model"
# Example usage
reference_predictions = model.predict(X_reference)
current_predictions = model.predict(X_current)
psi = calculate_psi(reference_predictions, current_predictions)
print(f"PSI: {psi:.4f} - {interpret_psi(psi)}")
Kolmogorov-Smirnov test for distribution comparison:
from scipy import stats
import pandas as pd
def detect_drift_ks(reference: np.ndarray, current: np.ndarray,
alpha=0.05) -> Dict:
"""Detect drift using Kolmogorov-Smirnov test."""
statistic, p_value = stats.ks_2samp(reference, current)
drift_detected = p_value < alpha
return {
'drift_detected': drift_detected,
'statistic': statistic,
'p_value': p_value,
'alpha': alpha
}
def monitor_feature_drift(reference_df: pd.DataFrame,
current_df: pd.DataFrame,
features: List[str]) -> pd.DataFrame:
"""Monitor drift across multiple features."""
results = []
for feature in features:
result = detect_drift_ks(
reference_df[feature].values,
current_df[feature].values
)
result['feature'] = feature
results.append(result)
return pd.DataFrame(results)
# Example usage
drift_results = monitor_feature_drift(
X_reference, X_current,
['pm25', 'pm10', 'no2', 'temperature', 'humidity']
)
drifted = drift_results[drift_results['drift_detected']]
print(f"Features with drift: {drifted['feature'].tolist()}")
Adaptive Windowing for online drift detection:
from river import drift
import numpy as np
class OnlineDriftDetector:
"""Online drift detection using ADWIN."""
def __init__(self):
self.adwin = drift.ADWIN(delta=0.002)
self.error_stream = []
def update(self, y_true: float, y_pred: float) -> bool:
"""Update detector with new prediction error."""
error = abs(y_true - y_pred)
self.error_stream.append(error)
# Add error to ADWIN
self.adwin.update(error)
# Check if drift detected
if self.adwin.drift_detected:
return True
return False
def get_statistics(self) -> Dict:
"""Get current statistics."""
return {
'mean_error': self.adwin.estimation,
'n_samples': self.adwin.total,
'recent_errors': self.error_stream[-100:]
}
# Example usage
detector = OnlineDriftDetector()
for y_true, y_pred in zip(y_test, predictions):
drift_detected = detector.update(y_true, y_pred)
if drift_detected:
print("Drift detected! Consider retraining.")
# Trigger retraining pipeline
Use multiple drift detection methods for robust monitoring. Different methods detect different types of drift.
Performance Degradation Alerts
Alert Configuration
import smtplib
from email.mime.text import MIMEText
from typing import List, Dict
import requests
from dataclasses import dataclass
@dataclass
class Alert:
severity: str # 'warning', 'error', 'critical'
title: str
message: str
metrics: Dict
class AlertManager:
"""Manage alerts for model monitoring."""
def __init__(self, config: Dict):
self.config = config
self.alert_history = []
def send_email_alert(self, alert: Alert):
"""Send email alert."""
msg = MIMEText(f"{alert.message}\n\nMetrics: {alert.metrics}")
msg['Subject'] = f"[{alert.severity.upper()}] {alert.title}"
msg['From'] = self.config['email_from']
msg['To'] = ', '.join(self.config['email_to'])
with smtplib.SMTP(self.config['smtp_host']) as server:
server.send_message(msg)
def send_slack_alert(self, alert: Alert):
"""Send Slack alert."""
color_map = {
'warning': '#FFA500',
'error': '#FF0000',
'critical': '#8B0000'
}
payload = {
'attachments': [{
'color': color_map[alert.severity],
'title': alert.title,
'text': alert.message,
'fields': [
{'title': k, 'value': str(v), 'short': True}
for k, v in alert.metrics.items()
]
}]
}
requests.post(self.config['slack_webhook'], json=payload)
def send_pagerduty_alert(self, alert: Alert):
"""Send PagerDuty alert."""
if alert.severity != 'critical':
return # Only critical alerts to PagerDuty
payload = {
'routing_key': self.config['pagerduty_key'],
'event_action': 'trigger',
'payload': {
'summary': alert.title,
'severity': 'critical',
'source': 'aqi-predictor',
'custom_details': alert.metrics
}
}
requests.post(
'https://events.pagerduty.com/v2/enqueue',
json=payload
)
def trigger_alert(self, alert: Alert):
"""Trigger alert through configured channels."""
self.alert_history.append(alert)
# Send through all channels based on severity
if alert.severity in ['error', 'critical']:
self.send_email_alert(alert)
self.send_slack_alert(alert)
if alert.severity == 'critical':
self.send_pagerduty_alert(alert)
Logging and Tracing
Structured Logging
import logging
import json
from datetime import datetime
from typing import Dict, Any
class StructuredLogger:
"""Structured logging for ML predictions."""
def __init__(self, name: str):
self.logger = logging.getLogger(name)
self.logger.setLevel(logging.INFO)
# JSON formatter
handler = logging.StreamHandler()
handler.setFormatter(self.JSONFormatter())
self.logger.addHandler(handler)
class JSONFormatter(logging.Formatter):
def format(self, record):
log_data = {
'timestamp': datetime.utcnow().isoformat(),
'level': record.levelname,
'message': record.getMessage(),
'logger': record.name
}
# Add extra fields
if hasattr(record, 'extra'):
log_data.update(record.extra)
return json.dumps(log_data)
def log_prediction(self, request_id: str, features: Dict,
prediction: float, metadata: Dict):
"""Log prediction with full context."""
self.logger.info(
'Prediction made',
extra={
'request_id': request_id,
'features': features,
'prediction': prediction,
'model_version': metadata.get('model_version'),
'processing_time_ms': metadata.get('processing_time'),
'confidence': metadata.get('confidence')
}
)
def log_error(self, request_id: str, error: Exception, context: Dict):
"""Log error with context."""
self.logger.error(
f'Prediction error: {str(error)}',
extra={
'request_id': request_id,
'error_type': type(error).__name__,
'context': context
}
)
# Example usage
logger = StructuredLogger('aqi_predictor')
logger.log_prediction(
request_id='abc-123',
features={'pm25': 35.5, 'temperature': 22.5},
prediction=75.3,
metadata={'model_version': 'v1.2', 'processing_time': 45.2}
)
Distributed Tracing
from opentelemetry import trace
from opentelemetry.exporter.jaeger.thrift import JaegerExporter
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
import time
# Setup tracing
trace.set_tracer_provider(TracerProvider())
jaeger_exporter = JaegerExporter(
agent_host_name='localhost',
agent_port=6831,
)
trace.get_tracer_provider().add_span_processor(
BatchSpanProcessor(jaeger_exporter)
)
tracer = trace.get_tracer(__name__)
def predict_with_tracing(features):
"""Make prediction with distributed tracing."""
with tracer.start_as_current_span('predict') as span:
span.set_attribute('model.version', 'v1.2')
span.set_attribute('features.count', len(features))
# Preprocess
with tracer.start_as_current_span('preprocess'):
processed = preprocess(features)
# Inference
with tracer.start_as_current_span('inference'):
prediction = model.predict(processed)
# Postprocess
with tracer.start_as_current_span('postprocess'):
result = postprocess(prediction)
span.set_attribute('prediction.value', result)
return result
Structured logging and distributed tracing are essential for debugging production issues and understanding system behavior.
Dashboard Setup
Grafana Dashboard
{
"dashboard": {
"title": "AQI Predictor Monitoring",
"panels": [
{
"title": "Prediction MAE",
"targets": [{
"expr": "aqi_prediction_mae"
}],
"alert": {
"conditions": [{
"evaluator": {"params": [15], "type": "gt"},
"query": {"params": ["A", "5m", "now"]}
}]
}
},
{
"title": "Request Rate",
"targets": [{
"expr": "rate(aqi_api_requests_total[5m])"
}]
},
{
"title": "P95 Latency",
"targets": [{
"expr": "histogram_quantile(0.95, aqi_api_request_duration_seconds)"
}]
},
{
"title": "Error Rate",
"targets": [{
"expr": "rate(prediction_errors_total[5m])"
}]
}
]
}
}
Automated Retraining
class RetrainingTrigger:
"""Trigger model retraining based on monitoring signals."""
def __init__(self, config: Dict):
self.config = config
self.drift_detector = OnlineDriftDetector()
self.performance_window = []
def should_retrain(self, metrics: Dict) -> Tuple[bool, str]:
"""Determine if retraining is needed."""
reasons = []
# Check performance degradation
if metrics['mae'] > self.config['max_mae']:
reasons.append(f"MAE {metrics['mae']:.2f} exceeds threshold")
# Check drift
if metrics.get('drift_detected'):
reasons.append("Data drift detected")
# Check staleness
days_since_training = metrics.get('days_since_training', 0)
if days_since_training > self.config['max_model_age_days']:
reasons.append(f"Model is {days_since_training} days old")
# Check error rate trend
self.performance_window.append(metrics['mae'])
if len(self.performance_window) > 7:
self.performance_window.pop(0)
trend = np.polyfit(range(7), self.performance_window, 1)[0]
if trend > 0.5: # Increasing error trend
reasons.append("Performance degrading over time")
should_retrain = len(reasons) > 0
reason_str = "; ".join(reasons)
return should_retrain, reason_str
def trigger_retraining_pipeline(self, reason: str):
"""Trigger automated retraining pipeline."""
print(f"Triggering retraining: {reason}")
# Call training pipeline (e.g., Airflow, Kubeflow)
# requests.post('http://airflow:8080/api/v1/dags/aqi_training/dagRuns')
pass
Best Practices
Monitoring Strategy
Monitoring Strategy
- Monitor performance, system, and business metrics
- Set up alerts for critical degradation
- Use multiple drift detection methods
- Implement gradual rollout for model updates
- Maintain shadow deployments for testing
Data Collection
Data Collection
- Log all predictions with full feature context
- Store ground truth labels when available
- Track feature distributions over time
- Monitor data quality continuously
- Implement sampling for high-volume systems
Alert Fatigue
Alert Fatigue
- Set appropriate thresholds to avoid false alarms
- Use escalating severity levels
- Implement alert aggregation
- Review and adjust thresholds regularly
- Document response procedures
Continuous Improvement
Continuous Improvement
- Analyze monitoring data for insights
- A/B test model improvements
- Automate retraining when possible
- Conduct regular model audits
- Document lessons learned