Source code for locator.prediction

"""Prediction functionality for locator"""

import json
import warnings

import h5py
import numpy as np
import pandas as pd

from .data import filter_dosage_matrix, is_dosage_matrix
from .models import feature_network


def predict_on_indices(model, filtered_genotypes, indices, site_order=None, verbose=0):
    """Predict coordinates for the given samples from the genotype matrix.

    Slices the requested samples, transposes to sample-major, and runs the
    network that consumes genotype features -- unwrapping IndexedGenotypeModel
    when needed.
    """
    inner = feature_network(model)
    features = np.ascontiguousarray(filtered_genotypes[:, indices].T)
    if site_order is not None:
        features = features[:, site_order]
    return inner.predict(features, verbose=verbose)


[docs] class PredictionMixin: """Mixin class providing prediction functionality for Locator."""
[docs] def predict( # noqa: C901 self, boot=0, verbose=True, prediction_genotypes=None, # Deprecated - use genotypes instead genotypes=None, # New: full genotype array for tf.data samples=None, # New: sample IDs indices=None, # New: which samples to predict (default: NA samples) return_df=False, save_preds_to_disk=True, site_order=None, ): """Make predictions for samples with unknown locations. Args: boot (int, optional): Bootstrap replicate number. Defaults to 0. verbose (bool, optional): Whether to print validation metrics. Defaults to True. prediction_genotypes (numpy.ndarray, optional): DEPRECATED - use genotypes parameter. Override default prediction genotypes. Used for jacknife resampling. Defaults to None. genotypes (numpy.ndarray, optional): Full genotype array for creating tf.data dataset. Should be the original unfiltered genotypes. Defaults to None. samples (numpy.ndarray, optional): Sample IDs corresponding to genotypes. Defaults to None. indices (numpy.ndarray, optional): Indices of samples to predict on. If None, predicts on samples without coordinates (self.pred_indices). Defaults to None. return_df (bool, optional): Whether to return predictions as pandas DataFrame. Defaults to False. save_preds_to_disk (bool, optional): Whether to save predictions to disk. Defaults to True. site_order (np.ndarray, optional): Array of SNP indices for bootstrap resampling. If provided, SNPs will be reordered according to these indices during prediction. Used for bootstrap analyses to ensure consistent resampling between train and predict. Returns ------- numpy.ndarray or pandas.DataFrame: Array of predicted coordinates or DataFrame with x,y coordinates and sampleID columns """ if self.model is None: raise ValueError("Model must be trained before prediction") # IndexedGenotypeModel gathers genotypes by index for training; for # prediction we feed genotype features straight to the feature network. inner = feature_network(self.model) # Check if using new tf.data approach or old array approach if genotypes is not None: # New tf.data approach if prediction_genotypes is not None: warnings.warn( "prediction_genotypes parameter is deprecated. Use genotypes parameter instead.", DeprecationWarning, stacklevel=2, ) # Determine which samples to predict locs = None if indices is None: # For new tf.data API, determine NA samples from provided data # Get sample data to identify NA samples if hasattr(self, "_sample_data_df"): sample_data, locs = self.sort_samples(samples) else: sample_data_path = self.config.get("sample_data") if sample_data_path: sample_data, locs = self.sort_samples(samples, sample_data_path) else: # No sample data available, fall back to pred_indices if hasattr(self, "pred_indices"): indices = self.pred_indices else: empty_df = pd.DataFrame(columns=["sampleID", "x", "y"]) if save_preds_to_disk: empty_df.to_csv( f"{self.config['out']}_predlocs.csv", index=False ) return empty_df if return_df else None # If we got sample data, find NA samples if locs is not None: na_mask = np.isnan(locs[:, 0]) | np.isnan(locs[:, 1]) indices = np.where(na_mask)[0] if len(indices) == 0: # No NA samples, predict on all if in 'separate' mode if ( hasattr(self, "config") and self.config.get("na_action") == "separate" ): indices = np.arange(len(samples)) else: empty_df = pd.DataFrame(columns=["sampleID", "x", "y"]) if save_preds_to_disk: empty_df.to_csv( f"{self.config['out']}_predlocs.csv", index=False ) return empty_df if return_df else None # Check if we have any samples to predict if len(indices) == 0: empty_df = pd.DataFrame(columns=["sampleID", "x", "y"]) if save_preds_to_disk: empty_df.to_csv(f"{self.config['out']}_predlocs.csv", index=False) return empty_df if return_df else None # Use stored samples if not provided if samples is None: if hasattr(self, "samples"): samples = self.samples else: raise ValueError("samples must be provided or stored from training") # Filter genotypes using the same parameters as training if hasattr(self, "filtered_genotypes"): filtered_genotypes = self.filtered_genotypes elif is_dosage_matrix(genotypes): filtered_genotypes = filter_dosage_matrix( genotypes, min_mac=self.config.get("min_mac", 2), max_snps=self.config.get("max_SNPs"), ) else: from .data import filter_snps_legacy as filter_snps filtered_genotypes = filter_snps( genotypes, min_mac=self.config.get("min_mac", 2), max_snps=self.config.get("max_SNPs"), impute=self.config.get("impute_missing", False), ) predictions = predict_on_indices( self.model, filtered_genotypes, indices, site_order=site_order, verbose=verbose, ) # Store the indices we predicted on for later use prediction_indices = indices else: # Old array-based approach (for backward compatibility) if prediction_genotypes is not None: warnings.warn( "Using deprecated array-based prediction. Consider using genotypes parameter for better memory efficiency.", DeprecationWarning, stacklevel=2, ) # Use provided prediction genotypes if available, otherwise use stored ones predgen = ( prediction_genotypes if prediction_genotypes is not None else self.predgen ) # Apply site resampling if site_order is provided if site_order is not None and predgen is not None and len(predgen) > 0: predgen = predgen[:, site_order] # Check if there are any samples to predict if predgen is None or len(predgen) == 0: # Return empty DataFrame with correct columns empty_df = pd.DataFrame(columns=["sampleID", "x", "y"]) if save_preds_to_disk: empty_df.to_csv(f"{self.config['out']}_predlocs.csv", index=False) return empty_df if return_df else None # Get predictions predictions = inner.predict(predgen) # Use stored pred_indices prediction_indices = ( self.pred_indices if hasattr(self, "pred_indices") else None ) # Denormalize predictions predictions = predictions.copy() predictions[:, 0] = predictions[:, 0] * self.sdlong + self.meanlong predictions[:, 1] = predictions[:, 1] * self.sdlat + self.meanlat # Create DataFrame pred_df = pd.DataFrame(predictions, columns=["x", "y"]) # Add sample IDs if samples is not None and prediction_indices is not None: # New approach: use provided samples and indices pred_df.insert(0, "sampleID", samples[prediction_indices]) elif hasattr(self, "samples") and hasattr(self, "pred_indices"): # Old approach: use stored values pred_df.insert(0, "sampleID", self.samples[self.pred_indices]) # Save predictions to file outfile = ( f"{self.config['out']}_boot{boot}_predlocs.txt" if self.config.get("bootstrap", False) or self.config.get("jacknife", False) else f"{self.config['out']}_predlocs.txt" ) if save_preds_to_disk: pred_df.to_csv(outfile, index=False) if return_df: return pred_df return predictions
[docs] def load_model(self, weights_path): """Load a trained model from saved weights. This method loads a model from HDF5 weights file and restores the preprocessing parameters needed for making predictions. Args: weights_path (str): Path to the saved HDF5 weights file Returns ------- dict: Dictionary containing loaded metadata including normalization params Raises ------ ValueError: If weights file cannot be loaded or is missing metadata """ import os if not os.path.exists(weights_path): raise ValueError(f"Weights file not found: {weights_path}") # Load metadata from HDF5 file metadata = {} try: with h5py.File(weights_path, "r") as f: # Load normalization parameters self.meanlong = float(f.attrs.get("coord_meanlong", 0.0)) self.sdlong = float(f.attrs.get("coord_sdlong", 1.0)) self.meanlat = float(f.attrs.get("coord_meanlat", 0.0)) self.sdlat = float(f.attrs.get("coord_sdlat", 1.0)) metadata["normalization"] = { "meanlong": self.meanlong, "sdlong": self.sdlong, "meanlat": self.meanlat, "sdlat": self.sdlat, } # Load preprocessing parameters metadata["preprocessing"] = { "min_mac": int(f.attrs.get("min_mac", 2)), "max_SNPs": int(f.attrs.get("max_SNPs", -1)), "impute_missing": bool(f.attrs.get("impute_missing", False)), } if metadata["preprocessing"]["max_SNPs"] == -1: metadata["preprocessing"]["max_SNPs"] = None # Load other metadata metadata["n_samples"] = int(f.attrs.get("n_samples", 0)) metadata["n_snps"] = int(f.attrs.get("n_snps", 0)) metadata["metadata_version"] = str( f.attrs.get("metadata_version", "unknown") ) metadata["locator_version"] = str( f.attrs.get("locator_version", "unknown") ) metadata["save_date"] = str(f.attrs.get("save_date", "unknown")) # Load config if available config_json = f.attrs.get("config_json", None) if config_json: metadata["config"] = json.loads(config_json) # Update current config with loaded values self.config.update(metadata["config"]) print(f"Loaded model metadata from {weights_path}") print( f"Model trained on {metadata['n_samples']} samples with {metadata['n_snps']} SNPs" ) print( f"Normalization params: mean_long={self.meanlong:.4f}, sd_long={self.sdlong:.4f}" ) except Exception as e: # For backward compatibility with models saved before metadata feature warnings.warn( f"Could not load metadata from weights file: {e}\n" "This may be an older model without saved metadata. " "Normalization parameters will need to be set manually." ) metadata = None # Create the model architecture if not already created if self.model is None: # Infer architecture from weights or use config # This requires knowing the input shape - will be set when genotypes are loaded warnings.warn( "Model architecture not yet created. " "Call train() with setup_only=True after loading genotypes." ) # Load the weights if model exists if self.model is not None: self.model.load_weights(weights_path) print("Loaded weights into model") return metadata
[docs] def predict_from_weights( self, weights_path, genotypes, samples, sample_data_file=None, save_preds_to_disk=True, return_df=True, ): """Convenience method to load weights and make predictions. This method combines loading a saved model and making predictions in a single call. It handles preprocessing the genotypes using the same parameters that were used during training. Args: weights_path (str): Path to saved HDF5 weights file genotypes (numpy.ndarray): Genotype data to predict on samples (numpy.ndarray): Sample IDs corresponding to genotypes sample_data_file (str, optional): Path to sample data file save_preds_to_disk (bool): Whether to save predictions to disk return_df (bool): Whether to return predictions as DataFrame Returns ------- numpy.ndarray or pandas.DataFrame: Predictions """ # Load the model and metadata metadata = self.load_model(weights_path) # Store samples self.samples = samples # Get sample data to identify prediction samples sample_data, locs = self._resolve_locations(samples, sample_data_file) # Find samples without coordinates (to predict) na_mask = np.isnan(locs[:, 0]) | np.isnan(locs[:, 1]) self.pred_indices = np.where(na_mask)[0] if len(self.pred_indices) == 0: warnings.warn("No samples found without coordinates. Nothing to predict.") return ( pd.DataFrame(columns=["sampleID", "x", "y"]) if return_df else np.array([]) ) if metadata and "preprocessing" in metadata: min_mac = metadata["preprocessing"]["min_mac"] max_snps = metadata["preprocessing"]["max_SNPs"] impute = metadata["preprocessing"]["impute_missing"] else: min_mac = self.config.get("min_mac", 2) max_snps = self.config.get("max_SNPs") impute = self.config.get("impute_missing", False) if is_dosage_matrix(genotypes): filtered_genotypes = filter_dosage_matrix( genotypes, min_mac=min_mac, max_snps=max_snps, ) else: from .data import filter_snps_legacy as filter_snps filtered_genotypes = filter_snps( genotypes, min_mac=min_mac, max_snps=max_snps, impute=impute, ) # Prepare prediction genotypes self.predgen = np.transpose(filtered_genotypes[:, self.pred_indices]) # Create model if needed if self.model is None: from .models import create_network n_snps = filtered_genotypes.shape[0] self.model = create_network( input_shape=n_snps, width=self.config.get("width", 256), n_layers=self.config.get("nlayers", 8), dropout_prop=self.config.get("dropout_prop", 0.25), optimizer_config={ "algo": self.config.get("optimizer_algo", "adam"), "learning_rate": self.config.get("learning_rate", 0.001), "weight_decay": self.config.get("weight_decay", 0.004), }, ) self.model.load_weights(weights_path) # Make predictions return self.predict(save_preds_to_disk=save_preds_to_disk, return_df=return_df)
[docs] def predict_holdout( self, verbose=True, return_df=False, save_preds_to_disk=True, plot_summary=True, plot_map=True, ): """Predict locations for held out samples. Args: verbose: Print progress and metrics return_df: Return predictions as pandas DataFrame save_preds_to_disk: Save predictions to disk plot_summary: Display error summary plot in notebook (only if return_df=True) plot_map: Display map of predictions (only if plot_summary=True) Returns ------- If return_df is True, returns pandas DataFrame with predictions Otherwise returns None """ if not hasattr(self, "holdout_idx") or not hasattr(self, "holdout_locs"): raise ValueError("No holdout data found. Run train_holdout() first.") if verbose: print("Predicting locations for holdout samples...") # holdout_gen is the (n_holdout, n_snps) feature slice stored by # _store_holdout_state; predict on it directly. predictions = feature_network(self.model).predict( self.holdout_gen, verbose=verbose ) # Create output dataframe pred_df = pd.DataFrame(predictions, columns=["x", "y"]) pred_df["sampleID"] = self.samples[self.holdout_idx] # Denormalize predictions pred_df["x"] = pred_df["x"] * self.sdlong + self.meanlong pred_df["y"] = pred_df["y"] * self.sdlat + self.meanlat pred_df["x_pred"] = pred_df["x"] pred_df["y_pred"] = pred_df["y"] if save_preds_to_disk: pred_df.to_csv(f"{self.config['out']}_holdout_predlocs.csv", index=False) if return_df: # If we're in a notebook and plot_summary is True, display the error plot try: import matplotlib.pyplot as plt # noqa: F401 from IPython.display import display # noqa: F401 from .plotting import plot_error_summary if plot_summary: # Get sample data if hasattr(self, "_sample_data_df"): sample_data = self._sample_data_df else: sample_data = pd.read_csv(self.config["sample_data"], sep="\t") # Create and display plot plot_error_summary( predictions=pred_df, sample_data=sample_data, plot_map=plot_map, width=15, height=5, out_prefix=self.config.get("out"), show=True, # Explicitly show since we're in a notebook ) except ImportError: # Not in a notebook, skip plotting pass return pred_df return predictions