"""IndexSet for memory-efficient data splitting without copying arrays."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import numpy as np
[docs]
@dataclass(frozen=True)
class IndexSet:
"""Container for dataset indices that avoids copying data.
This class stores indices for different data splits (train/val/test)
to enable memory-efficient data access without creating copies of
large genotype arrays.
Attributes
----------
indices: Dictionary mapping split names to numpy arrays of indices
total_samples: Total number of samples in the dataset
na_mask: Optional boolean mask indicating samples without coordinates
"""
indices: Dict[str, np.ndarray]
total_samples: int
na_mask: Optional[np.ndarray] = None
[docs]
def __post_init__(self):
"""Validate the IndexSet after initialization."""
# Verify no overlapping indices
all_indices = []
for split_indices in self.indices.values():
all_indices.extend(split_indices.tolist())
if len(all_indices) != len(set(all_indices)):
raise ValueError("IndexSet contains overlapping indices between splits")
# Verify indices are within bounds
max_idx = max(all_indices) if all_indices else -1
if max_idx >= self.total_samples:
raise ValueError(
f"Index {max_idx} exceeds total_samples {self.total_samples}"
)
@property
def train(self) -> np.ndarray:
"""Get training indices (backward compatibility)."""
return self.indices.get("train", np.array([], dtype=int))
@property
def val(self) -> np.ndarray:
"""Get validation indices (backward compatibility)."""
return self.indices.get("val", np.array([], dtype=int))
@property
def test(self) -> np.ndarray:
"""Get test indices (backward compatibility)."""
return self.indices.get("test", np.array([], dtype=int))
@property
def hold(self) -> np.ndarray:
"""Get holdout/prediction indices (backward compatibility)."""
# Try 'hold' first, then 'test' for compatibility
return self.indices.get(
"hold", self.indices.get("test", np.array([], dtype=int))
)
[docs]
def get_split(self, name: str) -> np.ndarray:
"""Get indices for a named split."""
if name not in self.indices:
raise KeyError(
f"Split '{name}' not found. Available splits: {list(self.indices.keys())}"
)
return self.indices[name]
[docs]
def split_sizes(self) -> Dict[str, int]:
"""Get the size of each split."""
return {name: len(indices) for name, indices in self.indices.items()}
[docs]
@classmethod
def random_split(
cls,
n: int,
splits: Optional[Dict[str, float]] = None,
seed: Optional[int] = None,
na_mask: Optional[np.ndarray] = None,
na_action: str = "separate",
) -> IndexSet:
"""Create random train/val/test splits.
Args:
n: Total number of samples
splits: Dictionary mapping split names to proportions (must sum to ≤ 1.0)
Default: {"train": 0.8, "val": 0.1, "test": 0.1}
seed: Random seed for reproducibility
na_mask: Boolean mask indicating samples without coordinates
na_action: How to handle NA samples ('separate', 'exclude', 'fail')
Returns
-------
IndexSet with random splits
"""
if splits is None:
splits = {"train": 0.8, "val": 0.1, "test": 0.1}
# Validate splits
total_prop = sum(splits.values())
if total_prop > 1.0 + 1e-10:
raise ValueError(f"Split proportions sum to {total_prop}, must be ≤ 1.0")
# Handle NA samples
if na_mask is not None:
if na_action == "fail" and np.any(na_mask):
raise ValueError(
"Samples without coordinates found but na_action='fail'"
)
elif na_action == "exclude" or na_action == "separate":
# Only use samples with coordinates for train/val/test
valid_indices = np.where(~na_mask)[0]
n_valid = len(valid_indices)
if n_valid == 0:
raise ValueError("No samples with valid coordinates")
else:
valid_indices = np.arange(n)
n_valid = n
else:
valid_indices = np.arange(n)
n_valid = n
# Set random seed
rng = np.random.RandomState(seed)
# Shuffle indices
shuffled = valid_indices.copy()
rng.shuffle(shuffled)
# Create splits
indices = {}
start = 0
for i, (name, prop) in enumerate(splits.items()):
if i == len(splits) - 1:
# Last split gets remaining samples
indices[name] = shuffled[start:]
else:
size = int(np.round(prop * n_valid))
indices[name] = shuffled[start : start + size]
start += size
# Handle NA samples in 'separate' mode
if na_mask is not None and na_action == "separate" and np.any(na_mask):
# Add NA samples as a separate 'predict' split
indices["predict"] = np.where(na_mask)[0]
return cls(indices=indices, total_samples=n, na_mask=na_mask)
[docs]
@classmethod
def from_k_fold(
cls,
n: int,
k: int,
fold: int,
seed: Optional[int] = None,
na_mask: Optional[np.ndarray] = None,
) -> IndexSet:
"""Create train/test split for k-fold cross-validation.
Args:
n: Total number of samples
k: Number of folds
fold: Which fold to use as test set (0-indexed)
seed: Random seed for reproducibility
na_mask: Boolean mask indicating samples without coordinates
Returns
-------
IndexSet with train and test splits
"""
if fold >= k or fold < 0:
raise ValueError(f"Fold {fold} out of range for {k}-fold CV")
# Handle NA samples - k-fold requires known coordinates
if na_mask is not None and np.any(na_mask):
valid_indices = np.where(~na_mask)[0]
n_valid = len(valid_indices)
else:
valid_indices = np.arange(n)
n_valid = n
# Shuffle indices
rng = np.random.RandomState(seed)
shuffled = valid_indices.copy()
rng.shuffle(shuffled)
# Create folds
fold_size = n_valid // k
test_start = fold * fold_size
test_end = test_start + fold_size if fold < k - 1 else n_valid
test_indices = shuffled[test_start:test_end]
train_indices = np.concatenate([shuffled[:test_start], shuffled[test_end:]])
return cls(
indices={"train": train_indices, "test": test_indices},
total_samples=n,
na_mask=na_mask,
)
[docs]
@classmethod
def from_groups(
cls,
groups: np.ndarray,
test_groups: List[Union[int, str]],
na_mask: Optional[np.ndarray] = None,
) -> IndexSet:
"""Create train/test split based on group membership.
Useful for spatial or temporal cross-validation where you want
to hold out entire groups (e.g., geographic regions).
Args:
groups: Array of group labels for each sample
test_groups: List of group labels to use as test set
na_mask: Boolean mask indicating samples without coordinates
Returns
-------
IndexSet with train and test splits
"""
n = len(groups)
test_mask = np.isin(groups, test_groups)
# Handle NA samples
if na_mask is not None:
# Exclude NA samples from both train and test
test_indices = np.where(test_mask & ~na_mask)[0]
train_indices = np.where(~test_mask & ~na_mask)[0]
else:
test_indices = np.where(test_mask)[0]
train_indices = np.where(~test_mask)[0]
return cls(
indices={"train": train_indices, "test": test_indices},
total_samples=n,
na_mask=na_mask,
)
[docs]
@classmethod
def from_manual(
cls,
train: np.ndarray,
test: Optional[np.ndarray] = None,
val: Optional[np.ndarray] = None,
predict: Optional[np.ndarray] = None,
total_samples: Optional[int] = None,
) -> IndexSet:
"""Create IndexSet from manually specified indices.
Args:
train: Training indices
test: Test indices
val: Validation indices
predict: Prediction indices (samples without labels)
total_samples: Total number of samples (inferred if not provided)
Returns
-------
IndexSet with specified splits
"""
indices = {"train": train}
if test is not None:
indices["test"] = test
if val is not None:
indices["val"] = val
if predict is not None:
indices["predict"] = predict
# Infer total samples if not provided
if total_samples is None:
all_indices = []
for split_indices in indices.values():
all_indices.extend(split_indices.tolist())
total_samples = max(all_indices) + 1 if all_indices else 0
return cls(indices=indices, total_samples=total_samples)
[docs]
@classmethod
def k_fold_split(
cls,
n: int,
k: int,
seed: Optional[int] = None,
na_mask: Optional[np.ndarray] = None,
) -> List[IndexSet]:
"""Create all k-fold cross-validation splits at once.
This method generates k IndexSet objects, one for each fold,
suitable for ensemble training or cross-validation.
Args:
n: Total number of samples
k: Number of folds
seed: Random seed for reproducibility
na_mask: Boolean mask indicating samples to exclude from k-fold
(e.g., samples without coordinates or not in training set)
Returns
-------
List of k IndexSet objects, one for each fold
"""
# Handle excluded samples - k-fold only uses included samples
if na_mask is not None and np.any(na_mask):
valid_indices = np.where(~na_mask)[0]
n_valid = len(valid_indices)
else:
valid_indices = np.arange(n)
n_valid = n
# Shuffle indices once
rng = np.random.RandomState(seed)
shuffled = valid_indices.copy()
rng.shuffle(shuffled)
# Create all folds
fold_index_sets = []
fold_size = n_valid // k
for fold in range(k):
test_start = fold * fold_size
test_end = test_start + fold_size if fold < k - 1 else n_valid
test_indices = shuffled[test_start:test_end]
train_indices = np.concatenate([shuffled[:test_start], shuffled[test_end:]])
fold_index_sets.append(
cls(
indices={"train": train_indices, "test": test_indices},
total_samples=n,
na_mask=na_mask,
)
)
return fold_index_sets