"""Plotting functionality for locator predictions"""
import base64
import io
from pathlib import Path
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.subplots as sp
import seaborn as sns
from geopy.distance import geodesic
from scipy.stats import gaussian_kde
__all__ = [
"kde_predict",
"plot_predictions",
"plot_error_summary",
"plot_interactive_error_map",
"plot_sample_weights",
"PlottingMixin",
]
def _handle_plot_display(show=None):
"""Handle whether to display a plot based on environment.
Args:
show: None (auto-detect), True (always show), or False (never show)
"""
if show is None:
# Auto-detect: show only if in interactive environment (Jupyter/IPython)
try:
get_ipython() # This is defined in IPython/Jupyter
plt.show()
except NameError:
# Not in interactive environment, don't show
pass
elif show:
plt.show()
[docs]
def kde_predict(x_coords, y_coords, xlim=(0, 50), ylim=(0, 50), n_points=100):
"""Calculate kernel density estimate of predictions.
This is a helper function used internally by plot_predictions() to compute
kernel density estimates for visualizing prediction uncertainty.
Args:
x_coords (array-like): Array of x coordinates (longitude values)
y_coords (array-like): Array of y coordinates (latitude values)
xlim (tuple): Tuple of (min, max) x values for grid. Default: (0, 50)
ylim (tuple): Tuple of (min, max) y values for grid. Default: (0, 50)
n_points (int): Number of points for density estimation grid. Default: 100
Returns
-------
tuple: A 3-tuple containing:
- **x_grid** (*numpy.ndarray*): X coordinates of the mesh grid
- **y_grid** (*numpy.ndarray*): Y coordinates of the mesh grid
- **density** (*numpy.ndarray*): Density values at each grid point
Returns (None, None, None) if KDE calculation fails.
Note:
The function uses scipy.stats.gaussian_kde for density estimation.
Grid limits should match the geographic extent of your predictions.
"""
try:
# Calculate kernel density
positions = np.vstack([x_coords, y_coords])
kernel = gaussian_kde(positions)
# Create grid of points using full plot range
x_grid = np.linspace(xlim[0], xlim[1], n_points)
y_grid = np.linspace(ylim[0], ylim[1], n_points)
xx, yy = np.meshgrid(x_grid, y_grid)
# Evaluate kernel on grid
positions = np.vstack([xx.ravel(), yy.ravel()])
density = np.reshape(kernel(positions).T, xx.shape)
return xx, yy, density
except Exception as e:
print(f"KDE failed: {e}")
return None, None, None
[docs]
def plot_predictions( # noqa: C901
predictions,
locator,
out_prefix,
samples=None,
n_samples=9,
n_cols=3,
plot_map=False,
width=5,
height=4,
dpi=300,
n_levels=3,
show=None,
):
"""Plot locator predictions from jacknife, bootstrap, or windows analyses.
This function visualizes predictions from any of locator's prediction methods
that generate multiple predictions per sample. It creates a grid of subplots,
one per sample, showing the distribution of predictions as KDE contours.
The function expects prediction data with:
- A 'sampleID' column
- Multiple prediction columns ('x_0', 'x_1'... and 'y_0', 'y_1'...)
For each sample, the plot shows:
- KDE contours of predictions (blue lines)
- True location if known (red star)
- All training sample locations (gray circles)
Args:
predictions (pandas.DataFrame or str): DataFrame or path to predictions file.
Output from any of:
- ``locator.run_jacknife(return_df=True)``
- ``locator.run_bootstraps(return_df=True)``
- ``locator.run_windows(return_df=True)``
locator (Locator): Locator instance containing training data configuration
out_prefix (str): Prefix for output files. Plot saved as {out_prefix}_predictions.pdf
samples (list, optional): List of sample IDs to plot. If None, randomly selects n_samples
n_samples (int): Number of samples to plot if samples not specified. Default: 9
n_cols (int): Number of columns in plot grid. Default: 3
plot_map (bool): Whether to plot on a geographic map (requires cartopy). Default: False
width (float): Width of each subplot in inches. Default: 5
height (float): Height of each subplot in inches. Default: 4
dpi (int): DPI resolution for output figure. Default: 300
n_levels (int): Number of KDE contour levels to plot. Default: 3
show (bool or None): Whether to display plot. None=auto-detect environment. Default: None
Returns
-------
None: Saves plot to file and optionally displays it
Examples
--------
For jacknife analysis::
predictions = locator.run_jacknife(genotypes, samples, return_df=True)
plot_predictions(predictions, locator, "jacknife_example")
For bootstrap analysis::
predictions = locator.run_bootstraps(genotypes, samples, return_df=True)
plot_predictions(predictions, locator, "bootstrap_example")
For windows analysis::
predictions = locator.run_windows(genotypes, samples, return_df=True)
plot_predictions(predictions, locator, "windows_example")
Plot specific samples::
plot_predictions(predictions, locator, "selected",
samples=['HG001', 'HG002', 'HG003'])
Note:
- Requires matplotlib and scipy for KDE calculation
- If plot_map=True, requires cartopy for geographic projections
- Automatically adjusts plot limits based on prediction ranges
- KDE may fail for samples with very few predictions
"""
# Load predictions
if isinstance(predictions, (str, Path)):
pred_path = Path(predictions)
if pred_path.is_file():
preds = pd.read_csv(pred_path)
else:
pred_files = list(pred_path.glob("*predlocs.*"))
preds = pd.concat([pd.read_csv(f) for f in pred_files])
else:
preds = predictions
# Get sample data from locator
if isinstance(locator.config["sample_data"], pd.DataFrame):
samples_df = locator.config["sample_data"].copy()
else:
samples_df = pd.read_csv(
locator.config["sample_data"], sep="\t", na_values="NA", quotechar='"'
)
samples_df.columns = samples_df.columns.str.strip('"')
if "sampleID" in samples_df.columns:
samples_df["sampleID"] = samples_df["sampleID"].str.strip('"')
# Select samples to plot if not provided
if samples is None:
available_samples = preds["sampleID"].unique()
samples = np.random.choice(
available_samples,
size=min(n_samples, len(available_samples)),
replace=False,
)
# Create figure
n_rows = int(np.ceil(len(samples) / n_cols))
fig = plt.figure(figsize=(width * n_cols, height * n_rows), dpi=dpi)
# Get x and y columns and calculate limits
x_cols = [col for col in preds.columns if col.startswith("x_")]
y_cols = [col for col in preds.columns if col.startswith("y_")]
# Calculate global min/max for x and y coordinates
x_all = preds[x_cols].values.ravel()
y_all = preds[y_cols].values.ravel()
# Add some padding (10%) to the limits
padding = 0.1
x_range = x_all.max() - x_all.min()
y_range = y_all.max() - y_all.min()
xlim = (x_all.min() - x_range * padding, x_all.max() + x_range * padding)
ylim = (y_all.min() - y_range * padding, y_all.max() + y_range * padding)
# Plot each sample
for i, sample in enumerate(samples, 1):
ax = fig.add_subplot(
n_rows, n_cols, i, projection=ccrs.PlateCarree() if plot_map else None
)
sample_preds = preds[preds["sampleID"] == sample]
sample_true = samples_df[samples_df["sampleID"] == sample]
if plot_map:
ax.add_feature(cfeature.LAND, facecolor="lightgray")
ax.add_feature(cfeature.COASTLINE, linewidth=0.5)
else:
ax.set_xlim(xlim)
ax.set_ylim(ylim)
# Plot all training sample locations as background
# Only plot samples that have true locations (not NA)
training_locs = samples_df[pd.notna(samples_df["x"]) & pd.notna(samples_df["y"])]
if not training_locs.empty:
ax.scatter(
training_locs["x"],
training_locs["y"],
c="gray",
marker="o",
s=20,
facecolors="none",
alpha=0.5,
linewidth=0.5,
label="Training samples",
)
# Plot predictions using KDE
if x_cols: # Changed from checking columns again to using existing x_cols
# Multiple predictions per sample (e.g., jacknife)
# Collect all predictions
x_preds = sample_preds[x_cols].values.ravel()
y_preds = sample_preds[y_cols].values.ravel()
# Calculate KDE using plot limits
xx, yy, density = kde_predict(x_preds, y_preds, xlim=xlim, ylim=ylim)
if density is not None:
# Calculate percentile-based contour levels
density_flat = density.ravel()
levels = np.percentile(density_flat[density_flat > 0], [85, 90, 95, 99])
# Plot contour lines
ax.contour(
xx,
yy,
density,
levels=levels,
colors="blue",
alpha=0.8,
linewidths=0.5,
)
# Plot true location if it exists and is not NA
if len(sample_true) > 0 and pd.notna(sample_true.iloc[0]["x"]):
ax.scatter(
sample_true.iloc[0]["x"],
sample_true.iloc[0]["y"],
c="red",
marker="*",
s=100,
label="True",
)
ax.set_title(f"Sample {sample}")
plt.tight_layout()
if out_prefix:
plt.savefig(f"{out_prefix}_predictions.pdf")
_handle_plot_display(show)
plt.close()
return None
[docs]
def plot_error_summary( # noqa: C901
predictions,
sample_data,
out_prefix=None,
plot_map=True,
width=20,
height=10,
dpi=300,
use_geodesic=True,
include_training_locs=True,
show=None,
return_merged=False,
):
"""
Plot summary of prediction errors from holdout analysis.
Creates a comprehensive error visualization with two panels:
1. **Map/Scatter panel**: Shows true locations colored by prediction error,
with lines connecting true and predicted locations
2. **Histogram panel**: Distribution of errors with summary statistics
This function is designed for analyzing results from holdout methods like:
- ``run_holdouts()``
- ``run_k_fold_holdouts()``
- ``run_leave_one_out()``
Args:
predictions (pandas.DataFrame): DataFrame with columns:
- ``sampleID``: Sample identifiers
- ``x_pred``: Predicted longitude
- ``y_pred``: Predicted latitude
sample_data (pandas.DataFrame or str): DataFrame or path to TSV file with columns:
- ``sampleID``: Sample identifiers (must match predictions)
- ``x``: True longitude
- ``y``: True latitude
out_prefix (str, optional): Prefix for output files. If provided, saves as
{out_prefix}_error_summary.png (or .html for interactive). Default: None
plot_map (bool): Whether to plot on a geographic map using cartopy projection.
If False, uses regular scatter plot. Default: True
width (float): Figure width in inches. Default: 20
height (float): Figure height in inches. Default: 10
dpi (int): Figure resolution in dots per inch. Default: 300
use_geodesic (bool): If True, calculate geodesic distances in kilometers.
If False, use Euclidean distances in coordinate units. Default: True
include_training_locs (bool): Whether to plot training locations (gray circles)
and use their extent for map bounds. Default: True
show (bool or None): Whether to display plot. None=auto-detect environment,
True=always show, False=never show. Default: None
return_merged (bool): If True, return the internal merged DataFrame used for plotting.
Default: False
Returns
-------
None: Saves plot to file and optionally displays it.
If return_merged is True, returns the internal merged DataFrame containing prediction errors and true locations.
Raises
------
ValueError: If predictions or sample_data are empty, have missing columns,
or have no matching samples
Examples
--------
Basic usage with k-fold results::
predictions = locator.run_k_fold_holdouts(genotypes, samples, return_df=True)
plot_error_summary(predictions, "samples.tsv", "kfold_errors")
With DataFrame input and Euclidean distances::
plot_error_summary(predictions, sample_df,
out_prefix="holdout_errors",
use_geodesic=False)
Without map projection::
plot_error_summary(predictions, sample_df,
plot_map=False,
width=10, height=5)
Return merged DataFrame::
merged = plot_error_summary(predictions, sample_df, return_merged=True)
Note:
- Summary statistics shown: mean, median, max error, R² for x and y
- Training locations help visualize geographic sampling bias
- Geodesic distances account for Earth's curvature
- Map projection requires cartopy to be installed
"""
# Validate predictions input
if predictions.empty:
raise ValueError("Predictions DataFrame cannot be empty")
# Consolidate loading and validation of sample_data
if isinstance(sample_data, pd.DataFrame):
samples = sample_data.copy()
elif isinstance(sample_data, (str, Path)):
sample_path = Path(sample_data)
if not sample_path.is_file():
raise ValueError(f"sample_data file {sample_data} does not exist")
samples = pd.read_csv(sample_path, sep="\t")
else:
raise ValueError("sample_data must be a DataFrame or a valid file path")
if samples.empty:
raise ValueError("Sample data cannot be empty")
# Validate required columns in predictions and samples
required_pred_cols = ["sampleID", "x_pred", "y_pred"]
required_sample_cols = ["sampleID", "x", "y"]
missing_pred_cols = [
col for col in required_pred_cols if col not in predictions.columns
]
missing_sample_cols = [
col for col in required_sample_cols if col not in samples.columns
]
if missing_pred_cols:
raise ValueError(f"Missing required columns in predictions: {missing_pred_cols}")
if missing_sample_cols:
raise ValueError(
f"Missing required columns in sample data: {missing_sample_cols}"
)
samples = samples.rename(columns={"x": "x_true", "y": "y_true"})
plt.rcParams.update(
{
"font.size": 12,
"axes.labelsize": 14,
"axes.titlesize": 14,
"xtick.labelsize": 12,
"ytick.labelsize": 12,
"legend.fontsize": 12,
}
)
# Merge predictions with true locations
merged = predictions.merge(samples[["sampleID", "x_true", "y_true"]], on="sampleID")
if merged.empty:
raise ValueError("No matching samples found between predictions and sample data")
# Calculate errors
if use_geodesic:
merged["error"] = merged.apply(
lambda row: (
geodesic(
(row["y_true"], row["x_true"]), (row["y_pred"], row["x_pred"])
).kilometers
),
axis=1,
)
error_units = "km"
else:
merged["error"] = np.sqrt(
(merged["x_pred"] - merged["x_true"]) ** 2
+ (merged["y_pred"] - merged["y_true"]) ** 2
)
error_units = "coordinate units"
# Set up figure and primary axis based on plot_map flag
if plot_map:
fig = plt.figure(figsize=(width, height), dpi=dpi)
gs = fig.add_gridspec(1, 3)
map_ax = fig.add_subplot(gs[0:2], projection=ccrs.PlateCarree())
else:
fig = plt.figure(figsize=(width, height), dpi=dpi)
gs = fig.add_gridspec(1, 2)
map_ax = fig.add_subplot(gs[0])
# Common axis setup
map_ax.set_xticks([])
map_ax.set_yticks([])
if plot_map:
map_ax.add_feature(cfeature.LAND, facecolor="lightgray")
map_ax.add_feature(cfeature.COASTLINE, linewidth=0.5)
# Determine bounds and optionally plot training locations
if include_training_locs:
x_min, x_max = samples["x_true"].min(), samples["x_true"].max()
y_min, y_max = samples["y_true"].min(), samples["y_true"].max()
training_mask = ~samples["sampleID"].isin(predictions["sampleID"])
training_locs = samples[training_mask]
if not training_locs.empty:
map_ax.scatter(
training_locs["x_true"],
training_locs["y_true"],
c="gray",
marker="o",
s=20,
alpha=0.5,
label="Training locations",
)
else:
x_min, x_max = merged["x_true"].min(), merged["x_true"].max()
y_min, y_max = merged["y_true"].min(), merged["y_true"].max()
padding = 0.1
x_range = x_max - x_min
y_range = y_max - y_min
if plot_map:
# Use set_extent only for map projections.
map_ax.set_extent(
[
x_min - x_range * padding,
x_max + x_range * padding,
y_min - y_range * padding,
y_max + y_range * padding,
]
)
else:
# For regular axes, set x and y limits.
map_ax.set_xlim(x_min - x_range * padding, x_max + x_range * padding)
map_ax.set_ylim(y_min - y_range * padding, y_max + y_range * padding)
# Plot scatter, colorbar, and error connections
scatter = map_ax.scatter(
merged["x_true"],
merged["y_true"],
c=merged["error"],
cmap="RdYlBu_r",
s=20,
**({"label": "Test locations"} if plot_map else {}),
)
cbar = plt.colorbar(scatter, ax=map_ax, label=f"Error ({error_units})")
cbar.outline.set_visible(False)
for _, row in merged.iterrows():
map_ax.plot(
[row["x_true"], row["x_pred"]],
[row["y_true"], row["y_pred"]],
"k-",
linewidth=0.5,
alpha=0.5,
)
if plot_map and include_training_locs:
map_ax.legend(loc="upper right")
# Set up histogram panel (common to both layouts)
hist_ax = fig.add_subplot(gs[2] if plot_map else gs[1])
sns.histplot(data=merged, x="error", ax=hist_ax)
hist_ax.set_xlabel(f"Error ({error_units})", fontsize=14)
hist_ax.set_ylabel("Count", fontsize=14)
stats_text = (
f"Mean error: {merged['error'].mean():.2f} {error_units}\n"
f"Median error: {merged['error'].median():.2f} {error_units}\n"
f"Max error: {merged['error'].max():.2f} {error_units}\n"
f"R² (x): {np.corrcoef(merged['x_pred'], merged['x_true'])[0, 1] ** 2:.3f}\n"
f"R² (y): {np.corrcoef(merged['y_pred'], merged['y_true'])[0, 1] ** 2:.3f}"
)
hist_ax.text(
0.95,
0.95,
stats_text,
transform=hist_ax.transAxes,
verticalalignment="top",
horizontalalignment="right",
bbox=dict(facecolor="white", alpha=0.8),
fontsize=12,
)
plt.tight_layout()
if out_prefix:
plt.savefig(f"{out_prefix}_error_summary.png")
_handle_plot_display(show)
plt.close()
if return_merged:
return merged
return None
def plot_interactive_error_map(
predictions,
sample_data,
out_prefix=None,
width=1200,
height=600,
use_geodesic=True,
include_training_locs=True,
show_histogram=True,
):
"""Create an interactive map of prediction errors using Plotly.
This function creates an interactive visualization of prediction errors from
holdout analyses, with hover tooltips showing detailed information for each
sample. The plot uses Plotly's geographic features to display results on a
world map with coastlines and land features.
Args:
predictions (pandas.DataFrame): DataFrame with columns:
- ``sampleID``: Sample identifiers
- ``x_pred``: Predicted longitude
- ``y_pred``: Predicted latitude
sample_data (pandas.DataFrame or str): DataFrame or path to TSV file with columns:
- ``sampleID``: Sample identifiers (must match predictions)
- ``x``: True longitude
- ``y``: True latitude
out_prefix (str, optional): Prefix for output files. If provided, saves as
{out_prefix}_interactive_error_map.html. Default: None
width (int): Width of plot in pixels. Default: 1200
height (int): Height of plot in pixels. Default: 600
use_geodesic (bool): If True, calculate geodesic distances in kilometers.
If False, use Euclidean distances in coordinate units. Default: True
include_training_locs (bool): Whether to plot training locations (gray circles).
Default: True
show_histogram (bool): Whether to include error distribution histogram panel.
Default: True
Returns
-------
plotly.graph_objects.Figure: The interactive figure object. Can be displayed
directly in Jupyter notebooks with fig.show() or just 'fig' in a cell.
Examples
--------
Basic usage::
fig = plot_interactive_error_map(predictions, sample_data)
fig.show() # Display in notebook
Save to file::
plot_interactive_error_map(predictions, sample_data,
out_prefix="analysis")
Without histogram panel::
fig = plot_interactive_error_map(predictions, sample_data,
show_histogram=False)
"""
# Validate and load data
if predictions.empty:
raise ValueError("Predictions DataFrame cannot be empty")
# Load sample data
if isinstance(sample_data, pd.DataFrame):
samples = sample_data.copy()
elif isinstance(sample_data, (str, Path)):
sample_path = Path(sample_data)
if not sample_path.is_file():
raise ValueError(f"sample_data file {sample_data} does not exist")
samples = pd.read_csv(sample_path, sep="\t")
else:
raise ValueError("sample_data must be a DataFrame or a valid file path")
if samples.empty:
raise ValueError("Sample data cannot be empty")
# Validate columns
required_pred_cols = ["sampleID", "x_pred", "y_pred"]
required_sample_cols = ["sampleID", "x", "y"]
missing_pred_cols = [
col for col in required_pred_cols if col not in predictions.columns
]
missing_sample_cols = [
col for col in required_sample_cols if col not in samples.columns
]
if missing_pred_cols:
raise ValueError(f"Missing required columns in predictions: {missing_pred_cols}")
if missing_sample_cols:
raise ValueError(
f"Missing required columns in sample data: {missing_sample_cols}"
)
# Rename columns for clarity
samples = samples.rename(columns={"x": "x_true", "y": "y_true"})
# Merge predictions with true locations
merged = predictions.merge(samples[["sampleID", "x_true", "y_true"]], on="sampleID")
if merged.empty:
raise ValueError("No matching samples found between predictions and sample data")
# Calculate errors
if use_geodesic:
merged["error"] = merged.apply(
lambda row: (
geodesic(
(row["y_true"], row["x_true"]), (row["y_pred"], row["x_pred"])
).kilometers
),
axis=1,
)
error_units = "km"
else:
merged["error"] = np.sqrt(
(merged["x_pred"] - merged["x_true"]) ** 2
+ (merged["y_pred"] - merged["y_true"]) ** 2
)
error_units = "coordinate units"
# Create subplots
if show_histogram:
# Two-panel layout with map and histogram
fig = sp.make_subplots(
rows=1,
cols=2,
column_widths=[0.7, 0.3],
subplot_titles=("", "Error Distribution"),
horizontal_spacing=0.05,
specs=[[{"type": "geo"}, {"type": "xy"}]],
)
else:
# Single geo panel
fig = go.Figure()
# Determine bounds for the plot
if include_training_locs:
x_min, x_max = samples["x_true"].min(), samples["x_true"].max()
y_min, y_max = samples["y_true"].min(), samples["y_true"].max()
else:
x_min, x_max = merged["x_true"].min(), merged["x_true"].max()
y_min, y_max = merged["y_true"].min(), merged["y_true"].max()
padding = 0.1
x_range = x_max - x_min
y_range = y_max - y_min
x_min_padded = x_min - x_range * padding
x_max_padded = x_max + x_range * padding
y_min_padded = y_min - y_range * padding
y_max_padded = y_max + y_range * padding
# We'll configure geographic features in the layout section
# First, add training locations if requested
if include_training_locs:
training_mask = ~samples["sampleID"].isin(predictions["sampleID"])
training_locs = samples[training_mask]
if not training_locs.empty:
trace = go.Scattergeo(
lon=training_locs["x_true"],
lat=training_locs["y_true"],
mode="markers",
marker=dict(
size=4,
color="rgba(128, 128, 128, 0.5)", # Gray with alpha
symbol="circle-open",
line=dict(width=0.5, color="gray"),
),
name="Training locations",
hoverinfo="skip",
showlegend=True,
)
if show_histogram:
fig.add_trace(trace, row=1, col=1)
else:
fig.add_trace(trace)
# Add lines connecting true and predicted locations
for idx, row in merged.iterrows():
trace = go.Scattergeo(
lon=[row["x_true"], row["x_pred"], None],
lat=[row["y_true"], row["y_pred"], None],
mode="lines",
line=dict(color="black", width=1),
opacity=0.3,
showlegend=False,
hoverinfo="skip",
)
if show_histogram:
fig.add_trace(trace, row=1, col=1)
else:
fig.add_trace(trace)
# Add scatter plot of true locations colored by error
hover_text = merged.apply(
lambda row: (
f"Sample: {row['sampleID']}<br>"
f"Error: {row['error']:.2f} {error_units}<br>"
f"True: ({row['x_true']:.2f}, {row['y_true']:.2f})<br>"
f"Predicted: ({row['x_pred']:.2f}, {row['y_pred']:.2f})"
),
axis=1,
)
trace = go.Scattergeo(
lon=merged["x_true"],
lat=merged["y_true"],
mode="markers",
marker=dict(
size=8,
color=merged["error"],
colorscale="RdYlBu_r",
showscale=True,
colorbar=dict(
title=f"Error ({error_units})",
x=1.02,
xanchor="left",
len=0.8,
thickness=15,
yanchor="middle",
y=0.5,
),
line=dict(width=0.5, color="rgba(255,255,255,0.8)"), # White outline
),
text=hover_text,
hoverinfo="text",
name="Test locations",
showlegend=True,
)
if show_histogram:
fig.add_trace(trace, row=1, col=1)
else:
fig.add_trace(trace)
# Add histogram if requested
if show_histogram:
fig.add_trace(
go.Histogram(
x=merged["error"],
name="Error distribution",
showlegend=False,
marker=dict(color="steelblue", line=dict(color="white", width=1)),
opacity=0.8,
),
row=1,
col=2,
)
# Calculate statistics
mean_error = merged["error"].mean()
median_error = merged["error"].median()
max_error = merged["error"].max()
r2_x = np.corrcoef(merged["x_pred"], merged["x_true"])[0, 1] ** 2
r2_y = np.corrcoef(merged["y_pred"], merged["y_true"])[0, 1] ** 2
# Add statistics annotation
stats_text = (
f"Mean: {mean_error:.2f} {error_units}<br>"
f"Median: {median_error:.2f} {error_units}<br>"
f"Max: {max_error:.2f} {error_units}<br>"
f"R² (x): {r2_x:.3f}<br>"
f"R² (y): {r2_y:.3f}"
)
# Add statistics annotation
if show_histogram:
fig.add_annotation(
x=0.98,
y=0.98,
xref="x2 domain",
yref="y2 domain",
text=stats_text,
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="rgba(0, 0, 0, 0.2)",
borderwidth=1,
font=dict(size=11, family="monospace"),
align="left",
xanchor="right",
yanchor="top",
)
else:
# Place stats on the map
fig.add_annotation(
x=0.02,
y=0.98,
xref="paper",
yref="paper",
text=stats_text,
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="rgba(0, 0, 0, 0.2)",
borderwidth=1,
font=dict(size=11, family="monospace"),
align="left",
xanchor="left",
yanchor="top",
)
# Configure geo subplot
fig.update_geos(
projection_type="natural earth",
showland=True,
landcolor="rgb(243, 243, 243)",
coastlinecolor="rgb(204, 204, 204)",
coastlinewidth=0.5,
showlakes=True,
lakecolor="white",
showocean=True,
oceancolor="rgb(230, 245, 255)",
lataxis=dict(range=[y_min_padded, y_max_padded]),
lonaxis=dict(range=[x_min_padded, x_max_padded]),
bgcolor="white",
showcountries=True,
countrycolor="rgb(204, 204, 204)",
countrywidth=0.5,
showsubunits=False,
showframe=False,
resolution=50, # 50m resolution
)
# Histogram panel axes (if present)
if show_histogram:
fig.update_xaxes(title_text=f"Error ({error_units})", row=1, col=2)
fig.update_yaxes(title_text="Count", row=1, col=2)
fig.update_layout(
width=width,
height=height,
title_text="Prediction Error Map",
title_font_size=16,
showlegend=True,
legend=dict(
yanchor="top",
y=0.98,
xanchor="left",
x=0.02,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="rgba(0, 0, 0, 0.2)",
borderwidth=1,
font=dict(size=11),
),
template="plotly_white",
hovermode="closest",
plot_bgcolor="white",
paper_bgcolor="white",
margin=dict(l=10, r=10, t=40, b=10),
)
# Save to HTML file
if out_prefix:
output_file = f"{out_prefix}_interactive_error_map.html"
fig.write_html(output_file)
print(f"Interactive map saved to: {output_file}")
# Return the figure so it can be displayed in Jupyter notebooks
return fig
[docs]
def plot_sample_weights(
locator,
out_prefix=None,
plot_map=True,
width=5,
height=3,
dpi=300,
show=None,
):
"""Plot sample weights assigned to training locations.
Visualizes the geographic distribution of sample weights used during training.
This is useful for understanding which regions are upweighted or downweighted
based on sampling density.
Sample weights are typically computed using:
- Kernel density (KD) method: Upweights samples in sparse regions
- Histogram binning method: Based on 2D histogram counts
The plot uses a log-scale color mapping to better show weight variations.
Args:
locator (Locator): Locator instance that has been trained with sample
weighting enabled. Must have computed sample_weights attribute.
out_prefix (str, optional): Prefix for output files. If provided, saves as
{out_prefix}_sample_weights.png. Default: None
plot_map (bool): Whether to plot on a geographic map using cartopy projection.
If False, uses regular scatter plot with equal aspect ratio. Default: True
width (float): Figure width in inches. Default: 5
height (float): Figure height in inches. Default: 3
dpi (int): Figure resolution in dots per inch. Default: 300
show (bool or None): Whether to display plot. None=auto-detect environment,
True=always show, False=never show. Default: None
Returns
-------
None: Saves plot to file and optionally displays it
Raises
------
ValueError: If locator doesn't have computed sample weights, or if
required data is missing
Examples
--------
After training with KDE weighting::
config = {
"weight_samples": {
"enabled": True,
"method": "KD"
}
}
locator = Locator(config)
locator.train(genotypes, samples)
plot_sample_weights(locator, "kde_weights")
With histogram binning weights::
config = {
"weight_samples": {
"enabled": True,
"method": "hist",
"xbins": 20,
"ybins": 20
}
}
locator = Locator(config)
locator.train(genotypes, samples)
plot_sample_weights(locator, "hist_weights", plot_map=False)
Note:
- Requires that locator was trained with weight_samples enabled
- Log scale coloring helps visualize large weight variations
- Higher weights (yellow) indicate undersampled regions
- Lower weights (purple) indicate oversampled regions
- Map projection requires cartopy to be installed
"""
sample_data = locator._sample_data_df
sample_weights = locator.sample_weights["sample_weights_df"]
# Validate inputs
if sample_data.empty or sample_weights.empty:
raise ValueError("Sample data and weights cannot be empty DataFrames")
# Check for required columns
required_weight_cols = ["sampleID", "sample_weight"]
required_sample_cols = ["sampleID", "x", "y"]
missing_weight_cols = [
col for col in required_weight_cols if col not in sample_weights.columns
]
missing_sample_cols = [
col for col in required_sample_cols if col not in sample_data.columns
]
if missing_weight_cols:
raise ValueError(
f"Missing required columns in predictions: {missing_weight_cols}"
)
if missing_sample_cols:
raise ValueError(
f"Missing required columns in sample data: {missing_sample_cols}"
)
# Set larger font sizes globally
plt.rcParams.update(
{
"font.size": 12,
"axes.labelsize": 14,
"axes.titlesize": 14,
"xtick.labelsize": 12,
"ytick.labelsize": 12,
"legend.fontsize": 12,
}
)
# Load sample data if path provided
if isinstance(sample_data, pd.DataFrame):
samples = sample_data.copy()
else:
samples = pd.read_csv(sample_data, sep="\t")
# Merge predictions with true locations
merged = sample_weights.merge(samples, on="sampleID")
# Check if merge was successful
if merged.empty:
raise ValueError(
"No matching samples found between sample data and sample weights"
)
# Create figure
if plot_map:
fig = plt.figure(figsize=(width, height), dpi=dpi)
gs = fig.add_gridspec(1, 2)
ax1 = fig.add_subplot(gs[0:1], projection=ccrs.PlateCarree())
ax1.set_xticks([])
ax1.set_yticks([])
ax1.add_feature(cfeature.LAND, facecolor="lightgray")
ax1.add_feature(cfeature.COASTLINE, linewidth=0.5)
x_min, x_max = merged["x"].min(), merged["x"].max()
y_min, y_max = merged["y"].min(), merged["y"].max()
# Add padding to bounds
padding = 0.1
x_range = x_max - x_min
y_range = y_max - y_min
# Set map extent
ax1.set_extent(
[
x_min - x_range * padding,
x_max + x_range * padding,
y_min - y_range * padding,
y_max + y_range * padding,
]
)
# Plot predictions scatter with error colors
scatter = ax1.scatter(
merged["x"],
merged["y"],
c=merged["sample_weight"],
cmap="viridis",
s=10,
label="Training locations",
norm=matplotlib.colors.LogNorm(),
)
# Add colorbar
cbar = plt.colorbar(scatter, ax=ax1, label="Sample Weights")
cbar.outline.set_visible(False)
# plt.gca().set_aspect('equal')
#
# plt.tight_layout()
if out_prefix:
plt.savefig(f"{out_prefix}_sample_weights.png")
_handle_plot_display(show)
plt.close()
else:
# Create figure
fig = plt.figure(figsize=(width, height), dpi=dpi)
gs = fig.add_gridspec(1, 2)
# Create left panel (map + colorbar) without frame
ax1 = fig.add_subplot(gs[0])
# Calculate bounds with some padding
x_min, x_max = merged["x"].min(), merged["x"].max()
y_min, y_max = merged["y"].min(), merged["y"].max()
# Add padding to bounds
padding = 0.1
x_range = x_max - x_min
y_range = y_max - y_min
# Set map extent
ax1.set(
xlim=(x_min - x_range * padding, x_max + x_range * padding),
ylim=(y_min - y_range * padding, y_max + y_range * padding),
)
# Plot predictions scatter with error colors
scatter = ax1.scatter(
merged["x"],
merged["y"],
c=merged["sample_weight"],
cmap="viridis",
s=10,
label="Training locations",
norm=matplotlib.colors.LogNorm(),
)
cbar = plt.colorbar(scatter, ax=ax1, label="Sample Weights")
cbar.outline.set_visible(False)
plt.gca().set_aspect("equal")
#
# plt.tight_layout()
if out_prefix:
plt.savefig(f"{out_prefix}_sample_weights.png")
_handle_plot_display(show)
plt.close()
return None
[docs]
class PlottingMixin:
"""Mixin class providing plotting functionality for Locator.
This mixin is inherited by the main Locator class to provide visualization
methods for training history and Jupyter notebook integration.
Methods
-------
_repr_html_: Generate rich HTML representation for Jupyter notebooks
"""
def _repr_html_(self): # noqa: C901
"""Return HTML representation of Locator instance for Jupyter notebooks.
Generates a rich HTML display showing:
- Model configuration parameters
- Current model status (trained/not trained)
- Training history plot (if available)
- Data loading status
- Sample weighting information
- Holdout sample information
This method is automatically called by Jupyter/IPython when displaying
a Locator instance in a notebook cell.
Returns
-------
str: HTML string with styled content including embedded plots
Note:
- Training history plot is embedded as base64 PNG
- Holdout samples shown in collapsible list if > 0
- Automatically detects which data has been loaded
Example:
In a Jupyter notebook::
locator = Locator(config)
locator # Rich HTML display appears automatically
"""
html = [
"<div style='font-family: monospace'>",
"<h3>Locator Model</h3>",
"<table>",
"<tr><th style='text-align:left; padding:5px'>Configuration</th><th style='text-align:left; padding:5px'>Value</th></tr>",
]
# Add key configuration parameters
key_params = [
"train_split",
"batch_size",
"min_mac",
"max_SNPs",
"width",
"nlayers",
"dropout_prop",
"max_epochs",
"optimizer_algo",
"learning_rate",
"weight_decay",
"use_range_penalty",
"species_range_shapefile",
"resolution",
"penalty_weight",
"species_range_geom",
]
for param in key_params:
if param in self.config:
html.append(
f"<tr><td style='padding:5px'>{param}</td>"
f"<td style='padding:5px'>{self.config[param]}</td></tr>"
)
# add weight samples to end, deal with weird dictionary thing
if self.config.get("weight_samples", {}).get("enabled", False):
html.append(
f"<tr><td style='padding:5px'>{'weight_samples'}</td>"
f"<td style='padding:5px'>{'True'}</td></tr>"
)
for k in ["method", "xbins", "ybins", "lam", "bandwidth"]:
if k in self.config["weight_samples"].keys():
if self.config["weight_samples"][k] is not None:
html.append(
f"<tr><td style='padding:5px'>{'weight_samples ' + k}</td>"
f"<td style='padding:5px'>{self.config['weight_samples'][k]}</td></tr>"
)
html.append("</table>")
# Add model status
html.append("<h4>Status:</h4>")
html.append("<ul>")
# Model trained status and training history
if self.model is not None:
html.append("<li>Model: Trained ✓</li>")
if hasattr(self, "traingen"):
html.append(f"<li>Training samples: {self.traingen.shape[0]}</li>")
html.append(f"<li>Features: {self.traingen.shape[1]}</li>")
# Add training history if available
if hasattr(self, "history") and self.history is not None:
# Create figure
fig, ax = plt.subplots(figsize=(8, 4))
hist = self.history.history
# Plot training and validation loss
ax.plot(hist["loss"], label="Training Loss", color="blue")
axV = ax.twinx()
axV.plot(hist["val_loss"], label="Validation Loss", color="orange")
ax.set_xlabel("Epoch")
ax.set_ylabel("Training Loss")
axV.set_ylabel("Validation Loss")
ax.legend()
axV.legend(loc="upper center")
# Get final validation loss
final_val_loss = hist["val_loss"][-1]
# Convert plot to base64 string
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight")
buf.seek(0)
plot_data = base64.b64encode(buf.getvalue()).decode("utf-8")
plt.close(fig)
# Add plot and metrics to HTML
html.append("</ul>") # Close the status list
html.append("<h4>Training History:</h4>")
html.append(f"<p>Final validation loss: {final_val_loss:.4f}</p>")
html.append(
f'<img src="data:image/png;base64,{plot_data}" style="max-width:100%">'
)
html.append("<ul>") # Reopen list for remaining items
else:
html.append("<li>Model: Not trained</li>")
# Location normalization status
if all(
x is not None for x in [self.meanlong, self.sdlong, self.meanlat, self.sdlat]
):
html.append("<li>Location normalization: Computed ✓</li>")
else:
html.append("<li>Location normalization: Not computed</li>")
# Sample data status
if hasattr(self, "_sample_data_df"):
html.append(
f"<li>Sample data loaded: {len(self._sample_data_df)} samples</li>"
)
elif "sample_data" in self.config:
html.append("<li>Sample data: Path provided</li>")
else:
html.append("<li>Sample data: Not provided</li>")
# Genotype data status
if hasattr(self, "_genotype_df"):
html.append(
f"<li>Genotype data loaded: {self._genotype_df.shape[1]} SNPs</li>"
)
elif any(x in self.config for x in ["zarr", "vcf", "genotype_data"]):
html.append("<li>Genotype data: Path provided</li>")
else:
html.append("<li>Genotype data: Not provided</li>")
if hasattr(self, "sample_weights"):
html.append(
f"<li>Samples weighted using {self.config['weight_samples'].get('method')}</li>"
)
# Add holdout information
if hasattr(self, "holdout_idx") and self.samples is not None:
n_holdout = len(self.holdout_idx)
html.append(f"<li>Holdout samples: {n_holdout} samples held out</li>")
# Add collapsible list of held out sample IDs
if n_holdout > 0:
sample_list = self.samples[self.holdout_idx]
html.append("<li>Held out samples: <details>")
html.append("<summary>Click to show/hide</summary>")
html.append("<ul style='max-height:200px;overflow-y:auto'>")
for sample in sample_list:
html.append(f"<li>{sample}</li>")
html.append("</ul></details></li>")
elif hasattr(self, "pred_indices") and self.samples is not None:
n_holdout = len(self.pred_indices)
html.append(f"<li>Prediction samples: {n_holdout} samples held out</li>")
# Add collapsible list of held out sample IDs
if n_holdout > 0:
sample_list = self.samples[self.pred_indices]
html.append("<li>Held out samples: <details>")
html.append("<summary>Click to show/hide</summary>")
html.append("<ul style='max-height:200px;overflow-y:auto'>")
for sample in sample_list:
html.append(f"<li>{sample}</li>")
html.append("</ul></details></li>")
html.append("</ul>")
html.append("</div>")
return "".join(html)