"""
Sample weighting functionality for Locator.
This module provides various methods for calculating sample weights based on
geographic distribution, including KDE-based weights with bandwidth optimization,
histogram-based weights, and loading pre-calculated weights.
"""
from typing import Any, Dict, Optional, Tuple
import numpy as np
import pandas as pd
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KernelDensity
class BandwidthOptimizer:
"""
Manages bandwidth calculation and caching for KDE weights.
This class provides optimized bandwidth selection by caching results
to avoid redundant grid searches across multiple analyses.
"""
def __init__(self):
"""Initialize the bandwidth optimizer with empty cache."""
self._cache = {}
def get_bandwidth(
self,
locations: np.ndarray,
cache_key: Optional[str] = None,
bandwidth: Optional[float] = None,
n_bandwidths: int = 100,
min_bw: float = 0.1,
max_bw: float = 10.0,
cv: int = 5,
n_jobs: int = -1,
verbose: bool = False,
) -> float:
"""
Get optimal bandwidth, using cache if available.
Args:
locations: Array of shape (n_samples, 2) with x, y coordinates
cache_key: Key for caching results. If None, creates key from data hash
bandwidth: Pre-specified bandwidth. If provided, returns this value
n_bandwidths: Number of bandwidth values to test
min_bw: Minimum bandwidth value
max_bw: Maximum bandwidth value
cv: Number of cross-validation folds
n_jobs: Number of parallel jobs (-1 for all cores)
verbose: Whether to print progress
Returns
-------
Optimal bandwidth value
"""
# Return pre-specified bandwidth if provided
if bandwidth is not None:
return bandwidth
# Create cache key if not provided
if cache_key is None:
# Use data characteristics for cache key
cache_key = f"n={len(locations)}_mean={locations.mean():.3f}_std={locations.std():.3f}"
# Check cache
if cache_key in self._cache:
if verbose:
print(
f"Using cached bandwidth for key '{cache_key}': {self._cache[cache_key]:.3f}"
)
return self._cache[cache_key]
# Calculate optimal bandwidth
if verbose:
print(
f"Calculating optimal bandwidth ({n_bandwidths} values from {min_bw} to {max_bw})..."
)
bandwidths = np.linspace(min_bw, max_bw, n_bandwidths)
grid = GridSearchCV(
KernelDensity(kernel="gaussian"),
{"bandwidth": bandwidths},
cv=cv,
n_jobs=n_jobs,
)
grid.fit(locations)
optimal_bw = grid.best_params_["bandwidth"]
# Cache result
self._cache[cache_key] = optimal_bw
if verbose:
print(
f"Optimal bandwidth: {optimal_bw:.3f} (CV score: {grid.best_score_:.3f})"
)
return optimal_bw
def clear_cache(self, cache_key: Optional[str] = None):
"""
Clear bandwidth cache.
Args:
cache_key: Specific key to clear. If None, clears entire cache
"""
if cache_key is None:
self._cache.clear()
elif cache_key in self._cache:
del self._cache[cache_key]
# Global optimizer instance for shared caching across analyses
_global_optimizer = None
def get_global_bandwidth_optimizer() -> BandwidthOptimizer:
"""Get or create the global bandwidth optimizer instance."""
global _global_optimizer
if _global_optimizer is None:
_global_optimizer = BandwidthOptimizer()
return _global_optimizer
def calculate_optimal_bandwidth(
locations: np.ndarray,
n_bandwidths: int = 100,
min_bw: float = 0.1,
max_bw: float = 10.0,
cv: int = 5,
n_jobs: int = -1,
verbose: bool = False,
) -> Tuple[float, Dict[str, Any]]:
"""
Calculate the optimal KDE bandwidth for a set of locations.
This is a standalone function for one-off bandwidth calculation without caching.
Args:
locations: Array of shape (n_samples, 2) with x, y coordinates
n_bandwidths: Number of bandwidth values to test
min_bw: Minimum bandwidth value
max_bw: Maximum bandwidth value
cv: Number of cross-validation folds
n_jobs: Number of parallel jobs (-1 for all cores)
verbose: Whether to print progress
Returns
-------
Tuple of (optimal_bandwidth, info_dict) where info_dict contains:
- 'bandwidth': optimal bandwidth value
- 'cv_scores': cross-validation scores for all tested bandwidths
- 'bandwidths_tested': array of bandwidth values tested
- 'best_score': best cross-validation score
"""
if len(locations) < 2:
raise ValueError("Need at least 2 locations to calculate bandwidth")
if verbose:
print(
f"Calculating optimal KDE bandwidth using {n_bandwidths} values from {min_bw} to {max_bw}"
)
bandwidths = np.linspace(min_bw, max_bw, n_bandwidths)
grid = GridSearchCV(
KernelDensity(kernel="gaussian"), {"bandwidth": bandwidths}, cv=cv, n_jobs=n_jobs
)
grid.fit(locations)
optimal_bandwidth = grid.best_params_["bandwidth"]
best_score = grid.best_score_
if verbose:
print(f"Optimal bandwidth: {optimal_bandwidth:.3f} (CV score: {best_score:.3f})")
return optimal_bandwidth, {
"bandwidth": optimal_bandwidth,
"cv_scores": grid.cv_results_["mean_test_score"],
"bandwidths_tested": bandwidths,
"best_score": best_score,
}
[docs]
def weight_samples(
method: str,
trainlocs: Optional[np.ndarray] = None,
trainsamps: Optional[np.ndarray] = None,
weightdf: Optional[pd.DataFrame] = None,
xbins: Optional[int] = None,
ybins: Optional[int] = None,
lam: Optional[float] = None,
bandwidth: Optional[float] = None,
cache_bandwidth: bool = True,
n_bandwidths: int = 100,
) -> Dict[str, Any]:
"""
Calculate weights for training data based on the specified method.
Args:
method: Method for calculating weights ('KD', 'histogram', or 'load')
trainlocs: Training locations (required for KD and histogram methods)
trainsamps: Training sample IDs
weightdf: DataFrame containing pre-calculated sample weights
xbins: Number of bins in x direction for histogram method
ybins: Number of bins in y direction for histogram method
lam: Exponent for KDE weights
bandwidth: Bandwidth for KDE (if None, will be calculated)
cache_bandwidth: Whether to use bandwidth caching for KDE
n_bandwidths: Number of bandwidth values to test if calculating
Returns
-------
Dictionary containing:
- 'method': weighting method used
- 'sample_weights': array of weights
- 'sample_weights_df': DataFrame with sampleID and weights
- method-specific parameters
"""
if method == "KD":
if trainlocs is None:
raise ValueError("trainlocs required for KD method")
weights = _make_kd_weights(
trainlocs,
lam=1.0 if lam is None else lam,
bandwidth=bandwidth,
cache_bandwidth=cache_bandwidth,
n_bandwidths=n_bandwidths,
)
df = pd.DataFrame({"sampleID": trainsamps, "sample_weight": weights})
elif method == "histogram":
if trainlocs is None:
raise ValueError("trainlocs required for histogram method")
weights = _make_histogram_weights(
trainlocs,
xbins=10 if xbins is None else xbins,
ybins=10 if ybins is None else ybins,
)
df = pd.DataFrame({"sampleID": trainsamps, "sample_weight": weights})
elif method == "load":
if weightdf is None:
raise ValueError("weightdf required for load method")
df = _load_sample_weights(weightdf, trainsamps)
weights = df["sample_weight"].values
else:
raise ValueError("Invalid method. Choose 'KD', 'histogram', or 'load'.")
return {
"method": method,
"sample_weights": weights,
"sample_weights_df": df,
"xbins": xbins,
"ybins": ybins,
"lam": lam,
"bandwidth": bandwidth,
}
def _make_kd_weights(
trainlocs: np.ndarray,
lam: float = 1.0,
bandwidth: Optional[float] = None,
cache_bandwidth: bool = True,
n_bandwidths: int = 100,
) -> np.ndarray:
"""
Calculate weights using Kernel Density Estimation with optimized bandwidth selection.
Args:
trainlocs: Training locations, shape (n_samples, 2)
lam: Exponent for weights
bandwidth: Pre-specified bandwidth. If None, will be calculated
cache_bandwidth: Whether to use bandwidth caching
n_bandwidths: Number of bandwidth values to test
Returns
-------
Array of normalized weights
"""
# Get bandwidth (from cache, parameter, or calculate)
if bandwidth is None and cache_bandwidth:
optimizer = get_global_bandwidth_optimizer()
bw = optimizer.get_bandwidth(trainlocs, n_bandwidths=n_bandwidths)
elif bandwidth is None:
# Calculate without caching
bw, _ = calculate_optimal_bandwidth(
trainlocs, n_bandwidths=n_bandwidths, verbose=False
)
else:
bw = bandwidth
# Fit kernel with determined bandwidth
kde = KernelDensity(bandwidth=bw, kernel="gaussian")
kde.fit(trainlocs)
# Calculate weights
weights = kde.score_samples(trainlocs)
weights = 1.0 / np.exp(weights)
weights /= min(weights)
weights = np.power(weights, lam)
weights /= sum(weights)
return weights
def _make_histogram_weights(
trainlocs: np.ndarray, xbins: int = 10, ybins: int = 10
) -> np.ndarray:
"""
Calculate weights using 2D histogram binning.
Args:
trainlocs: Training locations, shape (n_samples, 2)
xbins: Number of bins in x direction
ybins: Number of bins in y direction
Returns
-------
Array of weights based on inverse bin density
"""
bincount = [xbins, ybins]
# Make 2D histogram
H, xedges, yedges = np.histogram2d(trainlocs[:, 0], trainlocs[:, 1], bins=bincount)
# Sort trainlocs into bins
xbin = np.digitize(trainlocs[:, 0], xedges[1:], right=True)
ybin = np.digitize(trainlocs[:, 1], yedges[1:], right=True)
# Assign sample weights
weights = np.empty(len(trainlocs), dtype="float")
for i in range(len(trainlocs)):
weights[i] = 1 / (H[xbin[i]][ybin[i]])
weights /= min(weights)
return weights
def _load_sample_weights(weightdf: pd.DataFrame, trainsamps: list) -> pd.DataFrame:
"""
Load pre-calculated sample weights from a DataFrame.
Args:
weightdf: DataFrame with columns 'sampleID' and 'sample_weight'
trainsamps: List of training sample IDs
Returns
-------
DataFrame with sample weights for training samples
"""
if "sampleID" not in weightdf.columns or "sample_weight" not in weightdf.columns:
raise ValueError("weightdf must contain 'sampleID' and 'sample_weight' columns")
# Create a copy to avoid modifying original
df = weightdf.copy()
df.set_index("sampleID", inplace=True)
# Extract weights for training samples
weights = []
for samp in trainsamps:
if samp not in df.index:
raise ValueError(f"Sample '{samp}' not found in weight DataFrame")
w = df.loc[samp, "sample_weight"]
if isinstance(w, pd.Series):
weights.append(w.iloc[0])
else:
weights.append(w)
return pd.DataFrame({"sampleID": trainsamps, "sample_weight": weights})