"""Model manager for ensemble predictions."""
import json
import os
from typing import Dict, List, Optional
import numpy as np
import pandas as pd
from tensorflow import keras
from .data import NormalizationParams
[docs]
class EnsembleModelManager:
"""Manages multiple models for ensemble predictions.
This class handles:
- Saving and loading ensemble models with metadata
- Lazy loading of model weights
- Efficient storage of normalization parameters
- Model versioning and validation
"""
[docs]
def __init__(self, base_path: str):
"""Initialize model manager.
Args:
base_path: Base path for saving/loading models
"""
self.base_path = base_path
self.models_info = []
self.loaded_models = {} # Cache for loaded models
[docs]
def save_ensemble(
self, models_info: List[Dict], ensemble_metadata: Optional[Dict] = None
) -> None:
"""Save ensemble models and metadata.
Args:
models_info: List of model info dictionaries from training
ensemble_metadata: Optional metadata about the ensemble
"""
# Create ensemble directory
os.makedirs(self.base_path, exist_ok=True)
# Save ensemble metadata
metadata = {
"n_models": len(models_info),
"ensemble_version": "1.0",
"models": [],
}
if ensemble_metadata:
metadata.update(ensemble_metadata)
# Save each model
for i, model_info in enumerate(models_info):
model_path = os.path.join(self.base_path, f"model_fold_{i}.weights.h5")
# Save model weights
model = model_info["model"]
model.save_weights(model_path)
# Prepare model metadata
model_meta = {
"fold": model_info["fold"],
"model_path": f"model_fold_{i}.weights.h5",
"norm_params": model_info["norm_params"],
"train_indices_count": len(model_info["train_indices"]),
"val_indices_count": len(model_info["val_indices"]),
}
# Add history if available
if "history" in model_info and model_info["history"]:
history = model_info["history"]
# Check if history has the expected keys
if hasattr(history, "history") and isinstance(history.history, dict):
if "loss" in history.history and len(history.history["loss"]) > 0:
model_meta["final_loss"] = float(history.history["loss"][-1])
if (
"val_loss" in history.history
and len(history.history["val_loss"]) > 0
):
model_meta["final_val_loss"] = float(
history.history["val_loss"][-1]
)
if "loss" in history.history:
model_meta["epochs_trained"] = len(history.history["loss"])
metadata["models"].append(model_meta)
# Save metadata
metadata_path = os.path.join(self.base_path, "ensemble_metadata.json")
with open(metadata_path, "w") as f:
json.dump(metadata, f, indent=2)
print(f"Saved ensemble with {len(models_info)} models to {self.base_path}")
[docs]
def load_ensemble(self, model_builder_fn=None) -> List[Dict]:
"""Load ensemble models and metadata.
Args:
model_builder_fn: Function to build model architecture
Returns
-------
List of model info dictionaries
"""
# Load metadata
metadata_path = os.path.join(self.base_path, "ensemble_metadata.json")
with open(metadata_path) as f:
metadata = json.load(f)
# Load model info
self.models_info = []
for model_meta in metadata["models"]:
model_info = {
"fold": model_meta["fold"],
"norm_params": model_meta["norm_params"],
"model_path": os.path.join(self.base_path, model_meta["model_path"]),
"model": None, # Lazy loading
"model_builder_fn": model_builder_fn,
}
self.models_info.append(model_info)
print(f"Loaded ensemble metadata for {len(self.models_info)} models")
return self.models_info
[docs]
def get_model(self, fold: int, n_features: int) -> keras.Model:
"""Get a specific model, loading if necessary.
Args:
fold: Fold index
n_features: Number of features for model construction
Returns
-------
Loaded model
"""
if fold in self.loaded_models:
return self.loaded_models[fold]
# Find model info
model_info = self.models_info[fold]
# Build model architecture
if model_info["model_builder_fn"] is None:
raise ValueError("Model builder function required for lazy loading")
model = model_info["model_builder_fn"](n_features)
# Load weights
model.load_weights(model_info["model_path"])
# Cache the model
self.loaded_models[fold] = model
return model
[docs]
def get_normalization_params(self, fold: int) -> NormalizationParams:
"""Get normalization parameters for a specific fold.
Args:
fold: Fold index
Returns
-------
NormalizationParams instance
"""
params = self.models_info[fold]["norm_params"]
return NormalizationParams(
meanlong=params["meanlong"],
sdlong=params["sdlong"],
meanlat=params["meanlat"],
sdlat=params["sdlat"],
)
[docs]
def get_averaged_normalization_params(self) -> NormalizationParams:
"""Get averaged normalization parameters across all folds.
Returns
-------
Averaged NormalizationParams
"""
all_params = [info["norm_params"] for info in self.models_info]
avg_params = {
"meanlong": np.mean([p["meanlong"] for p in all_params]),
"sdlong": np.mean([p["sdlong"] for p in all_params]),
"meanlat": np.mean([p["meanlat"] for p in all_params]),
"sdlat": np.mean([p["sdlat"] for p in all_params]),
}
return NormalizationParams(**avg_params)
[docs]
def save_predictions(
self, predictions: pd.DataFrame, prediction_type: str = "ensemble"
) -> None:
"""Save predictions to disk.
Args:
predictions: DataFrame with predictions
prediction_type: Type of predictions (e.g., "ensemble", "fold_0")
"""
pred_path = os.path.join(self.base_path, f"predictions_{prediction_type}.csv")
predictions.to_csv(pred_path, index=False)
print(f"Saved predictions to {pred_path}")
[docs]
def clear_cache(self) -> None:
"""Clear loaded models from memory."""
self.loaded_models.clear()
keras.backend.clear_session()