"""Neural network model definitions"""
from typing import Optional
import geopandas as gpd
import numpy as np
import tensorflow as tf
from affine import Affine
from rasterio.features import rasterize
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras import layers
PCA_LAYER_NAME = "pca_projection"
PCA_GATE_NAME = "pca_finetune_gate"
def build_optimizer(algo, learning_rate, weight_decay=0.004):
"""Build a Keras optimizer from an algorithm name.
Args:
algo: "adam" or "adamw" (case-insensitive).
learning_rate: Learning rate.
weight_decay: Weight decay, used only for AdamW.
Returns
-------
A configured keras optimizer.
"""
algo = algo.lower()
if algo == "adam":
return keras.optimizers.Adam(learning_rate=learning_rate)
if algo == "adamw":
return keras.optimizers.AdamW(
learning_rate=learning_rate, weight_decay=weight_decay
)
raise ValueError(f"Unsupported optimizer: {algo}")
[docs]
def rasterize_species_range(shapefile_path, resolution=0.1):
gdf = gpd.read_file(shapefile_path)
geom = gdf.unary_union
bounds = gdf.total_bounds # xmin, ymin, xmax, ymax
xmin, ymin, xmax, ymax = bounds
width = int((xmax - xmin) / resolution)
height = int((ymax - ymin) / resolution)
transform = Affine.translation(xmin, ymin) * Affine.scale(resolution, resolution)
mask = rasterize(
[(geom, 1)],
out_shape=(height, width),
transform=transform,
fill=0,
dtype="uint8",
)
return mask.astype(np.float32), transform
def euclidean_distance_loss(y_true, y_pred):
"""Custom loss function using Euclidean distance.
Args:
y_true: Tensor of true coordinates
y_pred: Tensor of predicted coordinates
Returns
-------
Euclidean distance between true and predicted coordinates
"""
return K.sqrt(K.sum(K.square(y_pred - y_true), axis=-1))
def mask_lookup(pred_coords, mask_tensor, transform, resolution):
"""
pred_coords: (batch_size, 2), float32 [lon, lat]
mask_tensor: (H, W), float32
transform: Affine, from rasterize step
resolution: float, grid resolution used
"""
# Convert coordinates to mask indices
lon = pred_coords[:, 0]
lat = pred_coords[:, 1]
col = tf.clip_by_value(
((lon - transform.c) / resolution), 0, mask_tensor.shape[1] - 1
)
row = tf.clip_by_value(
((lat - transform.f) / resolution), 0, mask_tensor.shape[0] - 1
)
# Integer indices (nearest-neighbor lookup)
col = tf.cast(tf.round(col), tf.int32)
row = tf.cast(tf.round(row), tf.int32)
# Gather values from mask
idx = tf.stack([row, col], axis=-1)
valid = tf.gather_nd(mask_tensor, idx)
return valid
[docs]
def loss_with_range_penalty(
y_true, y_pred, mask_tensor, transform, resolution, penalty_weight=1.0
):
# Euclidean distance
euclidean = tf.sqrt(tf.reduce_sum(tf.square(y_pred - y_true), axis=-1))
# GPU-friendly range mask
valid_mask = mask_lookup(y_pred, mask_tensor, transform, resolution)
# Penalize out-of-range predictions
penalty = tf.square(1.0 - valid_mask)
# tf.print("Euclidean:", euclidean, summarize=10)
# tf.print("Penalty:", penalty, summarize=10)
# tf.print("Valid mask:", valid_mask, summarize=10)
return euclidean + penalty_weight * penalty
class GradientGate(keras.layers.Layer):
"""Pass values through unchanged while scaling the gradient by a gate.
Lets the PCA-initialized projection switch between frozen (phase 1) and
fine-tuning (phase 2) without changing the training graph. The layer's
output value is always its input, but the gradient that reaches the input
-- and therefore the upstream projection's weights -- is multiplied by
``gate``: 0 holds the projection at its PCA initialization, 1 lets it
train. ``gate`` is a non-trainable variable, so flipping it neither
retraces nor recompiles the graph, which keeps the compiled training
function reusable across both phases and across folds.
"""
def build(self, input_shape):
"""Create the non-trainable gate variable (starts closed, at 0)."""
self.gate = self.add_weight(
name="gate",
shape=(),
initializer="zeros",
trainable=False,
dtype="float32",
)
super().build(input_shape)
def call(self, inputs):
"""Return the input value with its gradient scaled by the gate."""
# gate == 0: output value == inputs, gradient to inputs == 0.
# gate == 1: output == inputs with the full gradient.
frozen = tf.stop_gradient(inputs)
return frozen + self.gate * (inputs - frozen)
[docs]
def create_network(
input_shape: int,
width: int = 256,
n_layers: int = 8,
dropout_prop: float = 0.25,
pca_components: Optional[int] = None,
optimizer_config: Optional[dict] = None,
loss_fn: Optional[callable] = None,
) -> keras.Model:
"""Create a neural network model for geographic location prediction.
:param input_shape: Number of input features (SNPs).
:type input_shape: int
:param width: Width of the dense layers, defaults to 256.
:type width: int, optional
:param n_layers: Total number of dense layers (excluding final layers),
defaults to 8.
:type n_layers: int, optional
:param dropout_prop: Dropout proportion for middle dropout layer,
defaults to 0.25.
:type dropout_prop: float, optional
:param pca_components: If set, prepend a linear projection layer named
"pca_projection" of this width as the first layer. The caller is
responsible for initializing its weights with PCA loadings. Defaults
to None (no projection layer).
:type pca_components: int, optional
:param optimizer_config: Configuration for the optimizer. Should be a dict containing keys:
"algo" (str): "adam" or "adamw";
"learning_rate" (float);
"weight_decay" (float, only used for "adamw").
Defaults to None (uses Adam with default settings).
:type optimizer_config: dict, optional
:param loss_fn: Loss function to use. If None, defaults to
euclidean_distance_loss, defaults to None.
:type loss_fn: callable, optional
:return: Compiled Keras model ready for training.
:rtype: keras.Model
Example:
>>> model = create_network(input_shape=1000)
>>> model.summary()
"""
# Create input layer explicitly
inputs = keras.Input(shape=(input_shape,))
# Optional PCA-initialized linear projection as the first layer. Placed
# before BatchNormalization so it sees raw genotype counts, which lets a
# caller set its weights to PCA loadings and reproduce PCA scores exactly.
# Pinned to float32: it computes raw_counts @ loadings + bias, where the
# two terms are large and nearly cancel, so float16 loses the result.
if pca_components is not None:
x = layers.Dense(
pca_components,
activation="linear",
name=PCA_LAYER_NAME,
dtype="float32",
)(inputs)
# Gradient gate: switches the projection between frozen and fine-tuning
# by flipping a variable, so the training graph never changes.
x = GradientGate(name=PCA_GATE_NAME, dtype="float32")(x)
x = layers.BatchNormalization()(x)
else:
x = layers.BatchNormalization()(inputs)
# First half of layers
for i in range(int(np.floor(n_layers / 2))):
x = layers.Dense(width, activation="elu")(x)
# Middle dropout layer
x = layers.Dropout(dropout_prop)(x)
# Second half of layers
for i in range(int(np.ceil(n_layers / 2))):
x = layers.Dense(width, activation="elu")(x)
# Two final coordinate prediction layers
x = layers.Dense(2)(x)
outputs = layers.Dense(2)(x)
# Create model with explicit inputs/outputs
model = keras.Model(inputs=inputs, outputs=outputs, name="locator_network")
# Configure optimizer
if optimizer_config is None:
optimizer = "Adam"
else:
optimizer = build_optimizer(
optimizer_config["algo"],
optimizer_config["learning_rate"],
optimizer_config.get("weight_decay", 0.004),
)
# Use provided loss function if available; else default to euclidean_distance_loss
if loss_fn is None:
loss_fn = euclidean_distance_loss
# Compile model with configured optimizer and loss
model.compile(optimizer=optimizer, loss=loss_fn)
return model
class IndexedGenotypeModel(keras.Model):
"""Model that gathers genotypes from a GPU-resident table by sample index.
The genotype matrix for realistic runs is small enough to live on the GPU
for the entire run (n_snps x n_samples x dtype bytes; sub-GB as int8). This
wrapper holds the whole matrix as a GPU-resident, sample-major tensor and
makes the model's input a vector of sample indices instead of a genotype
batch. Each training step the only host-to-device traffic is the index
vector; the row gather, dtype cast, and optional augmentation all run on the
GPU, and ``inner`` -- the ordinary coordinate-prediction network -- sees the
same ``(batch, n_snps)`` float features it always has.
The genotype table is held as a plain tensor, never a tracked weight, so it
stays out of ``get_weights()`` and checkpoints. ``save_weights`` /
``load_weights`` / ``get_layer`` delegate to ``inner`` so the on-disk weight
format and PCA-layer wiring are identical to a bare network.
Parameters
----------
inner : keras.Model
Coordinate-prediction network from ``create_network``; consumes
``(batch, n_snps)`` genotype features.
genotype_table : tf.Tensor
Sample-major genotype matrix, shape ``(n_samples, n_snps)``, native
dtype (int8 hard calls or float32 dosage).
site_order : array-like or None
Optional SNP resampling order (bootstrap/jacknife), applied as a
per-batch column gather after the row gather.
augment : dict or None
Optional augmentation config with ``enabled`` and ``flip_rate`` keys;
genotype flipping is applied only during training.
"""
def __init__(self, inner, genotype_table, site_order=None, augment=None):
super().__init__(name="indexed_genotype_model")
self.inner = inner
# Deliberately a plain tensor, not a tf.Variable: a Variable attribute
# would be tracked by Keras and copied on every get_weights() call
# (e.g. by EarlyStopping(restore_best_weights=True)). A constant tensor
# is captured by reference into the traced call and stays GPU-resident.
self._table = genotype_table
self._site_order = (
tf.constant(np.asarray(site_order), dtype=tf.int32)
if site_order is not None
else None
)
self._augment = augment or {}
# Imported here rather than at module scope to keep the models module
# free of any import-time dependency on the data subpackage.
from .data.tf_dataset import flip_genotypes_tf
self._flip_fn = flip_genotypes_tf
@property
def genotype_table(self):
"""The GPU-resident genotype table this model gathers from.
The compiled ``call`` captures this tensor by reference, so a model
may only be reused while its table is unchanged.
"""
return self._table
def call(self, idx, training=False):
"""Gather genotypes for a batch of sample indices and predict."""
idx = tf.cast(idx, tf.int32)
g = tf.gather(self._table, idx, axis=0) # (batch, n_snps), on GPU
if self._site_order is not None:
g = tf.gather(g, self._site_order, axis=1) # bootstrap SNP resample
# float32 keeps the optional pca_projection layer (float32, raw counts)
# exact; mixed-precision layers inside inner downcast internally.
g = tf.cast(g, tf.float32)
if training and self._augment.get("enabled", False):
g = self._flip_fn(g, self._augment.get("flip_rate", 0.05))
return self.inner(g, training=training)
def get_layer(self, *args, **kwargs):
"""Delegate so lookups of inner layers (e.g. PCA_LAYER_NAME) resolve."""
return self.inner.get_layer(*args, **kwargs)
def save_weights(self, *args, **kwargs):
"""Persist only the inner network -- on-disk weight format is unchanged."""
return self.inner.save_weights(*args, **kwargs)
def load_weights(self, *args, **kwargs):
"""Load weights into the inner network."""
return self.inner.load_weights(*args, **kwargs)
def feature_network(model):
"""Return the network that consumes genotype features.
Prediction and batch-size probing feed genotype features directly; given
either an IndexedGenotypeModel or a plain network, return the one that
takes ``(batch, n_snps)`` input.
"""
return model.inner if isinstance(model, IndexedGenotypeModel) else model
__all__ = [
"PCA_LAYER_NAME",
"PCA_GATE_NAME",
"GradientGate",
"IndexedGenotypeModel",
"feature_network",
"build_optimizer",
"create_network",
"euclidean_distance_loss",
"loss_with_range_penalty",
"rasterize_species_range",
"mask_lookup",
]