Source code for lbm_suite2p_python.grid_search

"""
grid search module for parameter optimization.

provides functions to:
- run grid search over suite2p/cellpose parameters
- collect and compare results across parameter combinations
- visualize quality metrics and detection results
"""

import copy
import shutil
from itertools import product
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import skew

from lbm_suite2p_python.postprocessing import (
    load_ops,
    load_planar_results,
    dff_shot_noise,
)
from lbm_suite2p_python.zplane import (
    plot_zplane_figures,
)

# registration parameters that require re-registration when changed
REGISTRATION_PARAMS = {
    "do_registration", "two_step_registration", "keep_movie_raw",
    "nimg_init", "batch_size", "maxregshift", "align_by_chan",
    "reg_tif", "reg_tif_chan2", "subpixel",
    "smooth_sigma_time", "smooth_sigma",
    "th_badframes", "norm_frames", "force_refImg", "pad_fft",
    "nonrigid", "block_size", "snr_thresh", "maxregshiftNR",
    "1Preg", "spatial_hp_reg", "pre_smooth", "spatial_taper",
}






def _combo_to_tag(combo_dict: dict) -> str:
    """Convert parameter combo dict to folder name tag."""
    tag_parts = [
        f"{k[:3]}{v:.2f}" if isinstance(v, float) else f"{k[:3]}{v}"
        for k, v in combo_dict.items()
    ]
    return "_".join(tag_parts)


