Source code for locator.data.tf_dataset

"""Index-based TensorFlow dataset creation for GPU-resident genotype training."""

from __future__ import annotations

from typing import Optional, Tuple, Union

import numpy as np
import tensorflow as tf

from .indexset import IndexSet


def build_genotype_table(filtered_genotypes: np.ndarray) -> tf.Tensor:
    """Build a GPU-resident, sample-major genotype tensor.

    The genotype matrix for realistic runs is small enough to live on the GPU
    for the whole run -- its size is ``n_snps * n_samples * dtype_bytes``
    (e.g. 1M x 236 int8 ~= 236 MB, 2M x 236 int8 ~= 472 MB, 4x that for float32
    dosage). Holding it on-device lets the model gather genotype batches by
    sample index without any host-to-device traffic during training.

    Args:
        filtered_genotypes: Filtered genotype matrix of shape
            ``(n_snps, n_samples)`` -- int8 allele counts for hard calls, or
            float32 for genotype-likelihood dosage.

    Returns
    -------
        A ``tf.Tensor`` of shape ``(n_samples, n_snps)`` with the source dtype
        preserved, placed on the GPU when one is available.
    """
    arr = np.asarray(filtered_genotypes)
    # Sample-major rows make each sample's SNP vector contiguous, so the
    # per-batch row gather inside the model is coalesced. The native dtype is
    # kept (int8 stays int8); the cast to compute dtype happens per batch.
    sample_major = np.ascontiguousarray(arr.T)
    device = "/GPU:0" if tf.config.list_physical_devices("GPU") else "/CPU:0"
    with tf.device(device):
        return tf.constant(sample_major)


[docs] def make_tf_dataset( coordinates: np.ndarray, index_set: IndexSet, split: str, batch_size: int = 256, sample_weights: Optional[np.ndarray] = None, training: bool = True, shuffle: bool = True, drop_remainder: Optional[bool] = None, prefetch: bool = True, ) -> tf.data.Dataset: """Create an index-based tf.data pipeline for training or validation. The pipeline carries only sample indices and their coordinates -- a few kilobytes per batch. Genotypes are gathered on the GPU inside ``IndexedGenotypeModel``, so the genotype matrix never enters this pipeline and there is no per-epoch host-to-device genotype traffic. Args: coordinates: Full coordinate array of shape ``(n_samples, 2)``. index_set: IndexSet containing the train/val/test/predict splits. split: Which split to use ('train', 'val', 'test', 'predict'). batch_size: Batch size for the dataset. sample_weights: Optional per-sample weights, aligned to the split's index order (length must equal the split size). training: Whether this is for training (enables shuffling). shuffle: Whether to shuffle the split each epoch (only when training). drop_remainder: Whether to drop the final partial batch (defaults to the value of ``training``). prefetch: Whether to prefetch batches. Returns ------- A ``tf.data.Dataset`` yielding ``(sample_index, coordinate)`` batches, or ``(sample_index, coordinate, sample_weight)`` when weights are given. """ indices = np.asarray(index_set.get_split(split)) if len(indices) == 0: raise ValueError(f"Split '{split}' has no samples") if drop_remainder is None: drop_remainder = training indices = indices.astype(np.int32) coords = np.asarray(coordinates)[indices].astype(np.float32) if sample_weights is not None: if len(sample_weights) != len(indices): raise ValueError( f"Sample weights length ({len(sample_weights)}) must match " f"split size ({len(indices)})" ) weights = np.asarray(sample_weights, dtype=np.float32) dataset = tf.data.Dataset.from_tensor_slices((indices, coords, weights)) else: dataset = tf.data.Dataset.from_tensor_slices((indices, coords)) if training and shuffle: # The split is a few hundred indices; a full-size buffer is a perfect # shuffle at negligible cost. dataset = dataset.shuffle( buffer_size=len(indices), reshuffle_each_iteration=True ) dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) if prefetch: dataset = dataset.prefetch(tf.data.AUTOTUNE) return dataset
def flip_genotypes_tf(genotypes: tf.Tensor, flip_rate: float = 0.05) -> tf.Tensor: """Randomly flip genotype values with a given probability. Randomly flips allele values (0 to 1, 1 to 0); 2 (missing) is never flipped. Used for training-time data augmentation. Args: genotypes: Tensor of genotype values. flip_rate: Probability of flipping each value. Returns ------- Augmented genotypes tensor. """ # Create random mask mask = tf.random.uniform(tf.shape(genotypes)) < flip_rate # Only flip values that are 0 or 1 (not missing data encoded as 2) is_flippable = tf.less(genotypes, 2.0) mask = tf.logical_and(mask, is_flippable) # Flip: 0 to 1, 1 to 0 flipped = tf.where(mask, 1.0 - genotypes, genotypes) return flipped def make_tf_dataset_from_arrays( train_gen: np.ndarray, train_locs: np.ndarray, test_gen: Optional[np.ndarray] = None, test_locs: Optional[np.ndarray] = None, val_gen: Optional[np.ndarray] = None, val_locs: Optional[np.ndarray] = None, batch_size: int = 256, cache: bool = True, prefetch: bool = True, ) -> Union[tf.data.Dataset, Tuple[tf.data.Dataset, ...]]: """Legacy helper: build feature-based datasets from pre-split arrays. Unlike :func:`make_tf_dataset` (which is index-based), this yields ``(genotype_features, coordinates)`` batches directly from the supplied sample-major arrays. Args: train_gen: Training genotypes of shape ``(n_train, n_features)``. train_locs: Training locations of shape ``(n_train, 2)``. test_gen: Optional test genotypes. test_locs: Optional test locations. val_gen: Optional validation genotypes. val_locs: Optional validation locations. batch_size: Batch size. cache: Whether to cache each dataset in memory. prefetch: Whether to prefetch batches. Returns ------- A single dataset, or a tuple of datasets (train, test, val). """ def _build(gen, locs, training): ds = tf.data.Dataset.from_tensor_slices( (np.asarray(gen, dtype=np.float32), np.asarray(locs, dtype=np.float32)) ) if cache: ds = ds.cache() if training: ds = ds.shuffle(len(gen), reshuffle_each_iteration=True) ds = ds.batch(batch_size, drop_remainder=training) if prefetch: ds = ds.prefetch(tf.data.AUTOTUNE) return ds datasets = [_build(train_gen, train_locs, training=True)] if test_gen is not None: datasets.append(_build(test_gen, test_locs, training=False)) if val_gen is not None: datasets.append(_build(val_gen, val_locs, training=False)) return datasets[0] if len(datasets) == 1 else tuple(datasets)