"""Ensemble functionality mixin for locator"""
import gc
import numpy as np
import pandas as pd
from tensorflow import keras
from .data import IndexSet, NormalizationParams, make_tf_dataset, normalize_locs
from .ensemble_model_manager import EnsembleModelManager
from .gpu_optimizer import GPUOptimizer
from .models import feature_network
from .prediction import predict_on_indices
# from tqdm import tqdm
[docs]
class EnsembleMixin:
"""Mixin class providing ensemble functionality for Locator."""
[docs]
def create_ensemble_folds(
self, genotypes, samples, k=5, training_set_indices=None, na_action=None
):
"""Create k-fold splits for ensemble training using IndexSet.
Args:
genotypes: GenotypeArray containing genetic data
samples: Array of sample IDs
k: Number of folds (default: 5)
training_set_indices: Optional array of indices to use for training+validation.
If provided, only these samples will be used to create k-folds.
na_action: How to handle NA samples ('separate', 'exclude', 'fail').
If None, uses self.na_action
Returns
-------
dict: Dictionary with fold information:
- 'index_sets': List of IndexSet objects for each fold
- 'fold_indices': Legacy format dict for backward compatibility
- 'sample_status': Sample status information
"""
na_action, status = self._validate_na_action(
samples, na_action, "Ensemble k-fold creation"
)
sample_data, locs = self._resolve_locations(samples)
# Prepare indices for k-fold splitting
n_samples = len(samples)
na_mask = np.isnan(locs[:, 0]) | np.isnan(locs[:, 1])
# If training_set_indices provided, create a mask for k-fold splitting
if training_set_indices is not None:
training_set_indices = np.array(training_set_indices)
# Validate indices
if not np.all(np.isin(training_set_indices, range(n_samples))):
raise ValueError("training_set_indices contains invalid indices")
# Create mask: True for samples NOT in training set (should be excluded from k-fold)
exclude_mask = np.ones(n_samples, dtype=bool)
exclude_mask[training_set_indices] = False
# Combine with NA mask for k-fold splitting
# Samples are excluded from k-fold if they are NA OR not in training set
if na_action == "separate":
kfold_exclude_mask = exclude_mask | na_mask
else:
kfold_exclude_mask = exclude_mask
else:
# Use only NA mask for k-fold splitting
kfold_exclude_mask = (
na_mask if na_action == "separate" else np.zeros(n_samples, dtype=bool)
)
# Create k-fold index sets
fold_index_sets = IndexSet.k_fold_split(
n=n_samples,
k=k,
na_mask=kfold_exclude_mask,
seed=self.config.get("seed", 12345),
)
# Convert to legacy format for backward compatibility
fold_indices = {}
for i, index_set in enumerate(fold_index_sets):
# For ensemble, we want train/val from the fold, and pred from NA samples
if na_action == "separate":
# Prediction set includes all NA samples
pred_idx = np.where(na_mask)[0]
else:
# No prediction set
pred_idx = np.array([], dtype=int)
fold_indices[i] = {
"train": index_set.train,
"val": index_set.test, # In k-fold, 'test' is used as validation
"pred": pred_idx,
}
return {
"index_sets": fold_index_sets,
"fold_indices": fold_indices,
"sample_status": status,
}
[docs]
def train_ensemble(
self,
genotypes,
samples,
k=5,
training_set_indices=None,
na_action=None,
augment_data=False,
flip_rate=0.05,
save_fold_models=True,
verbose=True,
use_model_manager=True,
use_mixed_precision=None,
patience_multiplier=1.0,
):
"""Train an ensemble of k models using k-fold cross-validation.
This method trains k models, each on a different k-fold split of the data.
It uses the modern tf.data pipeline for memory efficiency and supports
all standard Locator features including NA handling and data augmentation.
Args:
genotypes: GenotypeArray containing genetic data
samples: Array of sample IDs
k: Number of folds/models in ensemble (default: 5)
training_set_indices: Optional array of indices to restrict training
na_action: How to handle NA samples ('separate', 'exclude', 'fail')
augment_data: Whether to apply data augmentation (default: False)
flip_rate: Rate for genotype flipping augmentation (default: 0.05)
save_fold_models: Whether to save individual fold models (default: True)
verbose: Whether to show training progress (default: True)
use_model_manager: Whether to use model manager for saving (default: True)
use_mixed_precision: Whether to use mixed precision training (default: None, auto-detect)
patience_multiplier: Multiply patience for ensemble training (default: 1.0)
Returns
-------
dict: Dictionary containing:
- 'histories': List of training histories for each fold
- 'models': List of trained model configurations
- 'normalization_params': Averaged normalization parameters
- 'fold_info': Information about fold splits
"""
# Setup GPU optimizations for ensemble training
if verbose:
mixed_precision_enabled = self.setup_ensemble_gpu_optimization(
use_mixed_precision
)
if mixed_precision_enabled:
print("Mixed precision training enabled for ensemble")
else:
self.setup_ensemble_gpu_optimization(use_mixed_precision)
# Store samples for later use
self.samples = samples
# Create folds using IndexSet
fold_info = self.create_ensemble_folds(
genotypes, samples, k, training_set_indices, na_action
)
# Filter SNPs once before training
filtered_genotypes = self._filter_genotypes(genotypes)
# Store genotypes for later prediction
self._ensemble_genotypes = genotypes
# Store fold information
self._ensemble_fold_info = fold_info
self._ensemble_models = []
self._ensemble_histories = []
self._ensemble_norm_params = []
# Configure augmentation if requested
augment_config = None
if augment_data:
augment_config = {"enabled": True, "flip_rate": flip_rate}
self.config["augmentation"] = augment_config
# Get locations once
_, locs = self._resolve_locations(samples)
# Train each fold
for fold_idx in range(k):
if verbose:
print(f"\nTraining fold {fold_idx + 1}/{k}")
# Train single fold using the new method
model_info = self._train_single_fold(
fold_idx=fold_idx,
index_set=fold_info["index_sets"][fold_idx],
filtered_genotypes=filtered_genotypes,
samples=samples,
locs=locs,
augment_config=augment_config,
save_fold_models=save_fold_models,
patience_multiplier=patience_multiplier,
verbose=verbose,
)
# Add weights file path if saving
if save_fold_models:
model_info["weights_file"] = (
f"{self.config['out']}_fold{fold_idx}.weights.h5"
)
else:
model_info["weights_file"] = None
# Store results
self._ensemble_models.append(model_info)
self._ensemble_histories.append(model_info["history"])
self._ensemble_norm_params.append(model_info["norm_params"])
# Clear memory efficiently between folds
self._clear_fold_memory()
# Calculate averaged normalization parameters
avg_norm_params = self._average_normalization_params(self._ensemble_norm_params)
# Store averaged parameters on instance
self.meanlong = avg_norm_params["meanlong"]
self.sdlong = avg_norm_params["sdlong"]
self.meanlat = avg_norm_params["meanlat"]
self.sdlat = avg_norm_params["sdlat"]
# Save ensemble using model manager if requested
if use_model_manager and save_fold_models:
model_manager = EnsembleModelManager(f"{self.config['out']}_ensemble")
# Create a serializable version of config (excluding DataFrames)
serializable_config = {
k: v for k, v in self.config.items() if not isinstance(v, pd.DataFrame)
}
ensemble_metadata = {
"k_folds": k,
"na_action": na_action or self.na_action,
"augment_data": augment_data,
"config": serializable_config,
}
model_manager.save_ensemble(self._ensemble_models, ensemble_metadata)
return {
"histories": self._ensemble_histories,
"models": self._ensemble_models,
"normalization_params": avg_norm_params,
"fold_info": fold_info,
}
def _train_single_fold(
self,
fold_idx,
index_set,
filtered_genotypes,
samples,
locs,
augment_config=None,
save_fold_models=True,
patience_multiplier=1.0,
verbose=True,
):
"""Train a single fold model efficiently.
This method trains a model for a single fold without creating
a separate Locator instance, reusing the current instance's
configuration and methods.
Args:
fold_idx: Fold index
index_set: IndexSet for this fold
filtered_genotypes: Pre-filtered genotypes
samples: Sample IDs
locs: Location coordinates
augment_config: Augmentation configuration
verbose: Whether to show training progress
Returns
-------
dict: Model information including model, history, and parameters
"""
train_idx = index_set.train
val_idx = index_set.test # k-fold uses 'test' as validation
# Get train locations and normalize
train_locs = locs[train_idx]
meanlong, sdlong, meanlat, sdlat, _, normalized_train_locs = normalize_locs(
train_locs
)
# Store normalization parameters
norm_params = {
"meanlong": meanlong,
"sdlong": sdlong,
"meanlat": meanlat,
"sdlat": sdlat,
}
# Expose the fold's split and genotypes so _create_model can build the
# IndexedGenotypeModel (and fit per-fold PCA when pca_components is set).
# The parallel ensemble actor calls this method directly, bypassing
# train_ensemble, so filtered_genotypes must be set here too.
self.index_set = index_set
self.filtered_genotypes = filtered_genotypes
# Create model
model = self._create_model(input_shape=filtered_genotypes.shape[0])
# Normalize validation locations
val_locs = locs[val_idx]
normalized_val_locs = self._apply_normalization(val_locs, norm_params)
# Create normalized location array for all samples
normalized_locs = np.zeros((len(samples), 2))
normalized_locs[train_idx] = normalized_train_locs
normalized_locs[val_idx] = normalized_val_locs
# Determine optimal batch size for ensemble
batch_size = self.get_ensemble_batch_size(len(train_idx), fold_idx)
# Create datasets. Augmentation is applied inside the model (configured
# via self.config["augmentation"] in train_ensemble), not the pipeline.
train_dataset = make_tf_dataset(
coordinates=normalized_locs,
index_set=index_set,
split="train",
batch_size=batch_size,
training=True,
)
val_dataset = make_tf_dataset(
coordinates=normalized_locs,
index_set=index_set,
split="test", # k-fold uses 'test' for validation
batch_size=batch_size,
training=False,
)
# Create callbacks for this fold
callbacks = self._create_fold_callbacks(
fold_idx, save_fold_models, patience_multiplier
)
# Train model (runs the two-phase PCA fine-tune when enabled)
history = self._fit_model(
model,
train_dataset,
val_dataset,
callbacks,
keras_verbose=self.config.get("keras_verbose", 1) if verbose else 0,
)
# Return model info
return {
"fold": fold_idx,
"model": model,
"history": history,
"norm_params": norm_params,
"train_indices": train_idx,
"val_indices": val_idx,
}
def _create_fold_callbacks(
self, fold_idx, save_fold_models=True, patience_multiplier=1.0
):
"""Create callbacks for a specific fold.
Args:
fold_idx: Fold index
save_fold_models: Whether to save fold models
patience_multiplier: Multiply patience for ensemble training
Returns
-------
list: List of Keras callbacks
"""
callbacks = []
# Check if we should save this fold's model
should_save = save_fold_models
if should_save:
filepath = f"{self.config['out']}_fold{fold_idx}.weights.h5"
checkpointer = keras.callbacks.ModelCheckpoint(
filepath=filepath,
verbose=self.config.get("keras_verbose", 1),
save_best_only=True,
save_weights_only=True,
monitor="val_loss",
save_freq="epoch",
mode="min",
)
callbacks.append(checkpointer)
# Use enhanced callbacks from EnsembleTrainingMixin
earlystop = self.create_ensemble_early_stopping(patience_multiplier)
reducelr = self.create_ensemble_lr_scheduler(fold_idx)
callbacks.extend([earlystop, reducelr])
return callbacks
[docs]
def predict_ensemble(
self,
genotypes=None,
samples=None,
indices=None,
include_fold_predictions=False,
return_std=False,
return_df=True,
save_predictions=True,
):
"""Make predictions using the ensemble of models.
Args:
genotypes: GenotypeArray for prediction (if None, uses stored data)
samples: Sample IDs (if None, uses stored samples)
indices: Specific indices to predict on (if None, predicts all)
include_fold_predictions: Include individual fold predictions in output
return_std: Return standard deviation across ensemble predictions
return_df: Return results as DataFrame (default: True)
save_predictions: Save predictions to disk (default: True)
Returns
-------
pd.DataFrame or np.ndarray: Ensemble predictions with optional std
"""
# Validate inputs
self._validate_ensemble_prediction(genotypes, samples)
# Setup prediction indices
if indices is None:
indices = np.arange(len(samples))
# Get predictions from all folds
fold_predictions = self._collect_fold_predictions(genotypes, samples, indices)
# Calculate ensemble statistics
ensemble_mean = np.mean(fold_predictions, axis=0)
ensemble_std = np.std(fold_predictions, axis=0) if return_std else None
# Format and return results
if return_df:
return self._format_ensemble_predictions_df(
samples,
indices,
ensemble_mean,
ensemble_std,
fold_predictions,
include_fold_predictions,
save_predictions,
)
else:
return (ensemble_mean, ensemble_std) if return_std else ensemble_mean
def _validate_ensemble_prediction(self, genotypes, samples):
"""Validate inputs for ensemble prediction."""
if not hasattr(self, "_ensemble_models") or not self._ensemble_models:
raise ValueError("No ensemble models trained. Run train_ensemble() first.")
if genotypes is None or samples is None:
raise ValueError("genotypes and samples must be provided for prediction")
def _collect_fold_predictions(self, genotypes, samples, indices):
"""Collect predictions from all fold models."""
filtered_genotypes = self._filter_genotypes(genotypes)
fold_predictions = []
for model_info in self._ensemble_models:
predictions = self._predict_single_fold(
model_info, filtered_genotypes, samples, indices
)
fold_predictions.append(predictions)
return np.array(fold_predictions) # Shape: (n_folds, n_samples, 2)
def _predict_single_fold(self, model_info, filtered_genotypes, samples, indices):
"""Make predictions using a single fold model."""
model = model_info["model"]
if model_info["weights_file"]:
feature_network(model).load_weights(model_info["weights_file"])
predictions = predict_on_indices(model, filtered_genotypes, indices)
# Denormalize using NormalizationParams
norm_params = NormalizationParams(
meanlong=model_info["norm_params"]["meanlong"],
sdlong=model_info["norm_params"]["sdlong"],
meanlat=model_info["norm_params"]["meanlat"],
sdlat=model_info["norm_params"]["sdlat"],
)
return norm_params.reverse(predictions)
def _format_ensemble_predictions_df(
self,
samples,
indices,
ensemble_mean,
ensemble_std,
fold_predictions,
include_fold_predictions,
save_predictions,
):
"""Format ensemble predictions as DataFrame."""
# Create base DataFrame
result_df = pd.DataFrame(
{
"sampleID": samples[indices],
"x": ensemble_mean[:, 0],
"y": ensemble_mean[:, 1],
}
)
# Add optional columns
if ensemble_std is not None:
result_df["x_std"] = ensemble_std[:, 0]
result_df["y_std"] = ensemble_std[:, 1]
if include_fold_predictions:
for i in range(len(self._ensemble_models)):
result_df[f"x_fold{i}"] = fold_predictions[i, :, 0]
result_df[f"y_fold{i}"] = fold_predictions[i, :, 1]
# Save if requested
if save_predictions:
filename = f"{self.config['out']}_ensemble_predictions.csv"
result_df.to_csv(filename, index=False)
if self.config.get("keras_verbose", 1) >= 1:
print(f"Saved ensemble predictions to {filename}")
return result_df
def _apply_normalization(self, locations, norm_params):
"""Apply normalization to location coordinates using provided parameters."""
return np.array(
[
[
(loc[0] - norm_params["meanlong"]) / norm_params["sdlong"],
(loc[1] - norm_params["meanlat"]) / norm_params["sdlat"],
]
for loc in locations
]
)
def _average_normalization_params(self, norm_params_list):
"""Average normalization parameters across folds."""
avg_params = {
"meanlong": np.mean([p["meanlong"] for p in norm_params_list]),
"sdlong": np.mean([p["sdlong"] for p in norm_params_list]),
"meanlat": np.mean([p["meanlat"] for p in norm_params_list]),
"sdlat": np.mean([p["sdlat"] for p in norm_params_list]),
}
return avg_params
[docs]
def load_ensemble(self, ensemble_path):
"""Load a saved ensemble for prediction.
Args:
ensemble_path: Path to the saved ensemble directory
Returns
-------
dict: Ensemble information including models and parameters
"""
# Create model manager
model_manager = EnsembleModelManager(ensemble_path)
# Define model builder function
def model_builder(n_features):
return self._create_model(input_shape=n_features)
# Load ensemble
models_info = model_manager.load_ensemble(model_builder_fn=model_builder)
# Store for prediction
self._ensemble_models = models_info
self._ensemble_model_manager = model_manager
# Get averaged normalization parameters
avg_params = model_manager.get_averaged_normalization_params()
self.meanlong = avg_params.meanlong
self.sdlong = avg_params.sdlong
self.meanlat = avg_params.meanlat
self.sdlat = avg_params.sdlat
return {
"n_models": len(models_info),
"normalization_params": {
"meanlong": self.meanlong,
"sdlong": self.sdlong,
"meanlat": self.meanlat,
"sdlat": self.sdlat,
},
}
[docs]
def predict_ensemble_from_manager(
self,
genotypes,
samples,
indices=None,
return_df=True,
save_predictions=True,
):
"""Make predictions using loaded ensemble with model manager.
This method efficiently loads models on-demand for prediction,
reducing memory usage for large ensembles.
Args:
genotypes: GenotypeArray for prediction
samples: Sample IDs
indices: Specific indices to predict on (if None, predicts all)
return_df: Return results as DataFrame (default: True)
save_predictions: Save predictions to disk (default: True)
Returns
-------
pd.DataFrame or np.ndarray: Ensemble predictions
"""
if not hasattr(self, "_ensemble_model_manager"):
raise ValueError("No ensemble loaded. Use load_ensemble() first.")
# Setup prediction indices
if indices is None:
indices = np.arange(len(samples))
# Filter genotypes
filtered_genotypes = self._filter_genotypes(genotypes)
# Collect predictions from each fold
fold_predictions = []
for fold_idx, model_info in enumerate(self._ensemble_models):
# Get model (loads if necessary)
model = self._ensemble_model_manager.get_model(
fold_idx, filtered_genotypes.shape[0]
)
predictions = predict_on_indices(model, filtered_genotypes, indices)
# Get normalization params for this fold
norm_params = self._ensemble_model_manager.get_normalization_params(fold_idx)
# Denormalize
denorm_preds = norm_params.reverse(predictions)
fold_predictions.append(denorm_preds)
# Convert to array
fold_predictions = np.array(fold_predictions)
# Calculate ensemble mean
ensemble_mean = np.mean(fold_predictions, axis=0)
# Format results
if return_df:
result_df = pd.DataFrame(
{
"sampleID": samples[indices],
"x": ensemble_mean[:, 0],
"y": ensemble_mean[:, 1],
}
)
if save_predictions:
self._ensemble_model_manager.save_predictions(result_df, "ensemble")
return result_df
else:
return ensemble_mean
# Training improvement methods (consolidated from ensemble_improvements.py)
[docs]
def setup_ensemble_gpu_optimization(self, use_mixed_precision=None):
"""Setup GPU optimizations for ensemble training.
Args:
use_mixed_precision: Whether to use mixed precision training.
If None, uses config value or auto-detects based on GPU.
Returns
-------
bool: Whether mixed precision was enabled
"""
# Check if we should use mixed precision
if use_mixed_precision is None:
use_mixed_precision = self.config.get("use_mixed_precision", True)
# Apply mixed precision if requested and not already applied
mixed_precision_enabled = False
if use_mixed_precision and not hasattr(self, "_mixed_precision_setup"):
mixed_precision_enabled = GPUOptimizer.setup_mixed_precision()
self._mixed_precision_setup = True
return mixed_precision_enabled
[docs]
def get_ensemble_batch_size(self, dataset_size, fold_idx=0):
"""Determine optimal batch size for ensemble training.
Uses GPUOptimizer to find the best batch size, with caching
to avoid recomputing for each fold.
Args:
dataset_size: Size of training dataset
fold_idx: Current fold index (for logging)
Returns
-------
int: Optimal batch size
"""
# Check if we already computed batch size
if hasattr(self, "_ensemble_batch_size"):
return self._ensemble_batch_size
# Get batch size from config or auto-determine
batch_size = self.config.get("batch_size")
if batch_size == "auto" or batch_size is None:
# Only compute for first fold, reuse for others
if fold_idx == 0 and hasattr(self, "model") and self.model is not None:
batch_size = GPUOptimizer.get_optimal_batch_size(
model=feature_network(self.model),
input_shape=(self.filtered_genotypes.shape[0],),
dataset_size=dataset_size,
verbose=self.config.get("keras_verbose", 1) > 0,
)
self._ensemble_batch_size = batch_size
else:
# Use reasonable default
batch_size = self._determine_batch_size(dataset_size)
self._ensemble_batch_size = batch_size
else:
self._ensemble_batch_size = batch_size
return self._ensemble_batch_size
[docs]
def create_ensemble_early_stopping(self, patience_multiplier=1.5):
"""Create early stopping callback with ensemble-specific settings.
Args:
patience_multiplier: Multiply base patience for ensemble training
(ensembles often benefit from longer training)
Returns
-------
keras.callbacks.EarlyStopping: Configured callback
"""
base_patience = self.config.get("patience", 100)
ensemble_patience = int(base_patience * patience_multiplier)
return keras.callbacks.EarlyStopping(
monitor="val_loss",
min_delta=self.config.get("min_delta", 0),
patience=ensemble_patience,
restore_best_weights=self.config.get("restore_best_weights", True),
verbose=self.config.get("keras_verbose", 1) > 0,
)
[docs]
def create_ensemble_lr_scheduler(self, fold_idx):
"""Create learning rate scheduler for ensemble training.
Each fold can start with a slightly different learning rate
to improve ensemble diversity.
Args:
fold_idx: Current fold index
Returns
-------
keras.callbacks.ReduceLROnPlateau: Configured callback
"""
# Add small variation to learning rate based on fold
base_lr = self.config.get("learning_rate", 0.001)
lr_variation = 1 + (fold_idx * 0.1) / self.config.get("k_folds", 5)
# Update optimizer learning rate if model exists
if hasattr(self, "model") and self.model is not None:
keras.backend.set_value(
self.model.optimizer.learning_rate, base_lr * lr_variation
)
return keras.callbacks.ReduceLROnPlateau(
monitor="val_loss",
factor=0.5,
patience=self.config.get("patience", 100) // 6,
verbose=self.config.get("keras_verbose", 1),
mode="auto",
min_delta=0,
cooldown=0,
min_lr=base_lr * 0.01, # Don't go below 1% of original LR
)
def _clear_fold_memory(self):
"""Clear memory after training a fold."""
# Clear keras backend session
keras.backend.clear_session()
# Force garbage collection
gc.collect()