[docs] def collect_grid_results( save_path: Path | str, grid_params: dict = None, ) -> pd.DataFrame: """ Collect quality metrics from all grid search combinations. Parameters ---------- save_path : Path or str Root directory containing grid search results. grid_params : dict, optional Grid parameters dict to extract parameter values from ops. If None, only combo name is included. Returns ------- pd.DataFrame DataFrame with one row per combination, containing: - combo: folder name - n_accepted, n_rejected: cell counts - snr_median, snr_iqr: signal-to-noise ratio - skew_median, skew_iqr: trace skewness - noise_median, noise_iqr: shot noise - parameter columns from grid_params """ save_path = Path(save_path) results = [] for combo_dir in sorted(save_path.iterdir()): if not combo_dir.is_dir(): continue if combo_dir.name in ("_base", "__pycache__"): continue ops_file = combo_dir / "ops.npy" if not ops_file.exists(): # check subdirectories for subdir in combo_dir.iterdir(): if subdir.is_dir() and (subdir / "ops.npy").exists(): ops_file = subdir / "ops.npy" break else: continue try: metrics = compute_combo_metrics(ops_file) result = {"combo": combo_dir.name, "path": str(ops_file), **metrics} # add grid parameters if grid_params: loaded_ops = load_ops(ops_file) for param in grid_params.keys(): result[param] = loaded_ops.get(param) results.append(result) except Exception as e: print(f"Skipping {combo_dir.name}: {e}") if not results: print("No results found!") return pd.DataFrame() df = pd.DataFrame(results) df = df.sort_values("snr_median", ascending=False) return df
def compute_combo_metrics(ops_path: Path | str, neuropil_coef: float = 0.7) -> dict: """ Compute quality metrics for a single parameter combination. Parameters ---------- ops_path : Path or str Path to ops.npy file or directory containing it. neuropil_coef : float, default 0.7 Neuropil correction coefficient. Returns ------- dict Dictionary containing: - n_accepted, n_rejected: cell counts - snr_median, snr_iqr: signal-to-noise ratio statistics - skew_median, skew_iqr: skewness statistics - noise_median, noise_iqr: shot noise statistics """ res = load_planar_results(ops_path) ops = load_ops(ops_path) fs = ops.get("fs", 30.0) iscell = res["iscell"] mask = iscell[:, 0].astype(bool) if iscell.ndim == 2 else iscell.astype(bool) n_accepted = mask.sum() n_rejected = len(mask) - n_accepted if n_accepted == 0: return { "n_accepted": 0, "n_rejected": n_rejected, "snr_median": np.nan, "snr_iqr": np.nan, "skew_median": np.nan, "skew_iqr": np.nan, "noise_median": np.nan, "noise_iqr": np.nan, } F = res["F"][mask] Fneu = res["Fneu"][mask] stat = res["stat"][mask] if isinstance(res["stat"], np.ndarray) else [ s for s, m in zip(res["stat"], mask) if m ] # neuropil-corrected fluorescence and dF/F F_corr = F - neuropil_coef * Fneu baseline = np.percentile(F_corr, 20, axis=1, keepdims=True) baseline = np.maximum(baseline, 1e-6) dff = (F_corr - baseline) / baseline # snr: signal std / noise (MAD estimator) signal = np.std(dff, axis=1) noise_est = np.median(np.abs(np.diff(dff, axis=1)), axis=1) / 0.6745 snr = signal / (noise_est + 1e-6) # shot noise using existing function shot_noise = dff_shot_noise(dff, fs) # skewness from stat if available, else compute skewness = [] for i, s in enumerate(stat): if isinstance(s, dict) and "skew" in s: skewness.append(s["skew"]) else: skewness.append(skew(dff[i])) skewness = np.array(skewness) return { "n_accepted": n_accepted, "n_rejected": n_rejected, "snr_median": np.median(snr), "snr_iqr": np.percentile(snr, 75) - np.percentile(snr, 25), "skew_median": np.median(skewness), "skew_iqr": np.percentile(skewness, 75) - np.percentile(skewness, 25), "noise_median": np.median(shot_noise), "noise_iqr": np.percentile(shot_noise, 75) - np.percentile(shot_noise, 25), } def get_best_parameters(df: pd.DataFrame, grid_params: dict = None) -> dict: """ Find best parameter combinations by different criteria. Parameters ---------- df : pd.DataFrame Results DataFrame from collect_grid_results(). grid_params : dict, optional Grid parameters to include in output. Returns ------- dict Dictionary with keys 'best_snr', 'best_skew', 'best_noise', each containing the best row as a dict. """ if len(df) == 0: return {} result = {} # best by SNR (highest) best_snr_idx = df["snr_median"].idxmax() result["best_snr"] = df.loc[best_snr_idx].to_dict() # best by skewness (highest = more events) best_skew_idx = df["skew_median"].idxmax() result["best_skew"] = df.loc[best_skew_idx].to_dict() # best by noise (lowest) best_noise_idx = df["noise_median"].idxmin() result["best_noise"] = df.loc[best_noise_idx].to_dict() # best by cell count (highest) best_count_idx = df["n_accepted"].idxmax() result["best_count"] = df.loc[best_count_idx].to_dict() return result def print_best_parameters(df: pd.DataFrame, grid_params: dict = None): """Print summary of best parameters by different criteria.""" best = get_best_parameters(df, grid_params) if not best: print("No results to analyze.") return print("Best Parameters by Different Criteria:") print("=" * 60) for criterion, row in best.items(): name = criterion.replace("best_", "").upper() print(f"\n{name}: {row['combo']}") print(f" Cells: {row['n_accepted']}") print(f" SNR: {row['snr_median']:.3f}") print(f" Skewness: {row['skew_median']:.3f}") print(f" Shot Noise: {row['noise_median']:.4f}") if grid_params: for p in grid_params: if p in row: print(f" {p}: {row[p]}")
[docs] def plot_grid_metrics( df: pd.DataFrame, grid_params: dict = None, save_path: Path | str = None, figsize: tuple = (15, 10), ): """ Plot quality metrics comparison across grid search combinations. Parameters ---------- df : pd.DataFrame Results DataFrame from collect_grid_results(). grid_params : dict, optional Grid parameters for parameter effect plots. save_path : Path or str, optional Path to save figure. If None, displays with plt.show(). figsize : tuple, default (15, 10) Figure size. Returns ------- matplotlib.figure.Figure """ if len(df) == 0: print("No results to plot.") return None plt.style.use("dark_background") fig, axes = plt.subplots(2, 3, figsize=figsize) metrics = [ ("snr_median", "SNR (median)", "higher is better", "#2ecc71"), ("skew_median", "Skewness (median)", "higher = more events", "#9b59b6"), ("noise_median", "Shot Noise (median)", "lower is better", "#e74c3c"), ] # row 1: metrics by combination for col, (metric, label, note, color) in enumerate(metrics): ax = axes[0, col] x = range(len(df)) ax.bar(x, df[metric], color=color, alpha=0.8, edgecolor="white") ax.set_xticks(x) ax.set_xticklabels(df["combo"], rotation=45, ha="right", fontsize=8) ax.set_ylabel(label) ax.set_title(f"{label}\n({note})") ax.axhline(df[metric].median(), color="white", linestyle="--", alpha=0.5) # row 2: parameter effects on SNR params = list(grid_params.keys()) if grid_params else [] for col, param in enumerate(params[:3]): ax = axes[1, col] if param not in df.columns: ax.axis("off") continue grouped = df.groupby(param)["snr_median"].agg(["mean", "std"]).reset_index() x = range(len(grouped)) ax.bar( x, grouped["mean"], yerr=grouped["std"].fillna(0), color="#3498db", alpha=0.8, edgecolor="white", capsize=5 ) ax.set_xticks(x) ax.set_xticklabels([str(v) for v in grouped[param]]) ax.set_xlabel(param, fontweight="bold") ax.set_ylabel("SNR (mean ± std)") ax.set_title(f"Effect of {param} on SNR") # hide unused subplots for col in range(len(params), 3): axes[1, col].axis("off") plt.suptitle("Grid Search Quality Metrics", fontsize=14, fontweight="bold", y=1.02) plt.tight_layout() if save_path: save_path = Path(save_path) plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="black") plt.close(fig) print(f"Saved: {save_path}") else: plt.show() return fig
def plot_grid_distributions( df: pd.DataFrame, results_dir: Path | str, n_top: int = 4, save_path: Path | str = None, figsize: tuple = (15, 5), ): """ Plot distributions of quality metrics for top combinations. Parameters ---------- df : pd.DataFrame Results DataFrame from collect_grid_results(). results_dir : Path or str Root directory containing grid search results. n_top : int, default 4 Number of top combinations to include. save_path : Path or str, optional Path to save figure. figsize : tuple, default (15, 5) Figure size. Returns ------- matplotlib.figure.Figure """ if len(df) == 0: print("No results to plot.") return None results_dir = Path(results_dir) top_combos = df.head(n_top)["combo"].tolist() plt.style.use("dark_background") fig, axes = plt.subplots(1, 3, figsize=figsize) colors = ["#2ecc71", "#3498db", "#e74c3c", "#f39c12", "#9b59b6", "#1abc9c"] for combo, color in zip(top_combos, colors): combo_dir = results_dir / combo ops_file = combo_dir / "ops.npy" if not ops_file.exists(): continue try: res = load_planar_results(ops_file) ops = load_ops(ops_file) fs = ops.get("fs", 30.0) mask = res["iscell"][:, 0].astype(bool) if mask.sum() == 0: continue F = res["F"][mask] Fneu = res["Fneu"][mask] stat = [s for s, m in zip(res["stat"], mask) if m] F_corr = F - 0.7 * Fneu baseline = np.percentile(F_corr, 20, axis=1, keepdims=True) baseline = np.maximum(baseline, 1e-6) dff = (F_corr - baseline) / baseline # compute metrics signal = np.std(dff, axis=1) noise = np.median(np.abs(np.diff(dff, axis=1)), axis=1) / 0.6745 snr = signal / (noise + 1e-6) shot_noise = dff_shot_noise(dff, fs) skewness = np.array([ s.get("skew", skew(dff[i])) for i, s in enumerate(stat) ]) # plot distributions axes[0].hist(snr, bins=30, alpha=0.5, color=color, label=combo, density=True) axes[1].hist(skewness, bins=30, alpha=0.5, color=color, label=combo, density=True) axes[2].hist(shot_noise, bins=30, alpha=0.5, color=color, label=combo, density=True) except Exception as e: print(f"Error loading {combo}: {e}") axes[0].set_xlabel("SNR") axes[0].set_ylabel("Density") axes[0].set_title("SNR Distribution") axes[0].legend(fontsize=8) axes[1].set_xlabel("Skewness") axes[1].set_ylabel("Density") axes[1].set_title("Skewness Distribution") axes[2].set_xlabel("Shot Noise") axes[2].set_ylabel("Density") axes[2].set_title("Shot Noise Distribution") plt.suptitle(f"Quality Metric Distributions (Top {n_top} by SNR)", fontweight="bold", y=1.02) plt.tight_layout() if save_path: save_path = Path(save_path) plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="black") plt.close(fig) print(f"Saved: {save_path}") else: plt.show() return fig def plot_grid_masks( df: pd.DataFrame, results_dir: Path | str, n_top: int = 4, img_key: str = "meanImg", save_path: Path | str = None, figsize: tuple = (12, 12), ): """ Plot detection masks for top combinations side-by-side. Parameters ---------- df : pd.DataFrame Results DataFrame from collect_grid_results(). results_dir : Path or str Root directory containing grid search results. n_top : int, default 4 Number of top combinations to show. img_key : str, default "meanImg" Background image key from ops. save_path : Path or str, optional Path to save figure. figsize : tuple, default (12, 12) Figure size. Returns ------- matplotlib.figure.Figure """ if len(df) == 0: print("No results to plot.") return None results_dir = Path(results_dir) top_combos = df.head(n_top) # determine grid layout n_cols = min(n_top, 2) n_rows = (n_top + n_cols - 1) // n_cols plt.style.use("dark_background") fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize) if n_top == 1: axes = np.array([[axes]]) elif n_rows == 1: axes = axes.reshape(1, -1) elif n_cols == 1: axes = axes.reshape(-1, 1) axes = axes.flatten() for ax, (_, row) in zip(axes, top_combos.iterrows()): combo_dir = results_dir / row["combo"] ops_file = combo_dir / "ops.npy" if not ops_file.exists(): ax.axis("off") continue try: ops = load_ops(ops_file) img = ops.get(img_key, ops.get("refImg", np.zeros((512, 512)))) ax.imshow( img, cmap="gray", vmin=np.percentile(img, 1), vmax=np.percentile(img, 99) ) # draw ROIs stat = np.load(combo_dir / "stat.npy", allow_pickle=True) iscell = np.load(combo_dir / "iscell.npy")[:, 0].astype(bool) for i, s in enumerate(stat): if iscell[i]: ax.scatter(s["xpix"], s["ypix"], s=0.1, c="lime", alpha=0.3) title = f"{row['combo']}\n{row['n_accepted']} cells, SNR={row['snr_median']:.2f}" ax.set_title(title, fontsize=10) except Exception as e: ax.set_title(f"{row['combo']}\nError: {e}", fontsize=8) ax.axis("off") # hide unused axes for ax in axes[len(top_combos):]: ax.axis("off") plt.suptitle(f"Top {n_top} by SNR", fontsize=14, fontweight="bold") plt.tight_layout() if save_path: save_path = Path(save_path) plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="black") plt.close(fig) print(f"Saved: {save_path}") else: plt.show() return fig
[docs] def save_grid_results(df: pd.DataFrame, save_path: Path | str): """Save grid search results to CSV.""" save_path = Path(save_path) csv_path = save_path / "grid_search_results.csv" df.to_csv(csv_path, index=False) print(f"Saved: {csv_path}") return csv_path