"""Training functionality for locator"""
import json
import warnings
from datetime import datetime
import h5py
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from .data import (
IndexSet,
build_genotype_table,
filter_dosage_matrix,
is_dosage_matrix,
make_tf_dataset,
normalize_locs,
)
from .data import filter_snps_legacy as filter_snps
from .gpu_optimizer import GPUOptimizer
from .models import (
PCA_GATE_NAME,
PCA_LAYER_NAME,
IndexedGenotypeModel,
build_optimizer,
create_network,
euclidean_distance_loss,
feature_network,
loss_with_range_penalty,
rasterize_species_range,
)
from .pca import compute_pca_projection_gram, scree_elbow
from .sample_weights import weight_samples
[docs]
class TrainingMixin:
"""Mixin class providing training functionality for Locator."""
def _split_train_test(
self, genotypes, locations, train_split=0.9, na_action="separate"
):
"""Split genotype and location data into training and test sets.
This method creates an IndexSet for efficient data splitting without creating
full genotype arrays. The actual data loading is handled by tf.data pipeline.
Args:
genotypes: GenotypeArray containing genetic data for all samples
locations: Array of geographic coordinates (x,y) for each sample,
with NaN values for samples with unknown locations
train_split: Proportion of samples to use for training (default: 0.9)
na_action: How to handle NA samples ('separate', 'exclude', 'fail')
Returns
-------
tuple: (index_set, train_idx, test_idx, train_locs, test_locs, pred_idx)
index_set: IndexSet containing train/test/predict indices
train_idx: Training sample indices
test_idx: Test sample indices
train_locs: Location data for training samples
test_locs: Location data for test samples
pred_idx: Prediction sample indices
"""
# Create NA mask
na_mask = np.isnan(locations[:, 0])
n_samples = len(locations)
# Create IndexSet with custom splits for train/test
splits = {"train": train_split, "test": 1.0 - train_split}
index_set = IndexSet.random_split(
n=n_samples, splits=splits, na_mask=na_mask, na_action=na_action
)
# Get indices
train_idx = index_set.train
test_idx = index_set.test
# For 'separate' mode, prediction set should include ALL samples
if na_action == "separate":
pred_idx = np.arange(n_samples)
else:
pred_idx = (
index_set.get_split("predict")
if "predict" in index_set.indices
else np.array([], dtype=int)
)
# Prepare location arrays (always needed)
trainlocs = locations[train_idx]
testlocs = locations[test_idx]
# Return IndexSet and indices only - no arrays created
return index_set, train_idx, test_idx, trainlocs, testlocs, pred_idx
def _create_callbacks(self, boot=0):
"""Create Keras callbacks for training.
Args:
boot: Bootstrap replicate number (default: 0)
Returns
-------
list: List of Keras callbacks
"""
callbacks = []
# Check if we should save fold models
should_save = self.config.get("save_fold_models", True)
if should_save:
filepath = (
f"{self.config['out']}_boot{boot}.weights.h5"
if self.config.get("bootstrap", False)
else f"{self.config['out']}.weights.h5"
)
checkpointer = keras.callbacks.ModelCheckpoint(
filepath=filepath,
verbose=self.config.get("keras_verbose", 1),
save_best_only=True,
save_weights_only=True,
monitor="val_loss",
save_freq="epoch",
mode="min", # Explicitly set mode for clarity
)
callbacks.append(checkpointer)
earlystop = keras.callbacks.EarlyStopping(
monitor="val_loss",
min_delta=0,
patience=self.config.get("patience", 100),
)
reducelr = keras.callbacks.ReduceLROnPlateau(
monitor="val_loss",
factor=0.5,
patience=self.config.get("patience", 100) // 6,
verbose=self.config.get("keras_verbose", 1),
mode="auto",
min_delta=0,
cooldown=0,
min_lr=0,
)
callbacks.extend([earlystop, reducelr])
return callbacks
[docs]
def set_sample_weights(self, wdict):
"""Set sample weights for training.
Args:
wdict (dict): Dictionary returned by utils.weight_samples() containing sample weights.
"""
self.sample_weights = wdict
self.config["weight_samples"]["enabled"] = True
for key, value in wdict.items():
self.config["weight_samples"][key] = value
[docs]
def train( # noqa: C901
self,
*, # Force keyword arguments
genotypes,
samples,
sample_data_file=None,
boot=None,
train_gen=None,
test_gen=None,
pred_gen=None,
train_locs=None,
test_locs=None,
setup_only=False,
na_action=None,
site_order=None,
):
"""Train the Locator model on genotype and location data.
This method trains the neural network model to predict geographic locations from genetic data.
It supports both standard training and advanced workflows such as bootstrapping, by accepting
pre-processed genotype and location arrays. The model is configured using the parameters
provided at initialization.
Args:
genotypes (allel.GenotypeArray or np.ndarray): Genotype data for all samples. Should be of shape (n_sites, n_samples, ploidy).
samples (np.ndarray): Array of sample IDs corresponding to the genotype data.
sample_data_file (str, optional): Path to a tab-delimited file with columns 'sampleID', 'x', 'y' for sample locations. Used if not provided in config or as a DataFrame.
boot (int, optional): Bootstrap replicate number. Used for bootstrapping analyses. Defaults to None.
train_gen (np.ndarray, optional): Pre-processed training genotype data. Used for bootstrapping. If None, will be generated from `genotypes`. Defaults to None.
test_gen (np.ndarray, optional): Pre-processed test genotype data. Used for bootstrapping. If None, will be generated from `genotypes`. Defaults to None.
pred_gen (np.ndarray, optional): Pre-processed prediction genotype data. Used for bootstrapping. If None, will be generated from `genotypes`. Defaults to None.
train_locs (np.ndarray, optional): Pre-processed training locations. Used for bootstrapping. If None, will be generated from sample data. Defaults to None.
test_locs (np.ndarray, optional): Pre-processed test locations. Used for bootstrapping. If None, will be generated from sample data. Defaults to None.
setup_only (bool, optional): If True, only sets up the model and data without training. Defaults to False.
na_action (str, optional): How to handle NA samples ('separate', 'exclude', 'fail').
If None, uses self.na_action. Defaults to None.
site_order (np.ndarray, optional): Array of SNP indices for bootstrap resampling.
If provided, SNPs will be reordered according to these indices during training.
Used for bootstrap analyses to resample SNPs with replacement.
Returns
-------
keras.callbacks.History or None: The Keras training history object if training is performed, or None if `setup_only` is True.
Raises
------
ValueError: If required sample data is missing or improperly formatted.
Example:
>>> # Standard training
>>> loc = Locator({"out": "analysis", "sample_data": "samples.txt", "zarr": "genotypes.zarr"})
>>> genotypes, samples = loc.load_genotypes(zarr="genotypes.zarr")
>>> history = loc.train(genotypes=genotypes, samples=samples)
>>> # Bootstrapping with pre-processed data
>>> history = loc.train(
... genotypes=None,
... samples=samples,
... boot=1,
... train_gen=boot_train_gen,
... test_gen=boot_test_gen,
... pred_gen=boot_pred_gen,
... train_locs=boot_train_locs,
... test_locs=boot_test_locs
... )
"""
# Store samples and site_order
self.samples = samples
self.site_order = site_order
if self.config.get("pca_components") is not None and (
train_gen is not None or site_order is not None
):
raise ValueError(
"pca_components is not supported with bootstrap/jacknife "
"resampling (pre-processed train_gen or site_order)."
)
# Use instance default if na_action not specified
if na_action is None:
na_action = self.na_action
# Get sample status
status = self.get_sample_status(samples)
# Report status
print(
f"Training data: {status['n_known']} samples with coordinates, {status['n_na']} without"
)
if status["n_na"] > 0:
print(f"NA handling mode: {na_action}")
# Apply NA action
if na_action == "fail" and status["n_na"] > 0:
raise ValueError(
f"Found {status['n_na']} samples without coordinates. "
f"Set na_action='separate' or 'exclude' to proceed."
)
sample_data, locs = self._resolve_locations(samples, sample_data_file)
# Apply 'exclude' mode if needed
if na_action == "exclude" and status["n_na"] > 0:
print(f"Excluding {status['n_na']} samples without coordinates")
# Filter to only known samples
mask = status["known_indices"]
genotypes = genotypes[:, mask]
samples = samples[mask]
locs = locs[mask]
# Update sample data to match
sample_data = sample_data.iloc[mask]
# Filter SNPs if not using pre-processed data
if train_gen is None:
self._filter_genotypes(genotypes)
# Split data using IndexSet approach (no arrays created)
(
self.index_set,
train,
test,
trainlocs,
testlocs,
pred,
) = self._split_train_test(
self.filtered_genotypes,
locs, # Use unnormalized locations for split
train_split=self.config.get("train_split", 0.9),
na_action=na_action,
)
# Set array attributes to None for compatibility
self.traingen = None
self.testgen = None
# For 'separate' mode, create predgen for backward compatibility
if na_action == "separate" and len(pred) > 0:
self.predgen = np.transpose(self.filtered_genotypes[:, pred])
elif len(pred) == 0:
# Create empty array with correct shape
self.predgen = np.zeros(
(0, self.filtered_genotypes.shape[0]),
dtype=self.filtered_genotypes.dtype,
)
else:
self.predgen = None
# Normalize locations and store for each split using helper method
normalized_locs = self._normalize_and_store_locations(
locs, samples, train, test
)
# Calculate sample weights using helper method
# Pass unnormalized training locations
train_locs_unnormed = locs[train]
self._calculate_sample_weights(train, train_locs=train_locs_unnormed)
# Store prediction indices
self.pred_indices = pred
splits = {"Training": train, "Validation": test}
if len(pred) > 0:
splits["Prediction"] = pred
self._report_split_summary(
splits, len(samples), self.filtered_genotypes.shape[0]
)
else:
# Use pre-processed data (for bootstrapping)
self.traingen = train_gen
self.testgen = test_gen
self.predgen = pred_gen
# For pre-processed data, we still need to normalize locations to get the normalization parameters
(
self.meanlong,
self.sdlong,
self.meanlat,
self.sdlat,
self.unnormedlocs,
normalized_locs,
) = normalize_locs(locs)
# Use provided locations if available
if train_locs is not None and test_locs is not None:
self.trainlocs = train_locs
self.testlocs = test_locs
else:
# Get train/test indices and locations from original split
train = np.where(~np.isnan(normalized_locs[:, 0]))[0]
test = np.random.choice(
train,
round((1 - self.config.get("train_split", 0.9)) * len(train)),
replace=False,
)
train = np.setdiff1d(train, test)
self.trainlocs = normalized_locs[train]
self.testlocs = normalized_locs[test]
# Create the model. Rebuild when site_order is given so the bootstrap /
# jacknife SNP resampling baked into the model matches this replicate.
if self.model is None or site_order is not None:
# Determine input shape
if self.traingen is not None:
input_shape = self.traingen.shape[1]
else:
# When using efficient pipeline, get shape from filtered_genotypes
# If site_order is provided, use its length (for jacknife/bootstrap)
if site_order is not None:
input_shape = len(site_order)
else:
input_shape = self.filtered_genotypes.shape[0]
self.model = self._create_model(
input_shape=input_shape, site_order=site_order
)
# Return early if setup_only
if setup_only:
return None
callbacks = self._create_callbacks(boot=boot)
self._build_datasets_and_fit(normalized_locs, callbacks)
self._save_training_artifacts(boot=boot)
return self.history
[docs]
def train_holdout( # noqa: C901
self,
genotypes=None,
samples=None,
k=10,
holdout_indices=None,
filtered_genotypes=None,
):
"""Train the model while holding out samples with known locations.
Args:
genotypes: Array of genotype data. Required unless
filtered_genotypes is provided.
samples: Sample IDs corresponding to genotypes
k: Number of samples to hold out (ignored if holdout_indices provided)
holdout_indices: Optional specific indices of samples to hold out
filtered_genotypes: Pre-filtered allele count array. If provided,
skips internal filter_snps call and avoids loading the full
genotype array. Used by parallel dispatch to share one
filtered copy across all workers.
Returns
-------
keras.callbacks.History object from model training
"""
# Store samples
self.samples = samples
_, locs = self._resolve_locations(samples)
# Get indices of samples with known locations
known_idx = np.where(~np.isnan(locs[:, 0]))[0]
# Determine holdout indices
if holdout_indices is not None:
holdout_idx = np.array(holdout_indices)
if not all(idx in known_idx for idx in holdout_idx):
raise ValueError(
"All holdout_indices must be indices of samples with known locations"
)
else:
if k >= len(known_idx):
raise ValueError(
f"k ({k}) must be less than number of samples with known locations ({len(known_idx)})"
)
holdout_idx = np.random.choice(known_idx, k, replace=False)
self._filter_genotypes(genotypes, filtered_genotypes)
# Get available samples for training (exclude holdout and NA samples)
available_indices = np.setdiff1d(known_idx, holdout_idx)
n_available = len(available_indices)
if n_available == 0:
raise ValueError("No samples available for training after holdout")
# Split available samples into train/test
train_split = self.config.get("train_split", 0.9)
n_train = int(n_available * train_split)
np.random.shuffle(available_indices)
train_indices = available_indices[:n_train]
test_indices = available_indices[n_train:]
# Create IndexSet for efficient data handling
n_samples = len(locs)
self.index_set = IndexSet(
indices={
"train": train_indices,
"test": test_indices,
"holdout": holdout_idx,
},
total_samples=n_samples,
na_mask=np.isnan(locs[:, 0]),
)
# Normalize locations and store for each split
normalized_locs = self._normalize_and_store_locations(
locs, samples, train_indices, test_indices
)
self._store_holdout_state(holdout_idx, normalized_locs)
self._report_split_summary(
{
"Training": train_indices,
"Validation": test_indices,
"Holdout": holdout_idx,
},
len(samples),
self.filtered_genotypes.shape[0],
)
# Handle sample weights if enabled
self._calculate_sample_weights(train_indices)
# Build the model, or reuse the previous fold's. Reuse is valid only
# when the genotype table is unchanged (k-fold/LOO sharing one filtered
# array, as a Ray actor does); it keeps the compiled training function,
# avoiding a per-fold XLA recompile.
current_table = self._get_genotype_table()
if (
isinstance(self.model, IndexedGenotypeModel)
and self.model.genotype_table is current_table
):
self._reset_model_for_fold()
else:
self.model = self._create_model(input_shape=self.filtered_genotypes.shape[0])
# Create callbacks
# For train_holdout, we might want to skip saving intermediate models
# to reduce file I/O overhead during k-fold cross-validation
if self.config.get("holdout_no_intermediate_saves", False):
# Minimal callbacks without file saves
callbacks = [
keras.callbacks.EarlyStopping(
monitor="val_loss",
min_delta=0,
patience=self.config.get("patience", 100),
restore_best_weights=True,
),
keras.callbacks.ReduceLROnPlateau(
monitor="val_loss",
factor=0.5,
patience=self.config.get("patience", 100) // 6,
verbose=self.config.get("keras_verbose", 1),
mode="auto",
min_delta=0,
min_lr=1e-5,
),
]
else:
callbacks = self._create_callbacks()
self._build_datasets_and_fit(normalized_locs, callbacks, keras_verbose=0)
self._save_training_artifacts(save=self.config.get("save_fold_models", True))
return self.history
def _save_model_metadata(self, boot=0):
"""Save model metadata including normalization parameters to HDF5 file.
This method saves essential preprocessing parameters as HDF5 attributes
so the model can be properly used for predictions in a new session.
Args:
boot: Bootstrap iteration number (default: 0)
"""
# Determine the weights file path
if self.config.get("bootstrap", False):
filepath = f"{self.config['out']}_boot{boot}.weights.h5"
else:
filepath = f"{self.config['out']}.weights.h5"
# Open the HDF5 file and add metadata as attributes
try:
with h5py.File(filepath, "a") as f:
# Save normalization parameters
f.attrs["coord_meanlong"] = (
self.meanlong if self.meanlong is not None else 0.0
)
f.attrs["coord_sdlong"] = self.sdlong if self.sdlong is not None else 1.0
f.attrs["coord_meanlat"] = (
self.meanlat if self.meanlat is not None else 0.0
)
f.attrs["coord_sdlat"] = self.sdlat if self.sdlat is not None else 1.0
# Save preprocessing parameters
f.attrs["min_mac"] = self.config.get("min_mac", 2)
f.attrs["max_SNPs"] = (
self.config.get("max_SNPs", None)
if self.config.get("max_SNPs") is not None
else -1
)
f.attrs["impute_missing"] = self.config.get("impute_missing", False)
f.attrs["n_samples"] = (
len(self.samples) if self.samples is not None else 0
)
f.attrs["n_snps"] = (
self.filtered_genotypes.shape[0]
if hasattr(self, "filtered_genotypes")
and self.filtered_genotypes is not None
else 0
)
# Save metadata version for future compatibility
f.attrs["metadata_version"] = "1.0"
f.attrs["locator_version"] = "0.1.0" # Should get from package version
f.attrs["save_date"] = datetime.now().isoformat()
# Save config as JSON string for full reproducibility
config_to_save = self.config.copy()
# Remove non-serializable items
non_serializable_keys = [
"genotypes",
"sample_data",
"genotype_data",
"species_range_geom",
]
for key in non_serializable_keys:
config_to_save.pop(key, None)
# Also remove any DataFrame values in nested dicts
if "weight_samples" in config_to_save and isinstance(
config_to_save["weight_samples"], dict
):
config_to_save["weight_samples"] = config_to_save[
"weight_samples"
].copy()
config_to_save["weight_samples"].pop("weightdf", None)
f.attrs["config_json"] = json.dumps(config_to_save)
print(f"Model metadata saved to {filepath}")
except Exception as e:
warnings.warn(f"Failed to save model metadata: {e}")
# Don't fail training if metadata save fails
def _create_model(self, input_shape, site_order=None):
"""Create the training model.
Builds the coordinate-prediction network (``inner``) and, when a
GPU-resident genotype table is available, wraps it in an
IndexedGenotypeModel so genotypes are gathered on-device by sample
index. When no genotype matrix is loaded (e.g. building an architecture
to load saved weights into) the plain network is returned.
Args:
input_shape: Number of input features the inner network expects.
site_order: Optional SNP resampling order for bootstrap/jacknife,
applied per batch inside the wrapper.
"""
loss_fn = None
if self.config.get("use_range_penalty"):
if self.config.get("species_range_shapefile") is None:
raise ValueError(
"species_range_shapefile must be provided "
"if use_range_penalty is True"
)
if self.config.get("resolution") is None:
raise ValueError(
"resolution must be provided if use_range_penalty is True"
)
mask_tensor, mask_transform = rasterize_species_range(
self.config["species_range_shapefile"],
resolution=self.config.get("resolution", 0.05),
)
def loss_fn(y_true, y_pred): # noqa: F811
return loss_with_range_penalty(
y_true,
y_pred,
mask_tensor=mask_tensor,
transform=mask_transform,
resolution=self.config.get("resolution", 0.05),
penalty_weight=self.config.get("penalty_weight", 1.0),
)
self._loss_fn = loss_fn
pca_components = self._resolve_pca_components()
inner = create_network(
input_shape=input_shape,
width=self.config.get("width", 256),
n_layers=self.config.get("nlayers", 8),
dropout_prop=self.config.get("dropout_prop", 0.25),
pca_components=pca_components,
optimizer_config={
"algo": self.config.get("optimizer_algo", "adam"),
"learning_rate": self.config.get("learning_rate", 0.001),
"weight_decay": self.config.get("weight_decay", 0.004),
},
loss_fn=loss_fn,
)
# Without a resident genotype matrix there is nothing to gather from;
# return the plain network (used by weight-loading paths).
if getattr(self, "filtered_genotypes", None) is None:
if pca_components is not None:
self._inject_pca_weights(inner, pca_components)
return inner
model = IndexedGenotypeModel(
inner,
self._get_genotype_table(),
site_order=site_order,
augment=self.config.get("augmentation"),
)
model.compile(
optimizer=build_optimizer(
self.config.get("optimizer_algo", "adam"),
self.config.get("learning_rate", 0.001),
self.config.get("weight_decay", 0.004),
),
loss=loss_fn if loss_fn is not None else euclidean_distance_loss,
)
if pca_components is not None:
self._inject_pca_weights(model, pca_components)
return model
def _get_genotype_table(self):
"""Build or reuse the GPU-resident genotype table.
The table is rebuilt only when ``filtered_genotypes`` is a different
array object. A Ray actor reuses one Locator and one shared
``filtered_genotypes`` across all its folds, so the table is built once
per actor; windowed analysis filters a fresh array per window, so the
identity check correctly triggers a rebuild there. ``_filter_genotypes``
assigns ``filtered_genotypes`` by reference (never a copy), which keeps
this identity guard exact.
"""
if (
self._genotype_table is None
or self.filtered_genotypes is not self._genotype_table_src
):
self._genotype_table = build_genotype_table(self.filtered_genotypes)
self._genotype_table_src = self.filtered_genotypes
return self._genotype_table
def _resolve_pca_components(self):
"""Resolve the pca_components config value to a concrete width.
Returns None (no projection) or an int. The string ``"auto"`` is
resolved to the genotype-PCA scree elbow of the training split and
written back to the config, so every fold and the saved metadata of a
run share one rank.
"""
pca_components = self.config.get("pca_components")
if pca_components is None or isinstance(pca_components, int):
return pca_components
if pca_components != "auto":
raise ValueError(
f"pca_components must be None, an int, or 'auto'; got {pca_components!r}"
)
index_set = getattr(self, "index_set", None)
if index_set is None or getattr(self, "filtered_genotypes", None) is None:
raise ValueError(
"pca_components='auto' needs training data; pass an explicit "
"integer when building an architecture to load weights into"
)
train_geno = tf.gather(
self._get_genotype_table(),
np.asarray(index_set.train, dtype=np.int32),
axis=0,
)
rank = scree_elbow(train_geno)
self.config["pca_components"] = rank
print(f"pca_components='auto': using scree-elbow rank {rank}")
return rank
def _inject_pca_weights(self, model, pca_components):
"""Initialize the pca_projection layer with PCA loadings, gate closed.
PCA is fit on the training split only, gathered from the GPU-resident
genotype table so the eigendecomposition runs on-device with no host
round trip. The projection layer stays trainable; phase-1 training is
held at the PCA initialization by closing the gradient gate, which
needs no recompile. When training data is not available (e.g. building
an architecture to load saved weights into), this is a no-op and the
layer keeps its loaded weights.
Args:
model: The model returned by create_network with a pca_projection
layer.
pca_components: Projection width.
"""
index_set = getattr(self, "index_set", None)
filtered = getattr(self, "filtered_genotypes", None)
if index_set is None or filtered is None:
return
n_snps = filtered.shape[0]
train_idx = index_set.train
n_train = len(train_idx)
if pca_components > min(n_train, n_snps):
raise ValueError(
f"pca_components ({pca_components}) cannot exceed "
f"min(n_train={n_train}, n_snps={n_snps})"
)
# Gather the training rows from the resident table (sample-major, on
# the GPU) and fit PCA there via the Gram-matrix method.
train_geno = tf.gather(
self._get_genotype_table(),
np.asarray(train_idx, dtype=np.int32),
axis=0,
)
W, bias = compute_pca_projection_gram(train_geno, pca_components)
# Assign the loadings straight from the device tensors -- set_weights
# would round them through host memory.
projection = model.get_layer(PCA_LAYER_NAME)
projection.kernel.assign(W)
projection.bias.assign(bias)
# Close the gate: phase-1 training holds the projection at its PCA
# initialization. The layer stays trainable, so the graph is unchanged.
model.get_layer(PCA_GATE_NAME).gate.assign(0.0)
[docs]
def train_window(
self,
genotypes,
samples,
window_snp_indices,
index_set,
normalized_locs,
):
"""Train the model for a specific genomic window using efficient tf.data pipeline.
This is an internal method used by run_windows_holdouts to train models
on specific genomic windows without creating intermediate arrays.
Args:
genotypes: Full genotype array (not filtered)
samples: Sample IDs
window_snp_indices: Indices of SNPs in this window
index_set: Pre-computed IndexSet with train/test/holdout splits
normalized_locs: Pre-normalized location coordinates
Returns
-------
keras.callbacks.History object from model training
"""
if self.config.get("pca_components") is not None:
raise ValueError(
"pca_components is not supported with windowed analysis; "
"run windows without the PCA-init projection."
)
# Store samples and index set
self.samples = samples
self.index_set = index_set
window_genotypes = genotypes[window_snp_indices, :, :]
self._filter_genotypes(window_genotypes)
# Store filtered data shape
n_snps_filtered = self.filtered_genotypes.shape[0]
# Calculate sample weights if enabled
self._calculate_sample_weights(index_set.train)
# Create model for this window
self.model = self._create_model(input_shape=n_snps_filtered)
# Create callbacks
callbacks = self._create_callbacks()
# Store necessary data for prediction
# In window analysis, 'test' split contains the holdout samples
self._store_holdout_state(index_set.get_split("test"), normalized_locs)
# For window analysis, we need to split the train indices into train/val.
# IndexSet arrays shipped via Ray are read-only; copy before shuffling.
train_indices = np.array(index_set.get_split("train"), copy=True)
train_split = self.config.get("train_split", 0.9)
n_train = int(len(train_indices) * train_split)
np.random.shuffle(train_indices)
actual_train = train_indices[:n_train]
actual_val = train_indices[n_train:]
self.trainlocs = normalized_locs[actual_train]
self.testlocs = normalized_locs[actual_val]
# Create a new IndexSet with the proper splits for training
self.index_set = IndexSet(
indices={"train": actual_train, "test": actual_val},
total_samples=index_set.total_samples,
na_mask=index_set.na_mask,
)
self._build_datasets_and_fit(normalized_locs, callbacks, keras_verbose=0)
return self.history
def _calculate_sample_weights(self, train_indices, train_locs=None):
"""Calculate sample weights if enabled. Extracted to avoid duplication.
Args:
train_indices: Indices of training samples
train_locs: Optional unnormalized training locations. If None, uses self.unnormedlocs
"""
if self.config.get("weight_samples", {}).get("enabled", False):
if self.sample_weights is not None:
warnings.warn(
"""Sample weights already calculated.
Set locator.sample_weights to None in config to disable."""
)
else:
wmethod = self.config.get("weight_samples", {}).get("method")
# Use provided train_locs or fall back to self.unnormedlocs
locs_for_weights = (
train_locs if train_locs is not None else self.unnormedlocs
)
self.sample_weights = weight_samples(
wmethod,
trainlocs=locs_for_weights,
trainsamps=self.samples[train_indices],
weightdf=self.config.get("weight_samples", {}).get("weightdf"),
xbins=self.config.get("weight_samples", {}).get("xbins"),
ybins=self.config.get("weight_samples", {}).get("ybins"),
lam=self.config.get("weight_samples", {}).get("lam"),
bandwidth=self.config.get("weight_samples", {}).get("bandwidth"),
)
def _determine_batch_size(self, dataset_size):
"""Determine optimal batch size. Extracted to avoid duplication."""
batch_size = self.config.get("batch_size", 32)
verbose_batch_size = self.config.get("verbose_batch_size", False)
if self.config.get("gpu_batch_size") == "auto" and not self.config.get(
"disable_gpu", False
):
try:
# Probe the feature network: the IndexedGenotypeModel wrapper
# consumes sample indices, not genotype features.
optimal_batch = GPUOptimizer.get_optimal_batch_size(
feature_network(self.model),
input_shape=(self.filtered_genotypes.shape[0],),
target_memory_usage=0.85,
dataset_size=dataset_size,
verbose=verbose_batch_size,
)
if verbose_batch_size:
print(f"Using optimized batch size: {optimal_batch}")
batch_size = optimal_batch
except Exception as e:
if verbose_batch_size:
print(
f"Failed to optimize batch size: {e}. Using default: {batch_size}"
)
elif isinstance(self.config.get("gpu_batch_size"), int):
batch_size = self.config["gpu_batch_size"]
return batch_size
# ------------------------------------------------------------------
# Shared helpers used by train(), train_holdout(), and train_window()
# ------------------------------------------------------------------
def _resolve_locations(self, samples, sample_data_file=None):
"""Load sample metadata and return locations array.
Args:
samples: Array of sample IDs
sample_data_file: Optional path override for sample data file
Returns
-------
tuple: (sample_data DataFrame, locs array of shape (n_samples, 2))
"""
if hasattr(self, "_sample_data_df"):
return self.sort_samples(samples)
sample_data_path = sample_data_file or self.config.get("sample_data")
if not isinstance(sample_data_path, str):
raise ValueError(
"sample_data file path must be provided in config or as argument "
"when not using DataFrame input"
)
return self.sort_samples(samples, sample_data_path)
def _filter_genotypes(self, genotypes, filtered_genotypes=None):
"""Filter SNPs and store result as self.filtered_genotypes.
Two genotype input dialects are accepted:
- ``allel.GenotypeArray`` (n_sites, n_samples, ploidy): the original
path; biallelic check + MAC + max_snps + optional imputation via
``filter_snps``.
- 2D float ``np.ndarray`` (n_sites, n_samples): continuous dosage
(e.g., GL-derived expected dosage). Biallelic check is skipped (not
meaningful for continuous values); MAC and max_snps filters are
applied directly on the dosage matrix.
Args:
genotypes: Raw GenotypeArray, 2D float ndarray, or window slice
filtered_genotypes: Pre-filtered allele counts (skips filtering)
Returns
-------
np.ndarray: Filtered allele count / dosage array, shape (n_sites, n_samples)
"""
if filtered_genotypes is not None:
self.filtered_genotypes = filtered_genotypes
elif genotypes is not None:
if is_dosage_matrix(genotypes):
self.filtered_genotypes = filter_dosage_matrix(
genotypes,
min_mac=self.config.get("min_mac", 2),
max_snps=self.config.get("max_SNPs"),
)
else:
self.filtered_genotypes = filter_snps(
genotypes,
min_mac=self.config.get("min_mac", 2),
max_snps=self.config.get("max_SNPs"),
impute=self.config.get("impute_missing", False),
)
else:
raise ValueError("Either genotypes or filtered_genotypes must be provided")
return self.filtered_genotypes
def _store_holdout_state(self, holdout_idx, normalized_locs):
"""Store holdout data for use by predict_holdout().
Sets self.holdout_idx, self.holdout_gen, self.holdout_locs.
Args:
holdout_idx: Array of held-out sample indices
normalized_locs: Full normalized location array
"""
self.holdout_idx = holdout_idx
self.holdout_gen = np.asarray(
self.filtered_genotypes[:, holdout_idx].T, order="C"
)
self.holdout_locs = normalized_locs[holdout_idx]
def _build_datasets_and_fit(
self,
normalized_locs,
callbacks,
keras_verbose=None,
):
"""Build tf.data pipelines and train the model.
Requires self.filtered_genotypes, self.index_set, self.model,
and self.sample_weights to be set before calling. SNP resampling
(site_order) is handled inside the model, not the dataset.
Args:
normalized_locs: Full normalized location array
callbacks: List of Keras callbacks
keras_verbose: Verbosity for model.fit (default: from config)
Returns
-------
keras.callbacks.History
"""
batch_size = self._determine_batch_size(len(self.index_set.train))
if keras_verbose is None:
keras_verbose = self.config.get("keras_verbose", 1)
train_dataset = make_tf_dataset(
coordinates=normalized_locs,
index_set=self.index_set,
split="train",
batch_size=batch_size,
sample_weights=(
self.sample_weights["sample_weights"] if self.sample_weights else None
),
training=True,
)
val_dataset = make_tf_dataset(
coordinates=normalized_locs,
index_set=self.index_set,
split="test",
batch_size=batch_size,
training=False,
)
self.history = self._fit_model(
self.model, train_dataset, val_dataset, callbacks, keras_verbose
)
return self.history
def _fit_model(self, model, train_dataset, val_dataset, callbacks, keras_verbose):
"""Fit a model, running a two-phase PCA fine-tune when enabled.
With ``pca_components`` set and ``pca_finetune`` true: phase 1 trains
with the projection held at its PCA initialization (gradient gate
closed); phase 2 opens the gate and continues at a low learning rate.
Both phases share one compiled training function -- only the gate
variable, the optimizer state, and the learning rate are reassigned, so
phase 2 does not retrace or recompile. Otherwise a single fit runs.
Keras resets stateful callbacks (EarlyStopping, ReduceLROnPlateau) at
the start of each fit, so the same callback list is reused.
Args:
model: Compiled Keras model to train.
train_dataset: Training tf.data.Dataset.
val_dataset: Validation tf.data.Dataset.
callbacks: List of Keras callbacks.
keras_verbose: Verbosity passed to model.fit.
Returns
-------
keras.callbacks.History (phase-2 history with both phases merged
when the two-phase fine-tune runs).
"""
max_epochs = self.config.get("max_epochs", 5000)
history = model.fit(
train_dataset,
epochs=max_epochs,
validation_data=val_dataset,
callbacks=callbacks,
verbose=keras_verbose,
)
pca_components = self.config.get("pca_components")
if pca_components is None or not self.config.get("pca_finetune", True):
return history
# Phase 2: open the gradient gate so the projection fine-tunes, on a
# fresh optimizer state at a low learning rate. The graph is unchanged,
# so the compiled training function from phase 1 is reused as-is.
model.get_layer(PCA_GATE_NAME).gate.assign(1.0)
self._reset_optimizer_state(model.optimizer)
model.optimizer.learning_rate = self.config.get("pca_finetune_lr", 1e-4)
history2 = model.fit(
train_dataset,
epochs=max_epochs,
validation_data=val_dataset,
callbacks=callbacks,
verbose=keras_verbose,
)
return self._concat_histories(history, history2)
@staticmethod
def _reset_optimizer_state(optimizer):
"""Zero an optimizer's momentum/velocity slots and step counter.
Gives a reused model -- the fine-tune phase, or the next fold -- a
clean optimizer state without building a new optimizer, which would
reset the compile cache and force an XLA recompile. Under mixed
precision the optimizer is a LossScaleOptimizer wrapping the real one;
only the inner optimizer is zeroed, so the adapting loss scale is left
intact (zeroing it would stall training). The learning rate is zeroed
here too and must be set by the caller afterwards.
"""
inner = getattr(optimizer, "inner_optimizer", optimizer)
for var in inner.variables:
var.assign(tf.zeros_like(var))
@staticmethod
def _reinitialize_layer_weights(model):
"""Re-draw every layer weight from its initializer.
Used when reusing a model across folds: gives each fold a fresh random
initialization, exactly as a newly constructed network would have,
while keeping the compiled training function. Covers the weight kinds
create_network produces (Dense kernel/bias, BatchNormalization
gamma/beta and moving statistics); raises if a trainable weight is
left uncovered so a future architecture change cannot silently leak
state between folds.
"""
reset = set()
for layer in model.layers:
for weight_attr, init_attr in (
("kernel", "kernel_initializer"),
("bias", "bias_initializer"),
("gamma", "gamma_initializer"),
("beta", "beta_initializer"),
("moving_mean", "moving_mean_initializer"),
("moving_variance", "moving_variance_initializer"),
):
weight = getattr(layer, weight_attr, None)
initializer = getattr(layer, init_attr, None)
if weight is not None and initializer is not None:
weight.assign(initializer(weight.shape, weight.dtype))
reset.add(id(weight))
missed = [w for w in model.trainable_variables if id(w) not in reset]
if missed:
raise RuntimeError(
"_reinitialize_layer_weights left weights uncovered: "
f"{[w.name for w in missed]}"
)
def _reset_model_for_fold(self):
"""Reset a reused model to a fresh per-fold starting state.
Re-draws all layer weights from their initializers, zeros the
optimizer state, restores the base learning rate, and re-fits the
per-fold PCA projection. The compiled training function is left
intact, so the fold does not pay an XLA recompile.
"""
self._reinitialize_layer_weights(self.model.inner)
self._reset_optimizer_state(self.model.optimizer)
self.model.optimizer.learning_rate = self.config.get("learning_rate", 0.001)
if self.config.get("pca_components") is not None:
self._inject_pca_weights(self.model, self.config["pca_components"])
@staticmethod
def _concat_histories(history1, history2):
"""Merge two Keras History objects into one covering both phases."""
keys = set(history1.history) & set(history2.history)
history2.history = {
k: list(history1.history[k]) + list(history2.history[k]) for k in keys
}
return history2
def _save_training_artifacts(self, boot=0, save=True):
"""Save training history and model metadata to disk.
Args:
boot: Bootstrap replicate number
save: Whether to actually save (False skips)
"""
if not save or not hasattr(self, "history"):
return
hist_df = pd.DataFrame(self.history.history)
hist_df.to_csv(
f"{self.config['out']}_boot{boot}_history.txt"
if self.config.get("bootstrap", False)
else f"{self.config['out']}_history.txt",
index=False,
)
self._save_model_metadata(boot=boot)
def _report_split_summary(self, split_dict, n_total, n_snps):
"""Print data split summary if verbose_splits is enabled.
Args:
split_dict: Dict mapping label to index array,
e.g. {"Training": train_idx, "Validation": val_idx}
n_total: Total number of samples
n_snps: Number of SNPs after filtering
"""
if not self.config.get("verbose_splits", False):
return
print("\nSplit summary:")
for label, indices in split_dict.items():
n = len(indices)
pct = n / n_total * 100
print(f" {label} samples: {n} ({pct:.1f}%)")
print(f" Total samples: {n_total}")
print(f" Total SNPs: {n_snps}")
def _normalize_and_store_locations(self, locs, samples, train_indices, test_indices):
"""Normalize locations based on training data and store for each split.
Args:
locs: Array of location coordinates
samples: Array of sample IDs
train_indices: Indices of training samples
test_indices: Indices of test samples
Returns
-------
normalized_locs: Array of all locations normalized using training parameters
"""
# Get training locations and normalize them
train_locs = locs[train_indices]
self.trainIDs = samples[train_indices]
(
self.meanlong,
self.sdlong,
self.meanlat,
self.sdlat,
self.unnormedlocs,
normalized_train_locs,
) = normalize_locs(train_locs)
# Normalize all locations using the training parameters (vectorized)
normalized_locs = np.empty_like(locs, dtype=np.float64)
normalized_locs[:, 0] = (locs[:, 0] - self.meanlong) / self.sdlong
normalized_locs[:, 1] = (locs[:, 1] - self.meanlat) / self.sdlat
# Store normalized locations for each split
self.trainlocs = normalized_train_locs
self.testlocs = normalized_locs[test_indices]
return normalized_locs