Source code for lbm_suite2p_python.zplane

from pathlib import Path
from typing import Tuple

import numpy as np
import pandas as pd
import tifffile
import math

import matplotlib.offsetbox
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.lines import Line2D
from matplotlib.patches import Rectangle
from matplotlib.offsetbox import VPacker, HPacker, DrawingArea
import matplotlib.gridspec as gridspec

from scipy.ndimage import distance_transform_edt

from mbo_utilities.metadata import get_param

from lbm_suite2p_python.postprocessing import (
    load_ops,
    load_planar_results,
    dff_rolling_percentile,
    dff_shot_noise,
    compute_trace_quality_score,
)
from lbm_suite2p_python.utils import (
    _resize_masks_fit_crop,
    bin1d,
)


def infer_units(f: np.ndarray) -> str:
    """
    Infer calcium imaging signal type from array values:
    - 'raw': values in hundreds or thousands
    - 'dff': unitless ΔF/F₀, typically ~0–1
    - 'dff-percentile': ΔF/F₀ in percent, typically ~10–100

    Returns one of: 'raw', 'dff', 'dff-percentile'
    """
    f = np.asarray(f)
    if np.issubdtype(f.dtype, np.integer):
        return "raw"

    p1, p50, p99 = np.nanpercentile(f, [1, 50, 99])

    if p99 > 500 or p50 > 100:
        return "raw"
    elif 5 < p1 < 30 and 20 < p50 < 60 and 40 < p99 < 100:
        return "dffp"
    elif 0.1 < p1 < 0.2 < p50 < 0.5 < p99 < 1.0:
        return "dff"
    else:
        return "unknown"


def format_time(t):
    """
    Format a time value in seconds to a human-readable string.

    Parameters
    ----------
    t : float
        Time in seconds.

    Returns
    -------
    str
        Formatted time string (e.g., "30 s", "5 min", "2 h").
    """
    if t < 60:
        return f"{int(np.ceil(t))} s"
    elif t < 3600:
        return f"{int(round(t / 60))} min"
    else:
        return f"{int(round(t / 3600))} h"


def get_color_permutation(n):
    """
    Generate a permutation of indices for visually distinct color ordering.

    Uses a coprime step to spread colors evenly across the color space.

    Parameters
    ----------
    n : int
        Number of items to permute.

    Returns
    -------
    list
        Permuted indices [0, n-1].
    """
    for s in range(n // 2 + 1, n):
        if math.gcd(s, n) == 1:
            return [(i * s) % n for i in range(n)]
    return list(range(n))


class AnchoredHScaleBar(matplotlib.offsetbox.AnchoredOffsetbox):
    """
    create an anchored horizontal scale bar.

    parameters
    ----------
    size : float, optional
        bar length in data units (fixed; default is 1).
    label : str, optional
        text label (default is "").
    loc : int, optional
        location code (default is 2).
    ax : axes, optional
        axes to attach the bar (default uses current axes).
    pad, borderpad, ppad, sep : float, optional
        spacing parameters.
    linekw : dict, optional
        line properties.
    """

    def __init__(
        self,
        size=1,
        label="",
        loc=2,
        ax=None,
        pad=0.4,
        borderpad=0.5,
        ppad=0,
        sep=2,
        prop=None,
        frameon=True,
        linekw=None,
        **kwargs,
    ):
        if linekw is None:
            linekw = {}
        if ax is None:
            ax = plt.gca()
        # trans = ax.get_xaxis_transform()
        trans = ax.transAxes

        size_bar = matplotlib.offsetbox.AuxTransformBox(trans)
        line = Line2D([0, size], [0, 0], **linekw)
        size_bar.add_artist(line)
        txt = matplotlib.offsetbox.TextArea(label)
        self.txt = txt
        self.vpac = VPacker(children=[size_bar, txt], align="center", pad=ppad, sep=sep)
        super().__init__(
            loc,  # noqa
            pad=pad,
            borderpad=borderpad,
            child=self.vpac,
            prop=prop,
            frameon=frameon,
            **kwargs,
        )


class AnchoredVScaleBar(matplotlib.offsetbox.AnchoredOffsetbox):
    """
    Create an anchored vertical scale bar.

    Parameters
    ----------
    height : float, optional
        Bar height in data units (default is 1).
    label : str, optional
        Text label (default is "").
    loc : int, optional
        Location code (default is 2).
    ax : axes, optional
        Axes to attach the bar (default uses current axes).
    pad, borderpad, ppad, sep : float, optional
        Spacing parameters.
    linekw : dict, optional
        Line properties.
    spacer_width : float, optional
        Width of spacer between bar and text.
    """

    def __init__(
        self,
        height=1,
        label="",
        loc=2,
        ax=None,
        pad=0.4,
        borderpad=0.5,
        ppad=0,
        sep=2,
        prop=None,
        frameon=True,
        linekw=None,
        spacer_width=6,
        **kwargs,
    ):
        if ax is None:
            ax = plt.gca()
        if linekw is None:
            linekw = {}
        trans = ax.transAxes

        size_bar = matplotlib.offsetbox.AuxTransformBox(trans)
        line = Line2D([0, 0], [0, height], **linekw)
        size_bar.add_artist(line)

        txt = matplotlib.offsetbox.TextArea(
            label, textprops=dict(rotation=90, ha="left", va="bottom")
        )
        self.txt = txt

        spacer = DrawingArea(spacer_width, 0, 0, 0)
        self.hpac = HPacker(
            children=[size_bar, spacer, txt], align="bottom", pad=ppad, sep=sep
        )
        super().__init__(
            loc,  # noqa
            pad=pad,
            borderpad=borderpad,
            child=self.hpac,
            prop=prop,
            frameon=frameon,
            **kwargs,
        )


def plot_traces_noise(
    dff_noise,
    colors,
    fps=17.0,
    window=220,
    savepath=None,
    title="Trace Noise",
    lw=0.5,
):
    """
    Plot stacked noise traces in the same style as plot_traces.

    Parameters
    ----------
    dff_noise : ndarray
        Noise traces, shape (n_neurons, n_timepoints).
    colors : ndarray
        Colormap array returned from plot_traces(return_color=True).
    fps : float
        Sampling rate, Hz.
    window : float
        Time window (seconds) to display.
    savepath : str or Path, optional
        If given, save to file.
    title : str
        Title for figure.
    lw : float
        Line width.
    """

    n_neurons, n_timepoints = dff_noise.shape
    data_time = np.arange(n_timepoints) / fps
    current_frame = min(int(window * fps), n_timepoints - 1)

    # auto offset based on noise traces
    p10 = np.percentile(dff_noise[:, : current_frame + 1], 10, axis=1)
    p90 = np.percentile(dff_noise[:, : current_frame + 1], 90, axis=1)
    offset = np.median(p90 - p10) * 1.2

    fig, ax = plt.subplots(figsize=(10, 6), facecolor="black")
    ax.set_facecolor("black")
    ax.tick_params(axis="x", which="both", labelbottom=False, length=0, colors="white")
    ax.tick_params(axis="y", which="both", labelleft=False, length=0, colors="white")
    for spine in ax.spines.values():
        spine.set_visible(False)

    for i in reversed(range(n_neurons)):
        trace = dff_noise[i, : current_frame + 1]
        shifted_trace = trace + i * offset
        ax.plot(
            data_time[: current_frame + 1],
            shifted_trace,
            color=colors[i],
            lw=lw,
            zorder=-i,
        )

    if title:
        fig.suptitle(title, fontsize=16, fontweight="bold", color="white")

    if savepath:
        plt.savefig(savepath, dpi=200, facecolor=fig.get_facecolor())
        plt.close(fig)
    else:
        plt.show()


[docs] def plot_traces( f, save_path: str | Path = "", cell_indices: np.ndarray | list[int] | None = None, fps=17.0, num_neurons=20, window=220, title="", offset=None, lw=0.5, cmap="tab10", scale_bar_unit: str = None, mask_overlap: bool = True, ) -> None: """ Plot stacked fluorescence traces with automatic offset and scale bars. Parameters ---------- f : ndarray 2d array of fluorescence traces (n_neurons x n_timepoints). save_path : str, optional Path to save the output plot. fps : float Sampling rate in frames per second. num_neurons : int Number of neurons to display if cell_indices is None. window : float Time window (in seconds) to display. title : str Title of the figure. offset : float or None Vertical offset between traces; if None, computed automatically. lw : float Line width for data points. cmap : str Matplotlib colormap string. scale_bar_unit : str, optional Unit suffix for the vertical scale bar (e.g., "% ΔF/F₀", "a.u."). The numeric value is computed automatically based on the plot's vertical scale. If None, inferred from data range. cell_indices : array-like or None Specific cell indices to plot. If provided, overrides num_neurons. mask_overlap : bool, default True If True, lower traces mask (occlude) traces above them, creating a layered effect where each trace has a black background. """ if isinstance(f, dict): raise ValueError("f must be a numpy array, not a dictionary") n_timepoints = f.shape[-1] data_time = np.arange(n_timepoints) / fps current_frame = min(int(window * fps), n_timepoints - 1) if cell_indices is None: displayed_neurons = min(num_neurons, f.shape[0]) indices = np.arange(displayed_neurons) else: indices = np.array(cell_indices) if indices.dtype == bool: indices = np.where(indices)[0] # convert boolean mask to int indices displayed_neurons = len(indices) if len(indices) == 0: return None if offset is None: p10 = np.percentile(f[indices, : current_frame + 1], 10, axis=1) p90 = np.percentile(f[indices, : current_frame + 1], 90, axis=1) offset = np.median(p90 - p10) * 1.2 # Ensure minimum offset to prevent trace overlap min_offset = np.percentile(p90 - p10, 75) * 0.8 offset = max(offset, min_offset, 1e-6) # Absolute minimum to prevent divide-by-zero cmap_inst = plt.get_cmap(cmap) colors = cmap_inst(np.linspace(0, 1, displayed_neurons)) perm = get_color_permutation(displayed_neurons) colors = colors[perm] # fig, ax = plt.subplots(figsize=(10, 6), facecolor="black") # ax.set_facecolor("black") # Build shifted traces array (no masking - let z-order handle overlap) shifted_traces = np.zeros((displayed_neurons, current_frame + 1)) for i in range(displayed_neurons): trace = f[indices[i], : current_frame + 1] baseline = np.percentile(trace, 8) shifted_traces[i] = (trace - baseline) + i * offset # Plot traces with z-ordering (lower traces on top via higher zorder) fig, ax = plt.subplots(figsize=(10, 6), facecolor="black") ax.set_facecolor("black") ax.tick_params(axis="x", which="both", labelbottom=False, length=0, colors="white") ax.tick_params(axis="y", which="both", labelleft=False, length=0, colors="white") for spine in ax.spines.values(): spine.set_visible(False) # Plot from top to bottom so lower-indexed traces appear on top time_slice = data_time[: current_frame + 1] for i in range(displayed_neurons - 1, -1, -1): z = displayed_neurons - i # Lower index = higher zorder = on top if mask_overlap: # Fill below trace with black to mask traces above ax.fill_between( time_slice, shifted_traces[i], y2=shifted_traces[i].min() - offset, color="black", zorder=z - 0.5, ) ax.plot( time_slice, shifted_traces[i], color=colors[i], lw=lw, zorder=z, ) time_bar_length = 0.1 * window if time_bar_length < 60: time_label = f"{time_bar_length:.0f} s" elif time_bar_length < 3600: time_label = f"{time_bar_length / 60:.0f} min" else: time_label = f"{time_bar_length / 3600:.1f} hr" # Set y-limits with small padding (no extra space for scalebars - they go outside) y_min = np.min(shifted_traces) y_max = np.max(shifted_traces) y_range = y_max - y_min ax.set_ylim(y_min - y_range * 0.02, y_max + y_range * 0.02) # Compute vertical scale bar value (10% of y-range in data units) scale_bar_height_frac = 0.10 # 10% of axes height scale_bar_data_value = y_range * scale_bar_height_frac # Use provided unit or default to "a.u." if scale_bar_unit is None: scale_bar_unit = "a.u." # Format the scale bar label with computed value if scale_bar_data_value >= 100: scale_bar_label = f"{int(round(scale_bar_data_value, -1))} {scale_bar_unit}" elif scale_bar_data_value >= 10: scale_bar_label = f"{int(round(scale_bar_data_value))} {scale_bar_unit}" elif scale_bar_data_value >= 1: scale_bar_label = f"{scale_bar_data_value:.0f} {scale_bar_unit}" else: scale_bar_label = f"{scale_bar_data_value:.2f} {scale_bar_unit}" # Adjust subplot to make room for scalebars at bottom and right fig.subplots_adjust(bottom=0.12, right=0.88) linekw = dict(color="white", linewidth=3) # Time scale bar - use fig.text for fixed position below axes # Get axes position in figure coordinates ax_pos = ax.get_position() time_bar_x = ax_pos.x1 - 0.02 # right side of axes time_bar_y = 0.07 # fixed position just below axes # Draw horizontal line for time scale bar line_width_fig = 0.08 # width in figure coords fig.add_artist(plt.Line2D( [time_bar_x - line_width_fig, time_bar_x], [time_bar_y, time_bar_y], transform=fig.transFigure, color="white", linewidth=3, clip_on=False, )) # Add time label fig.text( time_bar_x - line_width_fig / 2, time_bar_y - 0.02, time_label, ha="center", va="top", color="white", fontsize=10, transform=fig.transFigure, ) # Vertical scale bar - positioned just outside right edge, bottom aligned with x-axis vsb = AnchoredVScaleBar( height=scale_bar_height_frac, label=scale_bar_label, loc="lower right", frameon=False, pad=0.5, sep=4, linekw=linekw, ax=ax, spacer_width=0, ) # Position just outside right edge of axes, bottom at y=0 vsb.set_bbox_to_anchor((1.02, 0.0), transform=ax.transAxes) vsb.txt._text.set_color("white") ax.add_artist(vsb) if title: fig.suptitle(title, fontsize=16, fontweight="bold", color="white") ax.set_ylabel( f"Neuron Count: {displayed_neurons}", fontsize=10, fontweight="bold", color="white", labelpad=5, ) if save_path: plt.savefig(save_path, dpi=200, facecolor=fig.get_facecolor()) plt.close(fig) else: plt.show() return None
def animate_traces( f, save_path="./traces.mp4", cell_indices=None, fps=17.0, num_neurons=20, window=30, title="", offset=None, lw=0.5, cmap="tab10", scale_bar_unit=None, mask_overlap=True, anim_fps=30, speed=1.0, dpi=150, ): """ Animated version of plot_traces - scrolling window through time. Creates an mp4 video showing traces scrolling like an oscilloscope display. Visual style matches plot_traces exactly. Parameters ---------- f : ndarray 2d array of fluorescence traces (n_neurons x n_timepoints). save_path : str or Path, default "./traces.mp4" Output path for the animation. cell_indices : array-like or None Specific cell indices to plot. If provided, overrides num_neurons. fps : float, default 17.0 Data frame rate in Hz. num_neurons : int, default 20 Number of neurons to display if cell_indices is None. window : float, default 30 Time window width in seconds. title : str, default "" Title for the animation. offset : float or None Vertical offset between traces; if None, computed automatically. lw : float, default 0.5 Line width for traces. cmap : str, default "tab10" Matplotlib colormap. scale_bar_unit : str, optional Unit suffix for vertical scale bar. If None, uses "a.u.". mask_overlap : bool, default True If True, lower traces mask traces above them. anim_fps : int, default 30 Animation frame rate (frames per second in output video). speed : float, default 1.0 Playback speed multiplier (1.0 = real-time, 2.0 = 2x speed). dpi : int, default 150 Output video resolution. Returns ------- str Path to saved animation. """ if isinstance(f, dict): raise ValueError("f must be a numpy array, not a dictionary") n_total, n_timepoints = f.shape data_time = np.arange(n_timepoints) / fps total_duration = data_time[-1] window_frames = int(window * fps) # select neurons if cell_indices is None: displayed_neurons = min(num_neurons, n_total) indices = np.arange(displayed_neurons) else: indices = np.array(cell_indices) if indices.dtype == bool: indices = np.where(indices)[0] displayed_neurons = len(indices) if len(indices) == 0: print("No neurons to display") return None # pre-compute baselines and offset (once, not per frame) baselines = np.percentile(f[indices], 8, axis=1) if offset is None: p10 = np.percentile(f[indices], 10, axis=1) p90 = np.percentile(f[indices], 90, axis=1) offset = np.median(p90 - p10) * 1.2 min_offset = np.percentile(p90 - p10, 75) * 0.8 offset = max(offset, min_offset, 1e-6) # colors - use same permutation as plot_traces cmap_inst = plt.get_cmap(cmap) colors = cmap_inst(np.linspace(0, 1, displayed_neurons)) perm = get_color_permutation(displayed_neurons) colors = colors[perm] # compute y-range based on expected stacked layout # each trace spans ~offset, so total height is roughly (n_neurons * offset) plus some headroom # use per-trace percentiles to avoid outliers trace_ranges = [] for idx in range(displayed_neurons): trace = f[indices[idx]] - baselines[idx] # use 1st/99th percentile to ignore spikes p1, p99 = np.percentile(trace, [1, 99]) trace_ranges.append(p99 - p1) median_trace_range = np.median(trace_ranges) # y-range: stack height plus headroom for trace fluctuations y_min_global = -median_trace_range * 0.5 y_max_global = (displayed_neurons - 1) * offset + median_trace_range * 1.5 y_range = y_max_global - y_min_global # ensure minimum range if y_range < 1e-6: y_range = 1.0 # setup figure (matches plot_traces) fig, ax = plt.subplots(figsize=(10, 6), facecolor="black") ax.set_facecolor("black") ax.tick_params(axis="x", which="both", labelbottom=False, length=0, colors="white") ax.tick_params(axis="y", which="both", labelleft=False, length=0, colors="white") for spine in ax.spines.values(): spine.set_visible(False) fig.subplots_adjust(bottom=0.15, right=0.85, left=0.10, top=0.92) # scale bar labels time_bar_length = 0.1 * window if time_bar_length < 60: time_label = f"{time_bar_length:.0f} s" elif time_bar_length < 3600: time_label = f"{time_bar_length / 60:.0f} min" else: time_label = f"{time_bar_length / 3600:.1f} hr" scale_bar_height_frac = 0.10 scale_bar_data_value = y_range * scale_bar_height_frac if scale_bar_unit is None: scale_bar_unit = "a.u." if scale_bar_data_value >= 100: scale_bar_label = f"{int(round(scale_bar_data_value, -1))} {scale_bar_unit}" elif scale_bar_data_value >= 10: scale_bar_label = f"{int(round(scale_bar_data_value))} {scale_bar_unit}" elif scale_bar_data_value >= 1: scale_bar_label = f"{scale_bar_data_value:.0f} {scale_bar_unit}" else: scale_bar_label = f"{scale_bar_data_value:.2f} {scale_bar_unit}" # create line objects (lower index = higher zorder = on top) lines = [] fills = [] for i in range(displayed_neurons - 1, -1, -1): z = displayed_neurons - i if mask_overlap: fill, = ax.fill([], [], color="black", zorder=z - 0.5) fills.append((i, fill)) line, = ax.plot([], [], color=colors[i], lw=lw, zorder=z) lines.append((i, line)) # static elements ax.set_ylim(y_min_global - y_range * 0.02, y_max_global + y_range * 0.02) if title: fig.suptitle(title, fontsize=16, fontweight="bold", color="white") ax.set_ylabel( f"Neuron Count: {displayed_neurons}", fontsize=10, fontweight="bold", color="white", labelpad=5, ) # time scale bar (static position) ax_pos = ax.get_position() time_bar_x = ax_pos.x1 - 0.02 time_bar_y = 0.07 line_width_fig = 0.08 time_line = fig.add_artist(plt.Line2D( [time_bar_x - line_width_fig, time_bar_x], [time_bar_y, time_bar_y], transform=fig.transFigure, color="white", linewidth=3, clip_on=False, )) time_text = fig.text( time_bar_x - line_width_fig / 2, time_bar_y - 0.02, time_label, ha="center", va="top", color="white", fontsize=10, transform=fig.transFigure, ) # vertical scale bar linekw = dict(color="white", linewidth=3) vsb = AnchoredVScaleBar( height=scale_bar_height_frac, label=scale_bar_label, loc="lower right", frameon=False, pad=0.5, sep=4, linekw=linekw, ax=ax, spacer_width=0, ) vsb.set_bbox_to_anchor((1.02, 0.0), transform=ax.transAxes) vsb.txt._text.set_color("white") ax.add_artist(vsb) # animation frames # each animation frame advances by (speed / anim_fps) seconds of data step_seconds = speed / anim_fps max_start_time = total_duration - window n_frames = int(max_start_time / step_seconds) + 1 def init(): for i, line in lines: line.set_data([], []) for i, fill in fills: fill.set_xy(np.empty((0, 2))) return [line for _, line in lines] + [fill for _, fill in fills] def update(frame): t_start = frame * step_seconds t_end = t_start + window i_start = int(t_start * fps) i_end = min(int(t_end * fps), n_timepoints) time_slice = data_time[i_start:i_end] # compute shifted traces for this window shifted = np.zeros((displayed_neurons, i_end - i_start)) for idx, neuron_idx in enumerate(indices): trace = f[neuron_idx, i_start:i_end] shifted[idx] = (trace - baselines[idx]) + idx * offset # update lines and fills for i, line in lines: line.set_data(time_slice, shifted[i]) if mask_overlap: for i, fill in fills: # fill from trace down to below the lowest point y_data = shifted[i] y_bottom = y_min_global - y_range * 0.1 # create polygon: trace forward, then bottom backward xy = np.column_stack([ np.concatenate([time_slice, time_slice[::-1]]), np.concatenate([y_data, np.full(len(y_data), y_bottom)]) ]) fill.set_xy(xy) ax.set_xlim(t_start, t_end) return [line for _, line in lines] + [fill for _, fill in fills] ani = FuncAnimation( fig, update, frames=n_frames, init_func=init, blit=True, interval=1000 / anim_fps, ) save_path = Path(save_path) print(f"Saving animation to {save_path} ({n_frames} frames at {anim_fps} fps)...") ani.save(str(save_path), fps=anim_fps, dpi=dpi, writer="ffmpeg") plt.close(fig) print(f"Saved: {save_path}") return str(save_path) def feather_mask(mask, max_alpha=0.75, edge_width=3): """ Create a feathered alpha mask with soft edges. Parameters ---------- mask : numpy.ndarray Binary or labeled mask (non-zero = foreground). max_alpha : float, optional Maximum alpha value at mask center. Default is 0.75. edge_width : int, optional Width of the feathered edge in pixels. Default is 3. Returns ------- numpy.ndarray Alpha mask with values in [0, max_alpha]. """ dist_out = distance_transform_edt(mask == 0) alpha = np.clip((edge_width - dist_out) / edge_width, 0, 1) return alpha * max_alpha def plot_masks( img: np.ndarray, stat: list[dict] | dict, mask_idx: np.ndarray, savepath: str | Path = None, colors=None, title=None, ): """ Draw ROI overlays onto the mean image. Parameters ---------- img : ndarray (Ly x Lx) Background image to overlay on. stat : list[dict] Suite2p ROI stat dictionaries (with "ypix", "xpix", "lam"). mask_idx : ndarray[bool] Boolean array selecting which ROIs to plot. savepath : str or Path, optional Path to save the figure. If None, displays with plt.show(). colors : ndarray or list, optional Array/list of RGB tuples for each ROI selected. If None, colors are assigned via HSV colormap. title : str, optional Title string to place on the figure. """ # Normalize background image using percentile stretch for better contrast # this prevents dark images when min/max are extreme outliers vmin = np.nanpercentile(img, 1) vmax = np.nanpercentile(img, 99) normalized = (img - vmin) / (vmax - vmin + 1e-6) normalized = np.clip(normalized, 0, 1) # Set NaN regions to 0 (black background) normalized = np.nan_to_num(normalized, nan=0.0) canvas = np.tile(normalized, (3, 1, 1)).transpose(1, 2, 0) # Get image dimensions for bounds checking Ly, Lx = img.shape[:2] # Assign colors if not provided n_masks = mask_idx.sum() if colors is None: colors = plt.cm.hsv(np.linspace(0, 1, n_masks + 1))[:, :3] # noqa c = 0 for n, s in enumerate(stat): if mask_idx[n]: ypix, xpix, lam = s["ypix"], s["xpix"], s["lam"] # Bounds checking - only keep pixels within image dimensions valid_mask = (ypix >= 0) & (ypix < Ly) & (xpix >= 0) & (xpix < Lx) if not np.any(valid_mask): c += 1 continue # Skip ROI if no valid pixels ypix = ypix[valid_mask] xpix = xpix[valid_mask] lam = lam[valid_mask] lam = lam / (lam.max() + 1e-10) col = colors[c] c += 1 for k in range(3): canvas[ypix, xpix, k] = ( 0.5 * canvas[ypix, xpix, k] + 0.5 * col[k] * lam ) fig, ax = plt.subplots(figsize=(10, 10), facecolor="black") ax.set_facecolor("black") ax.imshow(canvas, interpolation="nearest") if title is not None: ax.set_title(title, fontsize=10, color="white", fontweight="bold") ax.axis("off") plt.tight_layout() if savepath: if Path(savepath).is_dir(): raise ValueError("savepath must be a file path, not a directory.") plt.savefig(savepath, dpi=300, facecolor="black") plt.close(fig) else: plt.show()
[docs] def plot_projection( ops, output_directory=None, fig_label=None, vmin=None, vmax=None, add_scalebar=False, proj="meanImg", display_masks=False, accepted_only=False, ): from suite2p.detection.stats import ROI if proj == "meanImg": txt = "Mean-Image" elif proj == "max_proj": txt = "Max-Projection" elif proj == "meanImgE": txt = "Mean-Image (Enhanced)" else: raise ValueError( "Unknown projection type. Options are ['meanImg', 'max_proj', 'meanImgE']" ) if output_directory: output_directory = Path(output_directory) data = ops[proj] shape = data.shape fig, ax = plt.subplots(figsize=(6, 6), facecolor="black") vmin = np.nanpercentile(data, 2) if vmin is None else vmin vmax = np.nanpercentile(data, 98) if vmax is None else vmax if vmax - vmin < 1e-6: vmax = vmin + 1e-6 ax.imshow(data, cmap="gray", vmin=vmin, vmax=vmax) # move projection title higher if masks are displayed to avoid overlap. proj_title_y = 1.07 if display_masks else 1.02 ax.text( 0.5, proj_title_y, txt, transform=ax.transAxes, fontsize=14, fontweight="bold", fontname="Courier New", color="white", ha="center", va="bottom", ) if fig_label: fig_label = fig_label.replace("_", " ").replace("-", " ").replace(".", " ") ax.set_ylabel(fig_label, color="white", fontweight="bold", fontsize=12) ax.set_xticks([]) ax.set_yticks([]) if display_masks: res = load_planar_results(ops) stat = res["stat"] iscell_mask = res["iscell"][:, 0].astype(bool) im = ROI.stats_dicts_to_3d_array( stat, Ly=get_param(ops, "Ly", default=512), Lx=get_param(ops, "Lx", default=512), label_id=True ) im[im == 0] = np.nan accepted_cells = np.sum(iscell_mask) rejected_cells = np.sum(~iscell_mask) cell_rois = _resize_masks_fit_crop( np.nanmax(im[iscell_mask], axis=0) if np.any(iscell_mask) else np.zeros_like(im[0]), shape, ) green_overlay = np.zeros((*shape, 4), dtype=np.float32) green_overlay[..., 3] = feather_mask(cell_rois > 0, max_alpha=0.9) green_overlay[..., 1] = 1 ax.imshow(green_overlay) if not accepted_only: non_cell_rois = _resize_masks_fit_crop( ( np.nanmax(im[~iscell_mask], axis=0) if np.any(~iscell_mask) else np.zeros_like(im[0]) ), shape, ) magenta_overlay = np.zeros((*shape, 4), dtype=np.float32) magenta_overlay[..., 0] = 1 magenta_overlay[..., 2] = 1 magenta_overlay[..., 3] = (non_cell_rois > 0) * 0.5 ax.imshow(magenta_overlay) ax.text( 0.37, 1.02, f"Accepted: {accepted_cells:03d}", transform=ax.transAxes, fontsize=14, fontweight="bold", fontname="Courier New", color="lime", ha="right", va="bottom", ) ax.text( 0.63, 1.02, f"Rejected: {rejected_cells:03d}", transform=ax.transAxes, fontsize=14, fontweight="bold", fontname="Courier New", color="magenta", ha="left", va="bottom", ) if add_scalebar and "dx" in ops: pixel_size = ops["dx"] scale_bar_length = 100 / pixel_size scalebar_x = shape[1] * 0.05 scalebar_y = shape[0] * 0.90 ax.add_patch( Rectangle( (scalebar_x, scalebar_y), scale_bar_length, 5, edgecolor="white", facecolor="white", ) ) ax.text( scalebar_x + scale_bar_length / 2, scalebar_y - 10, "100 μm", color="white", fontsize=10, ha="center", fontweight="bold", ) # remove the spines that will show up as white bars for spine in ax.spines.values(): spine.set_visible(False) plt.tight_layout() if output_directory: output_directory.parent.mkdir(parents=True, exist_ok=True) plt.savefig(output_directory, dpi=300, facecolor="black") plt.close(fig) else: plt.show()
def plot_noise_distribution( noise_levels: np.ndarray, output_filename=None, title="Noise Level Distribution" ): """ Plots and saves the distribution of noise levels across neurons as a standardized image. Parameters ---------- noise_levels : np.ndarray 1D array of noise levels for each neuron. output_filename : str or Path, optional Path to save the plot. If empty, the plot will be displayed instead of saved. title : str, optional Suptitle for plot, default is "Noise Level Distribution". See Also -------- lbm_suite2p_python.dff_shot_noise """ if output_filename: output_filename = Path(output_filename) if output_filename.is_dir(): raise AttributeError( f"save_path should be a fully qualified file path, not a directory: {output_filename}" ) fig = plt.figure(figsize=(8, 5)) plt.hist(noise_levels, bins=50, color="gray", alpha=0.7, edgecolor="black") mean_noise: float = np.mean(noise_levels) # noqa plt.axvline( mean_noise, color="r", linestyle="dashed", linewidth=2, label=f"Mean: {mean_noise:.2f}", ) plt.xlabel("Noise Level", fontsize=14, fontweight="bold") plt.ylabel("Number of Neurons", fontsize=14, fontweight="bold") plt.title(title, fontsize=16, fontweight="bold") plt.legend(fontsize=12) plt.xticks(fontsize=12) plt.yticks(fontsize=12) if output_filename: plt.savefig(output_filename, dpi=200, bbox_inches="tight") plt.close(fig) else: plt.show()
[docs] def plot_rastermap( spks, model, neuron_bin_size=None, fps=17, vmin=0, vmax=0.8, xmin=0, xmax=None, save_path=None, title=None, title_kwargs=None, fig_text=None, ): n_neurons, n_timepoints = spks.shape if title_kwargs is None: title_kwargs = dict(fontsize=14, fontweight="bold", color="white") if neuron_bin_size is None: neuron_bin_size = max(1, np.ceil(n_neurons // 500)) else: neuron_bin_size = max(1, min(neuron_bin_size, n_neurons)) sn = bin1d(spks[model.isort], neuron_bin_size, axis=0) if xmax is None or xmax < xmin or xmax > sn.shape[1]: xmax = sn.shape[1] sn = sn[:, xmin:xmax] current_time = np.round((xmax - xmin) / fps, 1) current_neurons = sn.shape[0] fig, ax = plt.subplots(figsize=(6, 3), dpi=200) img = ax.imshow(sn, cmap="gray_r", vmin=vmin, vmax=vmax, aspect="auto") fig.patch.set_facecolor("black") ax.set_facecolor("black") ax.tick_params(axis="both", labelbottom=False, labelleft=False, length=0) for spine in ax.spines.values(): spine.set_visible(False) heatmap_pos = ax.get_position() scalebar_length = heatmap_pos.width * 0.1 # 10% width of heatmap scalebar_duration = np.round( current_time * 0.1 # noqa ) # 10% of the displayed time in heatmap x_start = heatmap_pos.x1 - scalebar_length x_end = heatmap_pos.x1 y_position = heatmap_pos.y0 fig.lines.append( plt.Line2D( [x_start, x_end], [y_position - 0.03, y_position - 0.03], transform=fig.transFigure, color="white", linewidth=2, solid_capstyle="butt", ) ) fig.text( x=(x_start + x_end) / 2, y=y_position - 0.045, # slightly below the scalebar s=f"{scalebar_duration:.0f} s", ha="center", va="top", color="white", fontsize=6, ) axins = fig.add_axes( [ # noqa heatmap_pos.x0, # exactly aligned with heatmap's left edge heatmap_pos.y0 - 0.03, # slightly below the heatmap heatmap_pos.width * 0.1, # 20% width of heatmap 0.015, # height of the colorbar ] ) cbar = fig.colorbar(img, cax=axins, orientation="horizontal", ticks=[vmin, vmax]) cbar.ax.tick_params(labelsize=5, colors="white", pad=2) cbar.outline.set_edgecolor("white") # noqa fig.text( heatmap_pos.x0, heatmap_pos.y0 - 0.1, # below the colorbar with spacing "z-scored", ha="left", va="top", color="white", fontsize=6, ) scalebar_neurons = int(0.1 * current_neurons) x_position = heatmap_pos.x1 + 0.01 # slightly right of heatmap y_start = heatmap_pos.y0 y_end = y_start + (heatmap_pos.height * scalebar_neurons / current_neurons) line = plt.Line2D( [x_position, x_position], [y_start, y_end], transform=fig.transFigure, color="white", linewidth=2, ) line.set_figure(fig) fig.lines.append(line) ntype = "neurons" if scalebar_neurons == 1 else "neurons" fig.text( x=x_position + 0.008, y=y_start, s=f"{scalebar_neurons} {ntype}", ha="left", va="bottom", color="white", fontsize=6, rotation=90, ) if fig_text is None: fig_text = f"Neurons: {spks.shape[0]}, Superneurons: {sn.shape[0]}, n_clusters: {model.n_PCs}, n_PCs: {model.n_clusters}, locality: {model.locality}" fig.text( x=(heatmap_pos.x0 + heatmap_pos.x1) / 2, y=y_start - 0.085, # vertically between existing scalebars s=fig_text, ha="center", va="top", color="white", fontsize=6, ) if title is not None: plt.suptitle(title, **title_kwargs) if save_path is not None: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=200, facecolor="black", bbox_inches="tight") plt.close(fig) else: plt.show() return fig, ax
def save_pc_panels_and_metrics(ops, savepath, pcs=(0, 1, 2, 3)): """ Save PC metrics in two forms: 1. Alternating TIFF (PC Low/High side-by-side per frame, press play in ImageJ to flip). 2. Panel TIFF (static figures for PC1/2 and PC3/4). Also saves summary metrics as CSV. Parameters ---------- ops : dict or str or Path Suite2p ops dict or path to ops.npy. Must contain "regPC" and "regDX". savepath : str or Path Output file stem (without extension). pcs : tuple of int PCs to include (default first four). """ if not isinstance(ops, dict): ops = np.load(ops, allow_pickle=True).item() if "nframes" in ops and ops["nframes"] < 1500: print( f"1500 frames needed for registration metrics, found {ops['nframes']}. Skipping PC metrics." ) return {} elif "regPC" not in ops or "regDX" not in ops: print("regPC or regDX not found in ops, skipping PC metrics.") return {} elif len(pcs) != 4 or any(p < 0 for p in pcs): raise ValueError( "pcs must be a tuple of four non-negative integers." " E.g., (0, 1, 2, 3) for the first four PCs." f" Got: {pcs}" ) regPC = ops["regPC"] # shape (2, nPC, Ly, Lx) regDX = ops["regDX"] # shape (nPC, 3) savepath = Path(savepath) alt_frames = [] alt_labels = [] for view, view_name in zip([0, 1], ["Low", "High"]): # side-by-side: PC1 | PC2 left = regPC[view, pcs[0]] right = regPC[view, pcs[1]] combined = np.hstack([left, right]) alt_frames.append(combined.astype(np.float32)) alt_labels.append(f"PC{pcs[0] + 1}/{pcs[1] + 1} {view_name}") # side-by-side: PC3 | PC4 left = regPC[view, pcs[2]] right = regPC[view, pcs[3]] combined = np.hstack([left, right]) alt_frames.append(combined.astype(np.float32)) alt_labels.append(f"PC{pcs[2] + 1}/{pcs[3] + 1} {view_name}") panel_frames = [] panel_labels = [] for left, right in [(pcs[0], pcs[1]), (pcs[2], pcs[3])]: for view, view_name in zip([0, 1], ["Low", "High"]): fig, axes = plt.subplots(1, 2, figsize=(10, 5)) axes[0].imshow(regPC[view, left], cmap="gray") axes[0].set_title(f"PC{left + 1} {view_name}") axes[0].axis("off") axes[1].imshow(regPC[view, right], cmap="gray") axes[1].set_title(f"PC{right + 1} {view_name}") axes[1].axis("off") fig.tight_layout() fig.canvas.draw() img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) # noqa w, h = fig.canvas.get_width_height() img = img.reshape((h, w, 4))[..., :3] panel_frames.append(img) panel_labels.append(f"PC{left + 1}/{right + 1} {view_name}") plt.close(fig) panel_tiff = savepath.with_name(savepath.stem + "_panels.tif") tifffile.imwrite( panel_tiff, np.stack(panel_frames, axis=0), imagej=True, metadata={"Labels": panel_labels}, ) df = pd.DataFrame(regDX, columns=["Rigid", "Avg_NR", "Max_NR"]) metrics = { "Avg_Rigid": df["Rigid"].mean(), "Avg_Average_NR": df["Avg_NR"].mean(), "Avg_Max_NR": df["Max_NR"].mean(), "Max_Rigid": df["Rigid"].max(), "Max_Average_NR": df["Avg_NR"].max(), "Max_Max_NR": df["Max_NR"].max(), } csv_path = savepath.with_suffix(".csv") pd.DataFrame([metrics]).to_csv(csv_path, index=False) return { "panel_tiff": panel_tiff, "metrics_csv": csv_path, } def plot_multiplane_masks( suite2p_path: str | Path, stat: np.ndarray, iscell: np.ndarray, nrows: int = None, ncols: int = None, figsize: tuple = None, save_path: str | Path = None, cmap: str = "gray", ) -> plt.Figure: """ Plot ROI masks from all planes in a publication-quality grid layout. Creates a multi-panel figure showing detected ROIs overlaid on reference images for each z-plane, with accepted cells in green and rejected cells in red. Background image is selected based on anatomical_only setting. Parameters ---------- suite2p_path : str or Path Path to suite2p directory containing plane folders (e.g., plane01_stitched/). stat : np.ndarray Consolidated stat array with 'iplane' field indicating plane assignment. iscell : np.ndarray Cell classification array (n_rois, 2) where column 0 is binary classification. nrows : int, optional Number of rows in the figure grid. Auto-calculated if None. ncols : int, optional Number of columns in the figure grid. Auto-calculated if None. figsize : tuple, optional Figure size in inches (width, height). Auto-calculated if None. save_path : str or Path, optional If provided, save figure to this path. Otherwise display interactively. cmap : str, default "gray" Colormap for background images. Returns ------- fig : matplotlib.figure.Figure The generated figure object. """ from scipy import ndimage suite2p_path = Path(suite2p_path) plane_dirs = sorted(suite2p_path.glob("plane*_stitched")) if not plane_dirs: plane_dirs = sorted(suite2p_path.glob("plane*")) nplanes = len(plane_dirs) if nplanes == 0: fig = plt.figure(figsize=(8, 6), facecolor="black") fig.text(0.5, 0.5, "No plane directories found", ha="center", va="center", fontsize=14, color="white") return fig # auto-calculate grid size if ncols is None: ncols = min(5, nplanes) if nrows is None: nrows = int(np.ceil(nplanes / ncols)) # auto-calculate figure size (make panels ~5 inches each for better visibility) if figsize is None: figsize = (ncols * 5, nrows * 5) fig, axes = plt.subplots( nrows, ncols, figsize=figsize, facecolor="black", gridspec_kw={"wspace": 0.02, "hspace": 0.08} ) if nrows == 1 and ncols == 1: axes = np.array([axes]) axes = axes.flatten() for idx, plane_dir in enumerate(plane_dirs): if idx >= len(axes): break ax = axes[idx] ax.set_facecolor("black") # extract plane number from directory name for display plane_name = plane_dir.name digits = "".join(filter(str.isdigit, plane_name)) plane_num = int(digits) if digits else idx + 1 # load plane ops ops_file = plane_dir / "ops.npy" yshift, xshift = 0, 0 # coordinate shifts for cropped images if ops_file.exists(): plane_ops = np.load(ops_file, allow_pickle=True).item() Ly = plane_ops.get("Ly", 512) Lx = plane_ops.get("Lx", 512) # get crop ranges from registration yrange = plane_ops.get("yrange", [0, Ly]) xrange = plane_ops.get("xrange", [0, Lx]) # select background image based on anatomical_only anatomical_only = plane_ops.get("anatomical_only", 0) if anatomical_only >= 4: # max projection for anatomical mode (cropped space) img = plane_ops.get("max_proj", None) if img is not None: yshift, xshift = int(yrange[0]), int(xrange[0]) else: img = plane_ops.get("meanImg") elif anatomical_only == 0: # Vcorr for functional imaging (cropped space) img = plane_ops.get("Vcorr", None) if img is not None: yshift, xshift = int(yrange[0]), int(xrange[0]) else: img = plane_ops.get("meanImg") else: # meanImg for other modes (full space) img = plane_ops.get("meanImg", plane_ops.get("meanImgE")) if img is None: img = np.zeros((Ly, Lx)) else: img = np.zeros((512, 512)) Ly, Lx = 512, 512 # display background image with proper contrast img_h, img_w = img.shape[:2] vmin, vmax = np.nanpercentile(img, [1, 99.5]) ax.imshow(img, cmap=cmap, aspect="equal", vmin=vmin, vmax=vmax) # get ROIs for this plane (iplane is 0-indexed from enumeration) plane_mask = np.array([s.get("iplane", 0) == idx for s in stat]) plane_stat = stat[plane_mask] plane_iscell = iscell[plane_mask] # create mask images for accepted and rejected cells accepted_mask = np.zeros((img_h, img_w), dtype=bool) rejected_mask = np.zeros((img_h, img_w), dtype=bool) accepted_idx = plane_iscell[:, 0] == 1 rejected_idx = plane_iscell[:, 0] == 0 for s in plane_stat[accepted_idx]: # shift coordinates from full to cropped space ypix = s["ypix"] - yshift xpix = s["xpix"] - xshift valid = (ypix >= 0) & (ypix < img_h) & (xpix >= 0) & (xpix < img_w) accepted_mask[ypix[valid], xpix[valid]] = True for s in plane_stat[rejected_idx]: # shift coordinates from full to cropped space ypix = s["ypix"] - yshift xpix = s["xpix"] - xshift valid = (ypix >= 0) & (ypix < img_h) & (xpix >= 0) & (xpix < img_w) rejected_mask[ypix[valid], xpix[valid]] = True # compute outlines acc_outline = ndimage.binary_dilation(accepted_mask) & ~accepted_mask rej_outline = ndimage.binary_dilation(rejected_mask) & ~rejected_mask # create rgba overlay overlay = np.zeros((img_h, img_w, 4), dtype=np.float32) # accepted cells: green fill with outline overlay[accepted_mask, :] = [0.2, 0.8, 0.2, 0.3] # green fill overlay[acc_outline, :] = [0.4, 1.0, 0.4, 0.9] # bright green outline # rejected cells: red fill with outline overlay[rejected_mask, :] = [0.8, 0.2, 0.2, 0.2] # red fill overlay[rej_outline, :] = [1.0, 0.3, 0.3, 0.6] # red outline ax.imshow(overlay) n_acc = accepted_idx.sum() n_rej = rejected_idx.sum() # title with plane info (white on black) ax.set_title( f"Plane {plane_num:02d} ({n_acc}/{n_rej})", fontsize=14, fontweight="bold", color="white", pad=6 ) ax.axis("off") # hide unused subplots for idx in range(nplanes, len(axes)): axes[idx].set_visible(False) # add legend from matplotlib.lines import Line2D from matplotlib.patches import Patch legend_elements = [ Patch(facecolor=(0.2, 0.8, 0.2, 0.5), edgecolor=(0.4, 1.0, 0.4, 1.0), linewidth=2, label="Accepted"), Patch(facecolor=(0.8, 0.2, 0.2, 0.3), edgecolor=(1.0, 0.3, 0.3, 1.0), linewidth=2, label="Rejected"), ] fig.legend( handles=legend_elements, loc="lower center", ncol=2, fontsize=12, frameon=False, bbox_to_anchor=(0.5, 0.01), labelcolor="white" ) plt.tight_layout(rect=[0, 0.04, 1, 1]) if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor="black") plt.close(fig) return fig def plot_plane_quality_metrics( stat: np.ndarray, iscell: np.ndarray, save_path: str | Path = None, figsize: tuple = (14, 10), ) -> plt.Figure: """ Generate publication-quality ROI quality metrics across all planes. Creates a multi-panel figure with line plots showing mean ± std: - Compactness vs plane - Skewness vs plane - ROI size (npix) vs plane - Radius vs plane Parameters ---------- stat : np.ndarray Consolidated stat array with 'iplane', 'compact', 'npix' fields. iscell : np.ndarray Cell classification array (n_rois, 2). save_path : str or Path, optional If provided, save figure to this path. figsize : tuple, default (14, 10) Figure size in inches. Returns ------- fig : matplotlib.figure.Figure The generated figure object. Examples -------- >>> stat = np.load("merged/stat.npy", allow_pickle=True) >>> iscell = np.load("merged/iscell.npy") >>> fig = plot_plane_quality_metrics(stat, iscell, save_path="quality.png") """ # Extract metrics plane_nums = np.array([s.get("iplane", 0) for s in stat]) unique_planes = np.unique(plane_nums) n_planes = len(unique_planes) compactness = np.array([s.get("compact", np.nan) for s in stat]) skewness = np.array([s.get("skew", np.nan) for s in stat]) npix = np.array([s.get("npix", 0) for s in stat]) radius = np.array([s.get("radius", np.nan) for s in stat]) accepted = iscell[:, 0] == 1 # Dark theme colors (consistent with plot_volume_diagnostics) bg_color = "black" text_color = "white" colors = { "compactness": "#9b59b6", # Purple "skewness": "#e67e22", # Orange "size": "#3498db", # Blue "radius": "#2ecc71", # Green } mean_line_color = "#e74c3c" # Red for mean markers # Compute mean and std per plane for accepted cells def compute_stats_per_plane(values, plane_nums, accepted, unique_planes): means = [] stds = [] for p in unique_planes: mask = (plane_nums == p) & accepted & ~np.isnan(values) if mask.sum() > 0: means.append(np.mean(values[mask])) stds.append(np.std(values[mask])) else: means.append(np.nan) stds.append(np.nan) return np.array(means), np.array(stds) compact_mean, compact_std = compute_stats_per_plane(compactness, plane_nums, accepted, unique_planes) skew_mean, skew_std = compute_stats_per_plane(skewness, plane_nums, accepted, unique_planes) npix_mean, npix_std = compute_stats_per_plane(npix.astype(float), plane_nums, accepted, unique_planes) radius_mean, radius_std = compute_stats_per_plane(radius, plane_nums, accepted, unique_planes) with plt.style.context("default"): fig, axes = plt.subplots(2, 2, figsize=figsize, facecolor=bg_color) axes = axes.flatten() x = np.arange(n_planes) def style_axis(ax, xlabel, ylabel, title): ax.set_facecolor(bg_color) ax.set_xlabel(xlabel, fontweight="bold", fontsize=10, color=text_color) ax.set_ylabel(ylabel, fontweight="bold", fontsize=10, color=text_color) ax.set_title(title, fontweight="bold", fontsize=11, color=text_color) ax.tick_params(colors=text_color, labelsize=9) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.spines["bottom"].set_color(text_color) ax.spines["left"].set_color(text_color) # Set x-ticks to show plane numbers if n_planes <= 20: ax.set_xticks(x) ax.set_xticklabels([f"{int(p)}" for p in unique_planes]) else: step = max(1, n_planes // 10) ax.set_xticks(x[::step]) ax.set_xticklabels([f"{int(p)}" for p in unique_planes[::step]]) # Panel 1: Compactness ax = axes[0] valid = ~np.isnan(compact_mean) ax.fill_between(x[valid], (compact_mean - compact_std)[valid], (compact_mean + compact_std)[valid], alpha=0.3, color=colors["compactness"]) ax.plot(x[valid], compact_mean[valid], 'o-', color=colors["compactness"], linewidth=2, markersize=5) style_axis(ax, "Z-Plane", "Compactness", "ROI Compactness (Accepted)") # Panel 2: Skewness ax = axes[1] valid = ~np.isnan(skew_mean) ax.fill_between(x[valid], (skew_mean - skew_std)[valid], (skew_mean + skew_std)[valid], alpha=0.3, color=colors["skewness"]) ax.plot(x[valid], skew_mean[valid], 'o-', color=colors["skewness"], linewidth=2, markersize=5) style_axis(ax, "Z-Plane", "Skewness", "Trace Skewness (Accepted)") # Panel 3: ROI Size (npix) ax = axes[2] valid = ~np.isnan(npix_mean) ax.fill_between(x[valid], (npix_mean - npix_std)[valid], (npix_mean + npix_std)[valid], alpha=0.3, color=colors["size"]) ax.plot(x[valid], npix_mean[valid], 'o-', color=colors["size"], linewidth=2, markersize=5) style_axis(ax, "Z-Plane", "Number of Pixels", "ROI Size (Accepted)") # Panel 4: Radius ax = axes[3] valid = ~np.isnan(radius_mean) ax.fill_between(x[valid], (radius_mean - radius_std)[valid], (radius_mean + radius_std)[valid], alpha=0.3, color=colors["radius"]) ax.plot(x[valid], radius_mean[valid], 'o-', color=colors["radius"], linewidth=2, markersize=5) style_axis(ax, "Z-Plane", "Radius (pixels)", "ROI Radius (Accepted)") # Main title total_accepted = int(accepted.sum()) total_rois = len(stat) fig.suptitle( f"Volume Quality Metrics: {total_accepted} accepted / {total_rois} total ROIs", fontsize=12, fontweight="bold", color=text_color, y=0.98 ) plt.tight_layout(rect=[0, 0, 1, 0.96]) if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor=bg_color) plt.close(fig) return fig def plot_trace_analysis( F: np.ndarray, Fneu: np.ndarray, stat: np.ndarray, iscell: np.ndarray, ops: dict, save_path: str | Path = None, figsize: tuple = (16, 14), ) -> Tuple[plt.Figure, dict]: """ Generate trace analysis figure showing extreme examples by quality metrics. Creates a 6-panel figure showing example ΔF/F traces for: - Highest SNR / Lowest SNR - Lowest shot noise / Highest shot noise - Highest skewness / Lowest skewness Parameters ---------- F : np.ndarray Fluorescence traces array (n_rois, n_frames). Fneu : np.ndarray Neuropil fluorescence array (n_rois, n_frames). stat : np.ndarray Stat array with 'iplane' and 'skew' fields. iscell : np.ndarray Cell classification array. ops : dict Ops dictionary with 'fs' (frame rate) field. save_path : str or Path, optional If provided, save figure to this path. figsize : tuple, default (16, 14) Figure size in inches. Returns ------- fig : matplotlib.figure.Figure The generated figure object. metrics : dict Dictionary containing computed metrics (snr, shot_noise, skewness, dff). Examples -------- >>> fig, metrics = plot_trace_analysis(F, Fneu, stat, iscell, ops) >>> print(f"Mean SNR: {np.mean(metrics['snr']):.2f}") """ accepted = iscell[:, 0] == 1 n_accepted = int(np.sum(accepted)) if n_accepted == 0: fig = plt.figure(figsize=figsize, facecolor="black") fig.text(0.5, 0.5, "No accepted ROIs found", ha="center", va="center", fontsize=16, fontweight="bold", color="white") return fig, {} F_acc = F[accepted] Fneu_acc = Fneu[accepted] stat_acc = stat[accepted] plane_nums = np.array([s.get("iplane", 0) for s in stat_acc]) fs = ops.get("fs", 30.0) # Compute ΔF/F F_corrected = F_acc - 0.7 * Fneu_acc baseline = np.percentile(F_corrected, 20, axis=1, keepdims=True) baseline = np.maximum(baseline, 1e-6) dff = (F_corrected - baseline) / baseline # Compute metrics # SNR: signal / noise signal = np.std(dff, axis=1) noise = np.median(np.abs(np.diff(dff, axis=1)), axis=1) / 0.6745 # MAD estimator snr = signal / (noise + 1e-6) # Shot noise: noise level (MAD of diff) shot_noise = noise # Skewness: from stat or compute from trace skewness = np.array([s.get("skew", np.nan) for s in stat_acc]) # Fill NaNs with computed skewness if needed nan_mask = np.isnan(skewness) if nan_mask.any(): from scipy.stats import skew as scipy_skew for i in np.where(nan_mask)[0]: skewness[i] = scipy_skew(dff[i]) # Style configuration bg_color = "black" text_color = "white" # Colors for each metric type colors = { "snr_high": "#2ecc71", # Green - good "snr_low": "#e74c3c", # Red - bad "noise_low": "#3498db", # Blue - good "noise_high": "#e67e22", # Orange - bad "skew_high": "#9b59b6", # Purple - high activity "skew_low": "#95a5a6", # Gray - low activity } # Find indices for each category valid_mask = ~np.isnan(snr) & ~np.isnan(shot_noise) & ~np.isnan(skewness) valid_idx = np.where(valid_mask)[0] if len(valid_idx) == 0: fig = plt.figure(figsize=figsize, facecolor="black") fig.text(0.5, 0.5, "No valid ROIs with computed metrics", ha="center", va="center", fontsize=16, fontweight="bold", color="white") return fig, {} # Get indices for extremes snr_valid = snr[valid_mask] noise_valid = shot_noise[valid_mask] skew_valid = skewness[valid_mask] idx_snr_high = valid_idx[np.argmax(snr_valid)] idx_snr_low = valid_idx[np.argmin(snr_valid)] idx_noise_low = valid_idx[np.argmin(noise_valid)] idx_noise_high = valid_idx[np.argmax(noise_valid)] idx_skew_high = valid_idx[np.argmax(skew_valid)] idx_skew_low = valid_idx[np.argmin(skew_valid)] # Time axis - show up to 100s or full trace n_frames_show = min(int(100 * fs), dff.shape[1]) time = np.arange(n_frames_show) / fs with plt.style.context("default"): fig = plt.figure(figsize=figsize, facecolor=bg_color) gs = gridspec.GridSpec(3, 2, figure=fig, hspace=0.4, wspace=0.25, left=0.08, right=0.95, top=0.92, bottom=0.06) def plot_trace_panel(ax, idx, title, color, metric_name, metric_val): """Plot a single trace panel.""" ax.set_facecolor(bg_color) trace = dff[idx, :n_frames_show] ax.plot(time, trace, color=color, linewidth=0.8, alpha=0.9) # Add zero line ax.axhline(0, color="gray", linestyle="--", linewidth=0.5, alpha=0.5) # Get plane info plane = plane_nums[idx] # Style ax.set_xlabel("Time (s)", fontsize=10, fontweight="bold", color=text_color) ax.set_ylabel("ΔF/F", fontsize=10, fontweight="bold", color=text_color) ax.set_title(f"{title}\n{metric_name}={metric_val:.2f}, Plane {plane}", fontsize=11, fontweight="bold", color=text_color) ax.tick_params(colors=text_color, labelsize=9) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.spines["bottom"].set_color(text_color) ax.spines["left"].set_color(text_color) # Set reasonable y-limits y_max = np.percentile(trace, 99.5) y_min = np.percentile(trace, 0.5) margin = (y_max - y_min) * 0.1 ax.set_ylim(y_min - margin, y_max + margin) # Row 1: SNR extremes ax1 = fig.add_subplot(gs[0, 0]) plot_trace_panel(ax1, idx_snr_high, "Highest SNR", colors["snr_high"], "SNR", snr[idx_snr_high]) ax2 = fig.add_subplot(gs[0, 1]) plot_trace_panel(ax2, idx_snr_low, "Lowest SNR", colors["snr_low"], "SNR", snr[idx_snr_low]) # Row 2: Shot noise extremes ax3 = fig.add_subplot(gs[1, 0]) plot_trace_panel(ax3, idx_noise_low, "Lowest Shot Noise", colors["noise_low"], "Noise", shot_noise[idx_noise_low]) ax4 = fig.add_subplot(gs[1, 1]) plot_trace_panel(ax4, idx_noise_high, "Highest Shot Noise", colors["noise_high"], "Noise", shot_noise[idx_noise_high]) # Row 3: Skewness extremes ax5 = fig.add_subplot(gs[2, 0]) plot_trace_panel(ax5, idx_skew_high, "Highest Skewness", colors["skew_high"], "Skew", skewness[idx_skew_high]) ax6 = fig.add_subplot(gs[2, 1]) plot_trace_panel(ax6, idx_skew_low, "Lowest Skewness", colors["skew_low"], "Skew", skewness[idx_skew_low]) # Main title with summary stats fig.suptitle( f"Trace Quality Extremes: {n_accepted} accepted ROIs | " f"SNR: {np.nanmedian(snr):.1f} (median) | " f"Noise: {np.nanmedian(shot_noise):.3f} (median)", fontsize=12, fontweight="bold", color=text_color, y=0.98 ) if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor=bg_color) plt.close(fig) metrics = { "snr": snr, "shot_noise": shot_noise, "skewness": skewness, "dff": dff, } return fig, metrics def create_volume_summary_table( stat: np.ndarray, iscell: np.ndarray, F: np.ndarray = None, Fneu: np.ndarray = None, ops: dict = None, save_path: str | Path = None, ) -> pd.DataFrame: """ Generates per-plane and aggregate statistics including ROI counts, SNR metrics, and quality measures. Parameters ---------- stat : np.ndarray Consolidated stat array with plane assignments. iscell : np.ndarray Cell classification array. F : np.ndarray, optional Fluorescence traces for SNR calculation. Fneu : np.ndarray, optional Neuropil traces for SNR calculation. ops : dict, optional Ops dictionary with frame rate. save_path : str or Path, optional If provided, save CSV to this path. Returns ------- df : pd.DataFrame Summary statistics table. Examples -------- >>> df = create_volume_summary_table(stat, iscell, F, Fneu, ops) >>> print(df.to_string()) """ accepted = iscell[:, 0] == 1 plane_nums = np.array([s.get("iplane", 0) for s in stat]) unique_planes = np.unique(plane_nums) # Compute SNR if traces provided snr = None mean_F_arr = None if F is not None and Fneu is not None: F_acc = F[accepted] Fneu_acc = Fneu[accepted] F_corrected = F_acc - 0.7 * Fneu_acc baseline = np.percentile(F_corrected, 20, axis=1, keepdims=True) baseline = np.maximum(baseline, 1e-6) dff = (F_corrected - baseline) / baseline signal = np.std(dff, axis=1) noise = np.median(np.abs(np.diff(dff, axis=1)), axis=1) / 0.6745 snr = signal / (noise + 1e-6) mean_F_arr = np.mean(F_acc, axis=1) plane_nums_acc = plane_nums[accepted] else: plane_nums_acc = plane_nums[accepted] # Extract metrics compactness = np.array([s.get("compact", np.nan) for s in stat]) npix = np.array([s.get("npix", 0) for s in stat]) summary_data = [] for p in unique_planes: plane_mask = plane_nums == p plane_mask_acc = plane_nums_acc == p if snr is not None else plane_mask & accepted n_total = plane_mask.sum() n_accepted = (plane_mask & accepted).sum() row = { "Plane": int(p), "Total_ROIs": int(n_total), "Accepted": int(n_accepted), "Rejected": int(n_total - n_accepted), "Accept_Rate_%": f"{100 * n_accepted / max(1, n_total):.1f}", "Mean_Compact": f"{np.nanmean(compactness[plane_mask & accepted]):.2f}", "Mean_Size_px": f"{np.mean(npix[plane_mask & accepted]):.0f}", } if snr is not None and plane_mask_acc.sum() > 0: row["Mean_SNR"] = f"{np.mean(snr[plane_mask_acc]):.2f}" row["Median_SNR"] = f"{np.median(snr[plane_mask_acc]):.2f}" row["High_SNR_%"] = f"{100 * np.sum(snr[plane_mask_acc] > 2) / plane_mask_acc.sum():.1f}" row["Mean_F"] = f"{np.mean(mean_F_arr[plane_mask_acc]):.0f}" summary_data.append(row) df = pd.DataFrame(summary_data) # Add totals row totals = { "Plane": "ALL", "Total_ROIs": int(len(stat)), "Accepted": int(accepted.sum()), "Rejected": int((~accepted).sum()), "Accept_Rate_%": f"{100 * accepted.sum() / len(stat):.1f}", "Mean_Compact": f"{np.nanmean(compactness[accepted]):.2f}", "Mean_Size_px": f"{np.mean(npix[accepted]):.0f}", } if snr is not None: totals["Mean_SNR"] = f"{np.mean(snr):.2f}" totals["Median_SNR"] = f"{np.median(snr):.2f}" totals["High_SNR_%"] = f"{100 * np.sum(snr > 2) / len(snr):.1f}" totals["Mean_F"] = f"{np.mean(mean_F_arr):.0f}" df = pd.concat([df, pd.DataFrame([totals])], ignore_index=True) if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) df.to_csv(save_path, index=False) print(f"Summary table saved to: {save_path}") return df def plot_volume_filter_summary( suite2p_path: str | Path, save_path: str | Path = None, figsize: tuple = (14, 8), ) -> plt.Figure: """ Create a volumetric summary figure showing cell filtering across all planes. Shows bar chart of accepted/rejected cells per plane, plus summary stats. Parameters ---------- suite2p_path : str or Path Path to suite2p output directory containing plane subdirectories. save_path : str or Path, optional Path to save the figure. If None, displays with plt.show(). figsize : tuple, default (14, 8) Figure size. Returns ------- matplotlib.figure.Figure The generated figure. """ import matplotlib.pyplot as plt from matplotlib.patches import Patch suite2p_path = Path(suite2p_path) # find plane directories plane_dirs = sorted(suite2p_path.glob("plane*")) if not plane_dirs: # single plane case if (suite2p_path / "stat.npy").exists(): plane_dirs = [suite2p_path] else: raise ValueError(f"No plane directories or stat.npy found in {suite2p_path}") # collect stats per plane plane_stats = [] for pdir in plane_dirs: stat_file = pdir / "stat.npy" iscell_file = pdir / "iscell.npy" iscell_s2p_file = pdir / "iscell_suite2p.npy" if not stat_file.exists() or not iscell_file.exists(): continue stat = np.load(stat_file, allow_pickle=True) iscell = np.load(iscell_file, allow_pickle=True) if iscell.ndim == 2: iscell = iscell[:, 0] # load suite2p original if exists if iscell_s2p_file.exists(): iscell_s2p = np.load(iscell_s2p_file, allow_pickle=True) if iscell_s2p.ndim == 2: iscell_s2p = iscell_s2p[:, 0] else: iscell_s2p = iscell n_total = len(stat) n_final_accepted = int(iscell.astype(bool).sum()) n_s2p_accepted = int(iscell_s2p.astype(bool).sum()) n_s2p_rejected = n_total - n_s2p_accepted n_filter_rejected = n_s2p_accepted - n_final_accepted # get plane number from dir name plane_name = pdir.name try: plane_num = int(plane_name.replace("plane", "")) except ValueError: plane_num = len(plane_stats) plane_stats.append({ "plane": plane_num, "name": plane_name, "n_total": n_total, "n_s2p_accepted": n_s2p_accepted, "n_s2p_rejected": n_s2p_rejected, "n_filter_rejected": n_filter_rejected, "n_final_accepted": n_final_accepted, }) if not plane_stats: raise ValueError("No valid plane data found") # sort by plane number plane_stats = sorted(plane_stats, key=lambda x: x["plane"]) # create figure fig, axes = plt.subplots(1, 2, figsize=figsize, gridspec_kw={"width_ratios": [2, 1]}) # left panel: stacked bar chart per plane ax = axes[0] planes = [p["name"] for p in plane_stats] x = np.arange(len(planes)) width = 0.7 # stack: final_accepted (green) + filter_rejected (orange) + s2p_rejected (red) final_accepted = [p["n_final_accepted"] for p in plane_stats] filter_rejected = [p["n_filter_rejected"] for p in plane_stats] s2p_rejected = [p["n_s2p_rejected"] for p in plane_stats] bars1 = ax.bar(x, final_accepted, width, label="accepted", color="#33a02c") bars2 = ax.bar(x, filter_rejected, width, bottom=final_accepted, label="filter rejected", color="#ff7f00") bars3 = ax.bar(x, s2p_rejected, width, bottom=[f + r for f, r in zip(final_accepted, filter_rejected)], label="suite2p rejected", color="#e31a1c") ax.set_xlabel("plane", fontsize=11) ax.set_ylabel("ROI count", fontsize=11) ax.set_title("ROI filtering per plane", fontsize=12, fontweight="bold") ax.set_xticks(x) ax.set_xticklabels(planes, rotation=45, ha="right") ax.legend(loc="upper right") # right panel: summary pie chart and stats ax2 = axes[1] total_final = sum(final_accepted) total_filter_rej = sum(filter_rejected) total_s2p_rej = sum(s2p_rejected) total_all = total_final + total_filter_rej + total_s2p_rej # pie chart sizes = [total_final, total_filter_rej, total_s2p_rej] labels = ["accepted", "filter rejected", "suite2p rejected"] colors = ["#33a02c", "#ff7f00", "#e31a1c"] # filter out zero values for pie nonzero = [(s, l, c) for s, l, c in zip(sizes, labels, colors) if s > 0] if nonzero: sizes_nz, labels_nz, colors_nz = zip(*nonzero) wedges, texts, autotexts = ax2.pie( sizes_nz, labels=labels_nz, colors=colors_nz, autopct=lambda pct: f"{pct:.1f}%\n({int(pct/100*total_all)})", startangle=90, textprops={"fontsize": 9} ) ax2.set_title("overall summary", fontsize=12, fontweight="bold") # add text summary below pie summary_text = ( f"total ROIs: {total_all}\n" f"final accepted: {total_final} ({100*total_final/max(1,total_all):.1f}%)\n" f"filter rejected: {total_filter_rej} ({100*total_filter_rej/max(1,total_all):.1f}%)\n" f"suite2p rejected: {total_s2p_rej} ({100*total_s2p_rej/max(1,total_all):.1f}%)\n" f"planes: {len(plane_stats)}" ) ax2.text(0.5, -0.15, summary_text, transform=ax2.transAxes, ha="center", va="top", fontsize=10, family="monospace") plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches="tight") plt.close(fig) else: plt.show() return fig def plot_plane_diagnostics( plane_dir: str | Path, save_path: str | Path = None, figsize: tuple = (16, 14), n_examples: int = 4, ) -> plt.Figure: """ Generate a single-figure diagnostic summary for a processed plane. Creates a publication-quality figure showing: - ROI size distribution (accepted vs rejected) - SNR distribution with quality threshold - Compactness vs SNR scatter - Summary statistics text - Zoomed ROI examples: N highest SNR and N lowest SNR cells Robust to low/zero cell counts - will display informative messages when data is insufficient for certain visualizations. Parameters ---------- plane_dir : str or Path Path to the plane directory containing ops.npy, stat.npy, etc. save_path : str or Path, optional If provided, save figure to this path. figsize : tuple, default (16, 14) Figure size in inches. n_examples : int, default 4 Number of high/low SNR ROI examples to show. Returns ------- fig : matplotlib.figure.Figure The generated figure object. """ plane_dir = Path(plane_dir) # Load results res = load_planar_results(plane_dir) ops = load_ops(plane_dir / "ops.npy") stat = res["stat"] iscell = res["iscell"] F = res["F"] Fneu = res["Fneu"] # Handle edge case: no ROIs at all n_total = len(stat) if n_total == 0: fig = plt.figure(figsize=figsize, facecolor="black") fig.text(0.5, 0.5, "No ROIs detected\n\nCheck detection parameters:\n- threshold_scaling\n- cellprob_threshold\n- diameter", ha="center", va="center", fontsize=16, fontweight="bold", color="white") plane_name = plane_dir.name fig.suptitle(f"Quality Diagnostics: {plane_name}", fontsize=14, fontweight="bold", y=0.98, color="white") if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="black") plt.close(fig) return fig # iscell from load_planar_results is (n_rois, 2): [:, 0] is 0/1, [:, 1] is probability accepted = iscell[:, 0].astype(bool) cell_prob = iscell[:, 1] # classifier probability for each ROI n_accepted = int(accepted.sum()) n_rejected = int((~accepted).sum()) # Compute metrics for ALL ROIs (not just accepted) F_corr = F - 0.7 * Fneu baseline = np.percentile(F_corr, 20, axis=1, keepdims=True) baseline = np.maximum(baseline, 1e-6) dff = (F_corr - baseline) / baseline # SNR calculation for all ROIs signal = np.std(dff, axis=1) noise = np.median(np.abs(np.diff(dff, axis=1)), axis=1) / 0.6745 snr = signal / (noise + 1e-6) # Extract ROI properties npix = np.array([s.get("npix", 0) for s in stat]) compactness = np.array([s.get("compact", np.nan) for s in stat]) skewness = np.array([s.get("skew", np.nan) for s in stat]) fs = ops.get("fs", 30.0) # Compute stats with safe defaults snr_acc = snr[accepted] if n_accepted > 0 else np.array([np.nan]) npix_acc = npix[accepted] if n_accepted > 0 else np.array([0]) mean_snr = np.nanmean(snr_acc) if n_accepted > 0 else 0.0 median_snr = np.nanmedian(snr_acc) if n_accepted > 0 else 0.0 high_snr_pct = 100 * np.sum(snr_acc > 2) / max(1, len(snr_acc)) if n_accepted > 0 else 0.0 mean_size = np.mean(npix_acc) if n_accepted > 0 else 0.0 # Get mean image for ROI zoom panels mean_img = ops.get("meanImgE", ops.get("meanImg")) # Create figure with custom layout - dark background like consolidate.ipynb # Row 0: Size dist, SNR dist, SNR vs Compactness, Activity vs SNR (4 panels) # Row 1: High SNR ROI zooms (n_examples panels) # Row 2: High SNR ROI traces # Row 3: Low SNR ROI zooms (n_examples panels) # Row 4: Low SNR ROI traces fig = plt.figure(figsize=(figsize[0], figsize[1] + 2), facecolor="black") # use nested gridspec: top row has more spacing, bottom rows are tight # Increased gap between top plots (bottom=0.62) and ROI images (top=0.48) # Added 5th row as spacer between high and low SNR groups gs_top = gridspec.GridSpec(1, 4, figure=fig, left=0.06, right=0.98, top=0.95, bottom=0.62, wspace=0.35) gs_bottom = gridspec.GridSpec(5, max(4, n_examples), figure=fig, left=0.02, right=0.98, top=0.48, bottom=0.02, hspace=0.02, wspace=0.08, height_ratios=[1, 0.4, 0.15, 1, 0.4]) # compute activity metric: number of transients (peaks above 2 std) if n_accepted > 0: dff_acc = dff[accepted] activity = np.sum(dff_acc > 2, axis=1) # count frames above 2 std else: activity = np.array([]) # compute shot noise per ROI (standardized noise metric) # shot_noise = median(|diff(dff)|) / sqrt(fs) * 100 (in %/sqrt(Hz)) frame_diffs = np.abs(np.diff(dff, axis=1)) shot_noise = np.median(frame_diffs, axis=1) / np.sqrt(fs) * 100 # Panel 1: ROI size distribution - use step histogram for clarity ax_size = fig.add_subplot(gs_top[0, 0]) ax_size.set_facecolor("black") all_npix = npix[npix > 0] if len(all_npix) > 0: bins = np.linspace(0, np.percentile(all_npix, 99), 40) # Use step histograms with distinct line styles for clear separation if n_accepted > 0: ax_size.hist(npix[accepted], bins=bins, histtype="stepfilled", alpha=0.7, color="#2ecc71", edgecolor="#2ecc71", linewidth=1.5, label=f"Accepted ({n_accepted})") if n_rejected > 0: ax_size.hist(npix[~accepted], bins=bins, histtype="step", color="#e74c3c", linewidth=2, linestyle="-", label=f"Rejected ({n_rejected})") ax_size.legend(fontsize=7, facecolor="#1a1a1a", edgecolor="white", labelcolor="white", loc="upper right") else: ax_size.text(0.5, 0.5, "No ROI data", ha="center", va="center", fontsize=12, color="white") ax_size.set_xlabel("Size (pixels)", fontweight="bold", fontsize=9, color="white") ax_size.set_ylabel("Count", fontweight="bold", fontsize=9, color="white") ax_size.set_title("ROI Size", fontweight="bold", fontsize=10, color="white") ax_size.tick_params(colors="white", labelsize=8) ax_size.spines["top"].set_visible(False) ax_size.spines["right"].set_visible(False) ax_size.spines["bottom"].set_color("white") ax_size.spines["left"].set_color("white") # Panel 2: SNR distribution - use step histogram for clarity ax_snr = fig.add_subplot(gs_top[0, 1]) ax_snr.set_facecolor("black") all_snr = snr[~np.isnan(snr)] if len(all_snr) > 0: bins = np.linspace(0, np.percentile(all_snr, 99), 40) # Filled for accepted, outline for rejected - no overlap confusion if n_accepted > 0: ax_snr.hist(snr[accepted], bins=bins, histtype="stepfilled", alpha=0.7, color="#2ecc71", edgecolor="#2ecc71", linewidth=1.5, label=f"Accepted ({n_accepted})") ax_snr.axvline(median_snr, color="#ffe66d", linestyle="-", linewidth=2, label=f"Median={median_snr:.1f}") if n_rejected > 0: ax_snr.hist(snr[~accepted], bins=bins, histtype="step", color="#e74c3c", linewidth=2, linestyle="-", label=f"Rejected ({n_rejected})") ax_snr.legend(fontsize=7, facecolor="#1a1a1a", edgecolor="white", labelcolor="white", loc="upper right") else: ax_snr.text(0.5, 0.5, "No SNR data", ha="center", va="center", fontsize=12, color="white") ax_snr.set_xlabel("SNR", fontweight="bold", fontsize=9, color="white") ax_snr.set_ylabel("Count", fontweight="bold", fontsize=9, color="white") ax_snr.set_title("SNR Distribution", fontweight="bold", fontsize=10, color="white") ax_snr.tick_params(colors="white", labelsize=8) ax_snr.spines["top"].set_visible(False) ax_snr.spines["right"].set_visible(False) ax_snr.spines["bottom"].set_color("white") ax_snr.spines["left"].set_color("white") # Panels 3 & 4: Compactness vs SNR and Activity vs SNR (shared Y-axis = SNR) # Color by skewness (activity pattern quality metric) ax_compact = fig.add_subplot(gs_top[0, 2]) ax_activity = fig.add_subplot(gs_top[0, 3], sharey=ax_compact) ax_compact.set_facecolor("black") ax_activity.set_facecolor("black") has_scatter_data = False if n_accepted > 0: valid_compact = accepted & ~np.isnan(compactness) & ~np.isnan(skewness) valid_activity = accepted & ~np.isnan(skewness) snr_acc = snr[accepted] skew_acc = skewness[accepted] # Get shared color limits from skewness (more informative than SNR for color) valid_skew = skew_acc[~np.isnan(skew_acc)] if len(valid_skew) > 0: vmin, vmax = np.nanpercentile(valid_skew, [5, 95]) else: vmin, vmax = 0, 1 if valid_compact.sum() > 0: # Panel 3: Compactness vs SNR (SNR on y-axis) sc1 = ax_compact.scatter(compactness[valid_compact], snr[valid_compact], c=skewness[valid_compact], cmap="plasma", alpha=0.7, s=20, vmin=vmin, vmax=vmax) has_scatter_data = True if len(activity) > 0 and valid_activity.sum() > 0: # Panel 4: Activity vs SNR (SNR on y-axis) sc2 = ax_activity.scatter(activity, snr_acc, c=skew_acc, cmap="plasma", alpha=0.7, s=20, vmin=vmin, vmax=vmax) # Add single colorbar for both plots (attached to activity plot) cbar = plt.colorbar(sc2, ax=ax_activity, shrink=0.8) cbar.set_label("Skewness", fontsize=8, color="white") cbar.ax.yaxis.set_tick_params(color="white") cbar.outline.set_edgecolor("white") plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="white") if not has_scatter_data: ax_compact.text(0.5, 0.5, "No data", ha="center", va="center", fontsize=12, color="white") ax_activity.text(0.5, 0.5, "No data", ha="center", va="center", fontsize=12, color="white") ax_compact.set_xlabel("Compactness", fontweight="bold", fontsize=9, color="white") ax_compact.set_ylabel("SNR", fontweight="bold", fontsize=9, color="white") ax_compact.set_title("Compactness vs SNR", fontweight="bold", fontsize=10, color="white") ax_compact.tick_params(colors="white", labelsize=8) ax_compact.spines["top"].set_visible(False) ax_compact.spines["right"].set_visible(False) ax_compact.spines["bottom"].set_color("white") ax_compact.spines["left"].set_color("white") ax_activity.set_xlabel("Active Frames", fontweight="bold", fontsize=9, color="white") ax_activity.set_ylabel("SNR", fontweight="bold", fontsize=9, color="white") ax_activity.set_title("Activity vs SNR", fontweight="bold", fontsize=10, color="white") ax_activity.tick_params(colors="white", labelsize=8) # Hide y-axis labels on right plot since it shares y-axis with left plt.setp(ax_activity.get_yticklabels(), visible=False) ax_activity.spines["top"].set_visible(False) ax_activity.spines["right"].set_visible(False) ax_activity.spines["bottom"].set_color("white") ax_activity.spines["left"].set_color("white") # Helper function to plot zoomed ROI def plot_roi_zoom(ax, roi_idx, img, stat_entry, snr_val, noise_val, color): """Plot a zoomed view of a single ROI with SNR and shot noise.""" ax.set_facecolor("black") ypix = stat_entry["ypix"] xpix = stat_entry["xpix"] # Calculate bounding box with padding pad = 15 y_min, y_max = max(0, ypix.min() - pad), min(img.shape[0], ypix.max() + pad) x_min, x_max = max(0, xpix.min() - pad), min(img.shape[1], xpix.max() + pad) # Extract ROI region roi_img = img[y_min:y_max, x_min:x_max] if roi_img.size == 0: ax.text(0.5, 0.5, "No image", ha="center", va="center", fontsize=10, color="white") ax.axis("off") return vmin, vmax = np.nanpercentile(roi_img, [1, 99]) ax.imshow(roi_img, cmap="gray", vmin=vmin, vmax=vmax, aspect="equal") # Draw ROI outline (shifted to local coordinates) local_y = ypix - y_min local_x = xpix - x_min ax.scatter(local_x, local_y, c=color, s=3, alpha=0.7, linewidths=0) # Title with SNR and shot noise ax.set_title(f"#{roi_idx} SNR={snr_val:.1f} σ={noise_val:.2f}", fontsize=7, fontweight="bold", color=color) ax.axis("off") # Helper function to plot a trace snippet def plot_roi_trace(ax, trace, color, window_frames=500): """Plot a short trace snippet with shrunk Y axis.""" ax.set_facecolor("black") # show first N frames or all if shorter n_show = min(window_frames, len(trace)) trace_segment = trace[:n_show] ax.plot(trace_segment, color=color, linewidth=0.8, alpha=0.9) ax.set_xlim(0, n_show) # Shrink Y axis to 5th-95th percentile to reduce whitespace if len(trace_segment) > 0: y_lo, y_hi = np.nanpercentile(trace_segment, [5, 95]) y_range = y_hi - y_lo if y_range > 0: ax.set_ylim(y_lo - 0.1 * y_range, y_hi + 0.1 * y_range) ax.axis("off") # Row 0-1: High SNR ROI examples with traces if n_accepted > 0 and mean_img is not None: accepted_idx = np.where(accepted)[0] snr_accepted = snr[accepted] n_show = min(n_examples, n_accepted) # Get indices of highest SNR cells top_snr_order = np.argsort(snr_accepted)[::-1][:n_show] for i in range(n_examples): # ROI image ax = fig.add_subplot(gs_bottom[0, i]) if i < n_show: local_idx = top_snr_order[i] global_idx = accepted_idx[local_idx] plot_roi_zoom(ax, global_idx, mean_img, stat[global_idx], snr[global_idx], shot_noise[global_idx], "#2ecc71") # trace below ax_trace = fig.add_subplot(gs_bottom[1, i]) plot_roi_trace(ax_trace, dff[global_idx], "#2ecc71") else: ax.set_facecolor("black") ax.axis("off") ax_trace = fig.add_subplot(gs_bottom[1, i]) ax_trace.set_facecolor("black") ax_trace.axis("off") # Row 2 is spacer (empty) # Row 3-4: Low SNR ROI examples with traces bottom_snr_order = np.argsort(snr_accepted)[:n_show] for i in range(n_examples): # ROI image ax = fig.add_subplot(gs_bottom[3, i]) if i < n_show: local_idx = bottom_snr_order[i] global_idx = accepted_idx[local_idx] plot_roi_zoom(ax, global_idx, mean_img, stat[global_idx], snr[global_idx], shot_noise[global_idx], "#ff6b6b") # trace below ax_trace = fig.add_subplot(gs_bottom[4, i]) plot_roi_trace(ax_trace, dff[global_idx], "#ff6b6b") else: ax.set_facecolor("black") ax.axis("off") ax_trace = fig.add_subplot(gs_bottom[4, i]) ax_trace.set_facecolor("black") ax_trace.axis("off") elif n_rejected > 0 and mean_img is not None: # Show rejected ROIs for diagnostics rejected_idx = np.where(~accepted)[0] snr_rejected = snr[~accepted] n_show = min(n_examples, n_rejected) # High SNR rejected top_snr_order = np.argsort(snr_rejected)[::-1][:n_show] for i in range(n_examples): ax = fig.add_subplot(gs_bottom[0, i]) if i < n_show: local_idx = top_snr_order[i] global_idx = rejected_idx[local_idx] plot_roi_zoom(ax, global_idx, mean_img, stat[global_idx], snr[global_idx], shot_noise[global_idx], "#ff6b6b") ax_trace = fig.add_subplot(gs_bottom[1, i]) plot_roi_trace(ax_trace, dff[global_idx], "#ff6b6b") else: ax.set_facecolor("black") ax.axis("off") ax_trace = fig.add_subplot(gs_bottom[1, i]) ax_trace.set_facecolor("black") ax_trace.axis("off") # Row 2 is spacer # Low SNR rejected bottom_snr_order = np.argsort(snr_rejected)[:n_show] for i in range(n_examples): ax = fig.add_subplot(gs_bottom[3, i]) if i < n_show: local_idx = bottom_snr_order[i] global_idx = rejected_idx[local_idx] plot_roi_zoom(ax, global_idx, mean_img, stat[global_idx], snr[global_idx], shot_noise[global_idx], "#ff6b6b") ax_trace = fig.add_subplot(gs_bottom[4, i]) plot_roi_trace(ax_trace, dff[global_idx], "#ff6b6b") else: ax.set_facecolor("black") ax.axis("off") ax_trace = fig.add_subplot(gs_bottom[4, i]) ax_trace.set_facecolor("black") ax_trace.axis("off") else: # No image available for row in [0, 1, 3, 4]: # Skip spacer row 2 for i in range(n_examples): ax = fig.add_subplot(gs_bottom[row, i]) ax.set_facecolor("black") if row in [0, 3]: ax.text(0.5, 0.5, "No data", ha="center", va="center", fontsize=8, color="white") ax.axis("off") # Main title plane_name = plane_dir.name fig.suptitle(f"Quality Diagnostics: {plane_name}", fontsize=14, fontweight="bold", y=0.98, color="white") # No tight_layout - we use manual GridSpec positioning for precise control if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="black") plt.close(fig) else: plt.show() return fig def mask_dead_zones_in_ops(ops, threshold=0.01): """ Mask out dead zones from registration shifts in ops image arrays. Dead zones appear as very dark regions (near zero intensity) at the edges of images after suite3D alignment shifts are applied. Parameters ---------- ops : dict Suite2p ops dictionary containing image arrays threshold : float Fraction of max intensity to use as cutoff (default 0.01 = 1%) Returns ------- ops : dict Modified ops with dead zones set to NaN in image arrays """ if "meanImg" not in ops: return ops # Use meanImg to identify valid regions mean_img = ops["meanImg"] valid_mask = mean_img > (mean_img.max() * threshold) n_invalid = (~valid_mask).sum() if n_invalid > 0: pct_invalid = 100 * n_invalid / valid_mask.size print(f"[mask_dead_zones] Masking {n_invalid} ({pct_invalid:.1f}%) dead zone pixels") # Mask all image arrays in ops for key in ["meanImg", "meanImgE", "max_proj", "Vcorr"]: if key in ops and isinstance(ops[key], np.ndarray): img = ops[key] # Only apply mask if shapes match if img.shape == valid_mask.shape: # Convert to float and set invalid regions to NaN ops[key] = img.astype(float) ops[key][~valid_mask] = np.nan else: print(f"[mask_dead_zones] Skipping {key}: shape {img.shape} != meanImg shape {valid_mask.shape}") return ops def plot_zplane_figures( plane_dir, dff_percentile=8, dff_window_size=None, dff_smooth_window=None, run_rastermap=False, **kwargs ): """ Re-generate Suite2p figures for a merged plane. Parameters ---------- plane_dir : Path Path to the planeXX output directory (with ops.npy, stat.npy, etc.). dff_percentile : int, optional Percentile used for ΔF/F baseline. dff_window_size : int, optional Window size for ΔF/F rolling baseline. If None, auto-calculated as ~10 × tau × fs based on ops values. dff_smooth_window : int, optional Temporal smoothing window for dF/F traces (in frames). If None, auto-calculated as ~0.5 × tau × fs to emphasize transients while reducing noise. Set to 1 to disable. run_rastermap : bool, optional If True, compute and plot rastermap sorting of cells. kwargs : dict Extra keyword args (e.g. fig_label). """ plane_dir = Path(plane_dir) # File naming convention: numbered prefixes ensure proper alphabetical ordering # 01_correlation -> 02_max_projection -> 03_mean -> 04_mean_enhanced # each image immediately followed by its _segmentation variant expected_files = { "ops": plane_dir / "ops.npy", "stat": plane_dir / "stat.npy", "iscell": plane_dir / "iscell.npy", # Summary images with segmentation overlays - numbered for proper ordering "correlation_image": plane_dir / "01_correlation.png", "correlation_segmentation": plane_dir / "01_correlation_segmentation.png", "max_proj": plane_dir / "02_max_projection.png", "max_proj_segmentation": plane_dir / "02_max_projection_segmentation.png", "meanImg": plane_dir / "03_mean.png", "meanImg_segmentation": plane_dir / "03_mean_segmentation.png", "meanImgE": plane_dir / "04_mean_enhanced.png", "meanImgE_segmentation": plane_dir / "04_mean_enhanced_segmentation.png", # Diagnostics and analysis "quality_diagnostics": plane_dir / "05_quality_diagnostics.png", "registration": plane_dir / "06_registration.png", # Traces - multiple cell counts "traces_raw_20": plane_dir / "07a_traces_raw_20.png", "traces_raw_50": plane_dir / "07b_traces_raw_50.png", "traces_raw_100": plane_dir / "07c_traces_raw_100.png", "traces_dff_20": plane_dir / "08a_traces_dff_20.png", "traces_dff_50": plane_dir / "08b_traces_dff_50.png", "traces_dff_100": plane_dir / "08c_traces_dff_100.png", "traces_rejected": plane_dir / "09_traces_rejected.png", # Noise distributions "noise_acc": plane_dir / "10_shot_noise_accepted.png", "noise_rej": plane_dir / "11_shot_noise_rejected.png", # Rastermap "model": plane_dir / "model.npy", "rastermap": plane_dir / "12_rastermap.png", # Regional zoom "regional_zoom": plane_dir / "13_regional_zoom.png", } output_ops = load_ops(expected_files["ops"]) # Dead zones are now handled via yrange/xrange cropping in run_lsp.py # so we don't need to mask them here anymore # output_ops = mask_dead_zones_in_ops(output_ops) # force remake of the heavy figures for key in [ "registration", "traces_raw_20", "traces_raw_50", "traces_raw_100", "traces_dff_20", "traces_dff_50", "traces_dff_100", "traces_rejected", "noise_acc", "noise_rej", "rastermap", ]: if key in expected_files: if expected_files[key].exists(): try: expected_files[key].unlink() except PermissionError: print(f"Error: Cannot delete {expected_files[key]}, it's open elsewhere.") if expected_files["stat"].is_file(): res = load_planar_results(plane_dir) # iscell is (n_rois, 2): [:, 0] is 0/1, [:, 1] is classifier probability iscell_mask = res["iscell"][:, 0].astype(bool) cell_prob = res["iscell"][:, 1] spks = res["spks"] F = res["F"] # Split by accepted/rejected F_accepted = F[iscell_mask] if iscell_mask.sum() > 0 else np.zeros((0, F.shape[1])) F_rejected = F[~iscell_mask] if (~iscell_mask).sum() > 0 else np.zeros((0, F.shape[1])) spks_cells = spks[iscell_mask] if iscell_mask.sum() > 0 else np.zeros((0, spks.shape[1])) n_accepted = F_accepted.shape[0] n_rejected = F_rejected.shape[0] print(f"Plotting results for {n_accepted} accepted / {n_rejected} rejected ROIs") # Rastermap (only for sufficient cell counts) # rastermap sorts neurons by activity similarity for visualization # we cache the model to avoid recomputing, but validate it matches current data model = None if run_rastermap and n_accepted >= 2: try: from lbm_suite2p_python.zplane import plot_rastermap import rastermap has_rastermap = True except ImportError: print("rastermap not found. Install via: pip install rastermap") print(" or: pip install mbo_utilities[rastermap]") has_rastermap = False rastermap, plot_rastermap = None, None if has_rastermap: model_file = expected_files["model"] plot_file = expected_files["rastermap"] need_recompute = True # check if cached model exists and is valid for current cell count if model_file.is_file(): try: cached_model = np.load(model_file, allow_pickle=True).item() # Handle both direct model objects and dict wrappers if hasattr(cached_model, "isort"): cached_isort = cached_model.isort elif isinstance(cached_model, dict) and "isort" in cached_model: cached_isort = cached_model["isort"] else: cached_isort = None if cached_isort is not None and len(cached_isort) == n_accepted: model = cached_model need_recompute = False print(f" Using cached rastermap model ({n_accepted} cells)") else: # stale model - cell count changed since last run cached_len = len(cached_isort) if cached_isort is not None else "?" print(f" Rastermap model stale (cached {cached_len} vs current {n_accepted} cells), recomputing...") model_file.unlink() except Exception as e: print(f" Failed to load cached rastermap model: {e}, recomputing...") model_file.unlink(missing_ok=True) # fit new model if needed if need_recompute: print(f" Computing rastermap model for {n_accepted} cells...") params = { "n_clusters": 100 if n_accepted >= 200 else None, "n_PCs": min(128, max(2, n_accepted - 1)), "locality": 0.0 if n_accepted >= 200 else 0.1, "time_lag_window": 15, "grid_upsample": 10 if n_accepted >= 200 else 0, } model = rastermap.Rastermap(**params).fit(spks_cells) np.save(model_file, model) # regenerate plot if missing (even if model was cached) if model is not None and not plot_file.is_file(): plot_rastermap( spks_cells, model, neuron_bin_size=0, save_path=plot_file, title_kwargs={"fontsize": 8, "y": 0.95}, title="Rastermap Sorted Activity", ) # apply sorting to traces for downstream plots if model is not None: # Handle both direct model objects and dict wrappers if hasattr(model, "isort"): isort = model.isort elif isinstance(model, dict) and "isort" in model: isort = model["isort"] else: isort = None if isort is not None: isort_global = np.where(iscell_mask)[0][isort] output_ops["isort"] = isort_global F_accepted = F_accepted[isort] # Compute dF/F fs = output_ops.get("fs", 1.0) tau = output_ops.get("tau", 1.0) # Compute unsmoothed dF/F for shot noise (smoothing reduces frame-to-frame variance) if n_accepted > 0: dffp_acc_unsmoothed = dff_rolling_percentile( F_accepted, percentile=dff_percentile, window_size=dff_window_size, smooth_window=1, # No smoothing for shot noise fs=fs, tau=tau, ) * 100 # Smoothed version for trace plotting dffp_acc = dff_rolling_percentile( F_accepted, percentile=dff_percentile, window_size=dff_window_size, smooth_window=dff_smooth_window, fs=fs, tau=tau, ) * 100 else: dffp_acc_unsmoothed = np.zeros((0, F.shape[1])) dffp_acc = np.zeros((0, F.shape[1])) if n_rejected > 0: dffp_rej_unsmoothed = dff_rolling_percentile( F_rejected, percentile=dff_percentile, window_size=dff_window_size, smooth_window=1, # No smoothing for shot noise fs=fs, tau=tau, ) * 100 # Smoothed version for trace plotting dffp_rej = dff_rolling_percentile( F_rejected, percentile=dff_percentile, window_size=dff_window_size, smooth_window=dff_smooth_window, fs=fs, tau=tau, ) * 100 else: dffp_rej_unsmoothed = np.zeros((0, F.shape[1])) dffp_rej = np.zeros((0, F.shape[1])) # Trace plots (robust to any cell count >= 1) # Sort traces by quality score (SNR, skewness, shot noise) for visualization # Generate plots with 20, 50, and 100 cells if available if n_accepted > 0: # Get accepted cell stat for skewness stat_accepted = [s for s, m in zip(res["stat"], iscell_mask) if m] # Compute quality scores and sort quality = compute_trace_quality_score( F_accepted, Fneu=res["Fneu"][iscell_mask] if "Fneu" in res else None, stat=stat_accepted, fs=fs, ) quality_sort_idx = quality["sort_idx"] # Sort traces by quality (best first) dffp_acc_sorted = dffp_acc[quality_sort_idx] F_accepted_sorted = F_accepted[quality_sort_idx] # Generate trace plots at multiple cell counts cell_counts = [20, 50, 100] for n_cells in cell_counts: if n_accepted >= n_cells: # dF/F traces (percent) plot_traces( dffp_acc_sorted, save_path=expected_files[f"traces_dff_{n_cells}"], num_neurons=n_cells, scale_bar_unit=r"% $\Delta$F/F$_0$", title=rf"Top {n_cells} $\Delta$F/F Traces by Quality (n={n_accepted} total)", ) # Raw traces plot_traces( F_accepted_sorted, save_path=expected_files[f"traces_raw_{n_cells}"], num_neurons=n_cells, scale_bar_unit="a.u.", title=f"Top {n_cells} Raw Traces by Quality (n={n_accepted} total)", ) elif n_cells == 20: # Always generate 20-cell plot even if fewer cells available plot_traces( dffp_acc_sorted, save_path=expected_files["traces_dff_20"], num_neurons=min(20, n_accepted), scale_bar_unit=r"% $\Delta$F/F$_0$", title=rf"Top {min(20, n_accepted)} $\Delta$F/F Traces by Quality (n={n_accepted} total)", ) plot_traces( F_accepted_sorted, save_path=expected_files["traces_raw_20"], num_neurons=min(20, n_accepted), scale_bar_unit="a.u.", title=f"Top {min(20, n_accepted)} Raw Traces by Quality (n={n_accepted} total)", ) else: print(" No accepted cells - skipping accepted trace plots") if n_rejected > 0: plot_traces( dffp_rej, save_path=expected_files["traces_rejected"], num_neurons=min(20, n_rejected), scale_bar_unit=r"% $\Delta$F/F$_0$", title=rf"$\Delta$F/F Traces - Rejected ROIs (n={n_rejected})", ) else: print(" No rejected ROIs - skipping rejected trace plots") # Noise distributions (robust to any cell count >= 1) # Use unsmoothed dF/F for shot noise (smoothing artificially reduces noise) if n_accepted > 0: dff_noise_acc = dff_shot_noise(dffp_acc_unsmoothed, fs) plot_noise_distribution( dff_noise_acc, output_filename=expected_files["noise_acc"], title=f"Shot-Noise Distribution (Accepted, n={n_accepted})", ) if n_rejected > 0: dff_noise_rej = dff_shot_noise(dffp_rej_unsmoothed, fs) plot_noise_distribution( dff_noise_rej, output_filename=expected_files["noise_rej"], title=f"Shot-Noise Distribution (Rejected, n={n_rejected})", ) # Segmentation overlays # Suite2p stores images in two coordinate systems: # - FULL space: refImg, meanImg, meanImgE (same size as original Ly x Lx) # - CROPPED space: max_proj, Vcorr (size determined by yrange/xrange after registration) # The stat coordinates are in FULL image space. stat_full = res["stat"] # stat coordinates in full image space # Helper to check if image is valid def _is_valid_image(img): if img is None: return False if isinstance(img, (int, float)) and img == 0: return False if isinstance(img, np.ndarray) and img.size == 0: return False return True # Get crop parameters for images in cropped space yrange = output_ops.get("yrange", [0, output_ops.get("Ly", 512)]) xrange = output_ops.get("xrange", [0, output_ops.get("Lx", 512)]) ymin, xmin = int(yrange[0]), int(xrange[0]) # Create stat with adjusted coordinates for cropped image space if ymin > 0 or xmin > 0: stat_cropped = [] for s in stat_full: s_adj = s.copy() s_adj["ypix"] = s["ypix"] - ymin s_adj["xpix"] = s["xpix"] - xmin stat_cropped.append(s_adj) else: stat_cropped = stat_full # Images in FULL space - use stat_full full_space_images = { "meanImg": ("Mean Image", expected_files["meanImg_segmentation"]), "meanImgE": ("Enhanced Mean Image", expected_files["meanImgE_segmentation"]), } for img_key, (title_name, save_file) in full_space_images.items(): img = output_ops.get(img_key) if _is_valid_image(img): if n_accepted > 0: plot_masks( img=img, stat=stat_full, mask_idx=iscell_mask, savepath=save_file, title=f"{title_name} - Accepted ROIs (n={n_accepted})" ) else: plot_projection( output_ops, save_file, fig_label=kwargs.get("fig_label", plane_dir.stem), display_masks=False, add_scalebar=True, proj=img_key, ) # Images in CROPPED space - use stat_cropped cropped_space_images = { "max_proj": ("Max Projection", expected_files["max_proj_segmentation"]), } for img_key, (title_name, save_file) in cropped_space_images.items(): img = output_ops.get(img_key) if _is_valid_image(img): if n_accepted > 0: plot_masks( img=img, stat=stat_cropped, mask_idx=iscell_mask, savepath=save_file, title=f"{title_name} - Accepted ROIs (n={n_accepted})" ) else: plot_projection( output_ops, save_file, fig_label=kwargs.get("fig_label", plane_dir.stem), display_masks=False, add_scalebar=True, proj=img_key, ) # Correlation image (Vcorr) - in CROPPED space vcorr = output_ops.get("Vcorr") if _is_valid_image(vcorr): # Save correlation image without masks fig, ax = plt.subplots(figsize=(8, 8), facecolor="black") ax.set_facecolor("black") ax.imshow(vcorr, cmap="gray") ax.set_title("Correlation Image", color="white", fontweight="bold") ax.axis("off") plt.tight_layout() plt.savefig(expected_files["correlation_image"], dpi=150, facecolor="black") plt.close(fig) # Correlation image with segmentation if n_accepted > 0: plot_masks( img=vcorr, stat=stat_cropped, mask_idx=iscell_mask, savepath=expected_files["correlation_segmentation"], title=f"Correlation Image - Accepted ROIs (n={n_accepted})" ) # Summary images (no masks) - always generated fig_label = kwargs.get("fig_label", plane_dir.stem) for key in ["meanImg", "max_proj", "meanImgE"]: if key in output_ops and output_ops[key] is not None: try: plot_projection( output_ops, expected_files[key], fig_label=fig_label, display_masks=False, add_scalebar=True, proj=key, ) except Exception as e: print(f" Failed to plot {key}: {e}") # Quality diagnostics try: plot_plane_diagnostics(plane_dir, save_path=expected_files["quality_diagnostics"]) except Exception as e: print(f" Failed to generate quality diagnostics: {e}") # Regional zoom try: plot_regional_zoom( plane_dir, zoom_size=150, img_key="meanImgE", save_path=expected_files["regional_zoom"], ) except Exception as e: print(f" Failed to generate regional zoom: {e}") return output_ops def normalize99(img): """ Normalize image using 1st and 99th percentile values. This is a robust normalization that clips outliers and scales the image to the [0, 1] range based on the 1st and 99th percentile values. Parameters ---------- img : numpy.ndarray Input image array of any shape. Returns ------- numpy.ndarray Normalized image with values clipped to [0, 1]. Examples -------- >>> img = np.random.rand(100, 100) * 1000 >>> normalized = normalize99(img) >>> assert 0 <= normalized.min() <= normalized.max() <= 1 """ p1, p99 = np.percentile(img, [1, 99]) return np.clip((img - p1) / (p99 - p1 + 1e-8), 0, 1) def apply_hp_filter(img, diameter, spatial_hp_cp): """ Apply high-pass filter to image for Cellpose preprocessing. This replicates Suite2p's anatomical detection preprocessing, which normalizes the image and then subtracts a Gaussian-smoothed version to enhance cell boundaries. Parameters ---------- img : numpy.ndarray Input 2D image (e.g., mean image or max projection). diameter : int or float Expected cell diameter in pixels. Used to calculate the Gaussian sigma as ``diameter * spatial_hp_cp``. spatial_hp_cp : float High-pass filter strength multiplier. Common values: - 0: No filtering (return normalized image) - 0.5: LBM default, mild filtering - 2.0: Strong filtering, enhances small features Returns ------- numpy.ndarray High-pass filtered image. See Also -------- normalize99 : Used internally for percentile normalization. Examples -------- >>> from scipy.ndimage import gaussian_filter >>> img = np.random.rand(256, 256) >>> filtered = apply_hp_filter(img, diameter=6, spatial_hp_cp=0.5) """ from scipy.ndimage import gaussian_filter img_norm = normalize99(img) if spatial_hp_cp > 0: sigma = diameter * spatial_hp_cp img_hp = img_norm - gaussian_filter(img_norm, sigma) else: img_hp = img_norm return img_hp def random_colors_for_mask(mask, seed=42): """ Generate random distinct colors for each cell ID in a mask. Uses HSV color space to generate visually distinct colors for each unique cell label in the mask. Parameters ---------- mask : numpy.ndarray 2D integer array where each unique positive value represents a different cell. Background should be 0. seed : int, optional Random seed for reproducibility. Default is 42. Returns ------- numpy.ndarray RGB image of shape ``(Ly, Lx, 3)`` with float32 values in [0, 1]. See Also -------- mask_overlay : Uses this function to colorize masks. stat_to_mask : Converts Suite2p stat to mask array. Examples -------- >>> mask = np.zeros((100, 100), dtype=np.int32) >>> mask[10:20, 10:20] = 1 >>> mask[30:40, 30:40] = 2 >>> colors = random_colors_for_mask(mask) >>> assert colors.shape == (100, 100, 3) """ from matplotlib.colors import hsv_to_rgb n_cells = mask.max() if n_cells == 0: return np.zeros((*mask.shape, 3), dtype=np.float32) # Generate random colors using HSV for better distinction np.random.seed(seed) hues = np.random.rand(n_cells + 1) saturations = 0.7 + 0.3 * np.random.rand(n_cells + 1) values = 0.8 + 0.2 * np.random.rand(n_cells + 1) # Convert HSV to RGB colors = np.zeros((n_cells + 1, 3)) for i in range(1, n_cells + 1): colors[i] = hsv_to_rgb([hues[i], saturations[i], values[i]]) # Map colors to mask rgb = colors[mask] return rgb.astype(np.float32) def mask_overlay(img, mask, alpha=0.5): """ Overlay colored masks on a grayscale image. Creates a visualization where detected cells are shown as colored regions blended with the underlying grayscale image. Parameters ---------- img : numpy.ndarray 2D grayscale background image (e.g., mean image, max projection). mask : numpy.ndarray 2D integer mask where each positive value represents a different cell. Background should be 0. alpha : float, optional Blending factor for mask overlay. 0 = fully transparent, 1 = fully opaque. Default is 0.5. Returns ------- numpy.ndarray RGB image of shape ``(Ly, Lx, 3)`` with float32 values in [0, 1]. See Also -------- random_colors_for_mask : Generates colors for each cell. stat_to_mask : Converts Suite2p stat to mask. plot_mask_comparison : Uses this for multi-panel visualization. Examples -------- >>> img = np.random.rand(256, 256) >>> mask = np.zeros((256, 256), dtype=np.int32) >>> mask[50:100, 50:100] = 1 >>> overlay = mask_overlay(img, mask, alpha=0.5) >>> assert overlay.shape == (256, 256, 3) """ img_norm = normalize99(img) rgb = np.stack([img_norm] * 3, axis=-1).astype(np.float32) if mask.max() > 0: colors = random_colors_for_mask(mask) mask_px = mask > 0 rgb[mask_px] = (1 - alpha) * rgb[mask_px] + alpha * colors[mask_px] return rgb def get_background_image(ops, img_key="max_proj"): """ Get background image and coordinate offsets from ops. Handles the coordinate space difference between full-FOV images (meanImg, meanImgE, refImg) and cropped images (max_proj, Vcorr). Parameters ---------- ops : dict Suite2p ops dictionary. img_key : str Key for desired image: 'max_proj', 'Vcorr', 'meanImg', 'meanImgE'. Returns ------- img : np.ndarray Background image. yoff : int Y offset to subtract from stat coordinates. xoff : int X offset to subtract from stat coordinates. """ Ly = ops.get("Ly", 512) Lx = ops.get("Lx", 512) yrange = ops.get("yrange", [0, Ly]) xrange = ops.get("xrange", [0, Lx]) # cropped images need coordinate adjustment cropped_keys = {"max_proj", "Vcorr"} if img_key in ops: img = ops[img_key] if img_key in cropped_keys: yoff, xoff = int(yrange[0]), int(xrange[0]) else: yoff, xoff = 0, 0 else: # fallback to meanImg (full space) img = ops.get("meanImg", np.zeros((Ly, Lx))) yoff, xoff = 0, 0 return img, yoff, xoff def stat_to_mask(stat, Ly, Lx, yoff=0, xoff=0): """ Convert Suite2p stat array to a 2D labeled mask. Each cell is assigned a unique integer label starting from 1. Background pixels are 0. Parameters ---------- stat : numpy.ndarray or list Array of Suite2p stat dictionaries, each containing 'ypix' and 'xpix' keys with pixel coordinates. Ly : int Image height in pixels. Lx : int Image width in pixels. yoff : int, optional Y offset to subtract from stat coordinates (for cropped images). xoff : int, optional X offset to subtract from stat coordinates (for cropped images). Returns ------- numpy.ndarray 2D mask of shape ``(Ly, Lx)`` with dtype uint16. Each cell has a unique integer label, background is 0. See Also -------- mask_overlay : Uses masks for visualization. Examples -------- >>> stat = [{'ypix': np.array([10, 11]), 'xpix': np.array([20, 21])}] >>> mask = stat_to_mask(stat, Ly=100, Lx=100) >>> assert mask[10, 20] == 1 >>> assert mask[0, 0] == 0 """ mask = np.zeros((Ly, Lx), dtype=np.uint16) for i, s in enumerate(stat): ypix = s['ypix'] - yoff xpix = s['xpix'] - xoff valid = (ypix >= 0) & (ypix < Ly) & (xpix >= 0) & (xpix < Lx) mask[ypix[valid], xpix[valid]] = i + 1 return mask def plot_mask_comparison( img, results, zoom_levels=None, zoom_center=None, title=None, save_path=None, figsize=None, ): """ Create a multi-panel comparison of detection results with zoom views. Generates a grid visualization comparing different parameter combinations (e.g., diameters) with full-image views and progressively zoomed regions. Parameters ---------- img : numpy.ndarray 2D background image for overlay (e.g., max projection). results : dict Dictionary mapping names to result dicts. Each result dict should contain either: - 'masks': 2D labeled mask array, OR - 'stat': Suite2p stat array (will be converted to mask) And optionally: - 'n_cells': Number of cells (computed from mask if not provided) zoom_levels : list of int, optional List of zoom region sizes in pixels. Default is [400, 200, 100]. zoom_center : tuple of (int, int), optional Center point (cy, cx) for zoom regions. Default is image center. title : str, optional Overall figure title. save_path : str or Path, optional Path to save the figure. If None, displays with plt.show(). figsize : tuple, optional Figure size (width, height) in inches. Default is auto-calculated. Returns ------- matplotlib.figure.Figure The generated figure. See Also -------- mask_overlay : Creates individual overlays. stat_to_mask : Converts stat arrays to masks. Examples -------- >>> img = ops['max_proj'] >>> results = { ... 'd=2': {'masks': masks_d2, 'n_cells': 500}, ... 'd=4': {'masks': masks_d4, 'n_cells': 350}, ... 'd=6': {'masks': masks_d6, 'n_cells': 200}, ... } >>> fig = plot_mask_comparison(img, results, zoom_levels=[200, 100]) """ Ly, Lx = img.shape[:2] # Default zoom levels if zoom_levels is None: zoom_levels = [400, 200, 100] # Default to image center if zoom_center is None: cy, cx = Ly // 2, Lx // 2 else: cy, cx = zoom_center n_cols = len(results) n_rows = len(zoom_levels) + 1 # Full image + zoom levels if figsize is None: figsize = (5 * n_cols, 5 * n_rows) fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize) if n_cols == 1: axes = axes.reshape(-1, 1) # Color palette for zoom boxes box_colors = ['yellow', 'cyan', 'magenta', 'lime', 'orange'] for col, (name, r) in enumerate(results.items()): # Get or create mask if 'masks' in r: mask = r['masks'] elif 'stat' in r: mask = stat_to_mask(r['stat'], Ly, Lx) else: raise ValueError(f"Result '{name}' must contain 'masks' or 'stat'") # Get cell count if 'n_cells' in r: n_cells = r['n_cells'] else: n_cells = mask.max() overlay = mask_overlay(img, mask) # Full image with zoom boxes axes[0, col].imshow(overlay) axes[0, col].set_title(f"{name}: {n_cells} cells\nFull image") axes[0, col].axis('off') # Draw zoom boxes for i, zs in enumerate(zoom_levels): color = box_colors[i % len(box_colors)] rect = Rectangle( (cx - zs // 2, cy - zs // 2), zs, zs, fill=False, edgecolor=color, linewidth=2 ) axes[0, col].add_patch(rect) # Zoomed views for row, zs in enumerate(zoom_levels): y1, y2 = max(0, cy - zs // 2), min(Ly, cy + zs // 2) x1, x2 = max(0, cx - zs // 2), min(Lx, cx + zs // 2) zoom_overlay = overlay[y1:y2, x1:x2] zoom_mask = mask[y1:y2, x1:x2] n_cells_zoom = len(np.unique(zoom_mask)) - 1 # Exclude background axes[row + 1, col].imshow(zoom_overlay) axes[row + 1, col].set_title(f"{zs}x{zs} zoom: {n_cells_zoom} cells") axes[row + 1, col].axis('off') if title: fig.suptitle(title, fontsize=14, fontweight='bold') plt.tight_layout() if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.close(fig) else: plt.show() return fig def plot_regional_zoom( plane_dir, zoom_size: int = 150, img_key: str = "max_proj", alpha: float = 0.5, save_path=None, figsize: tuple = (15, 10), accepted_only: bool = True, ): """ Plot corner, edge, and center zoom views of detection results. Creates a 2x3 grid visualization showing the full image with region boxes, plus zoomed views of each corner and the center. Useful for checking detection quality across different parts of the field of view. Parameters ---------- plane_dir : str or Path Path to a Suite2p plane directory containing ops.npy, stat.npy, and optionally iscell.npy. zoom_size : int, optional Size of zoom regions in pixels. Default is 150. img_key : str, optional Key in ops to use as background image. Options: 'max_proj', 'meanImg', 'meanImgE'. Default is 'max_proj'. alpha : float, optional Blending factor for mask overlay (0-1). Default is 0.5. save_path : str or Path, optional Path to save the figure. If None, displays with plt.show(). figsize : tuple, optional Figure size (width, height) in inches. Default is (15, 10). accepted_only : bool, optional If True, only show cells marked as accepted (iscell[:, 0] == 1). Default is True. Returns ------- matplotlib.figure.Figure The generated figure. Examples -------- >>> # From pipeline results >>> ops_paths = lsp.pipeline(input_data=path, save_path=output_dir, ...) >>> for ops_path in ops_paths: ... plot_regional_zoom(ops_path.parent, zoom_size=150) >>> # With custom settings >>> plot_regional_zoom( ... "D:/results/plane01_vdaq0", ... zoom_size=200, ... img_key="meanImgE", ... save_path="regional_zoom.png" ... ) """ plane_dir = Path(plane_dir) # load results res = load_planar_results(plane_dir) ops = load_ops(plane_dir) # get background image with coordinate offsets img, yoff, xoff = get_background_image(ops, img_key) img_h, img_w = img.shape[:2] # Get stat and optionally filter by iscell stat = res["stat"] if accepted_only and "iscell" in res: iscell_mask = res["iscell"][:, 0].astype(bool) stat = stat[iscell_mask] n_cells = len(stat) # Create mask from stat (with coordinate offset for cropped images) mask = stat_to_mask(stat, img_h, img_w, yoff, xoff) # Create overlay overlay = mask_overlay(img, mask, alpha=alpha) # Define corner and edge regions cy, cx = img_h // 2, img_w // 2 zs = zoom_size regions = { "Top-Left": (0, zs, 0, zs), "Top-Right": (0, zs, img_w - zs, img_w), "Bottom-Left": (img_h - zs, img_h, 0, zs), "Bottom-Right": (img_h - zs, img_h, img_w - zs, img_w), "Center": (cy - zs // 2, cy + zs // 2, cx - zs // 2, cx + zs // 2), } # Color palette for boxes box_colors = ['red', 'blue', 'green', 'orange', 'yellow'] fig, axes = plt.subplots(2, 3, figsize=figsize) axes = axes.flatten() # Full image with boxes showing regions ax = axes[0] ax.imshow(overlay) for (name, (y1, y2, x1, x2)), c in zip(regions.items(), box_colors): rect = Rectangle( (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor=c, linewidth=2, label=name ) ax.add_patch(rect) ax.set_title(f"Full Image: {n_cells} cells\n(boxes show zoom regions)") ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.02), ncol=3, fontsize=8) ax.axis('off') # Zoomed views for ax, ((name, (y1, y2, x1, x2)), c) in zip(axes[1:], zip(regions.items(), box_colors)): zoom_mask = mask[y1:y2, x1:x2] n_zoom = len(np.unique(zoom_mask)) - 1 # Exclude background ax.imshow(overlay[y1:y2, x1:x2]) ax.set_title(f"{name}: {n_zoom} cells", color=c, fontweight='bold') ax.axis('off') # Get plane name for title plane_name = plane_dir.name diameter = ops.get("diameter", "?") plt.suptitle( f"{plane_name} - Regional Comparison ({zs}x{zs}) - d={diameter}", fontsize=14, fontweight='bold' ) plt.tight_layout() if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.close(fig) else: plt.show() return fig def plot_filtered_cells( plane_dir, iscell_original, iscell_filtered, img_key: str = "max_proj", alpha: float = 0.5, save_path=None, figsize: tuple = (18, 6), title: str = None, ): """ Plot side-by-side comparison of cells before and after filtering. Shows three panels: kept cells, removed cells, and both overlaid with different colors. Parameters ---------- plane_dir : str or Path Path to a Suite2p plane directory containing ops.npy, stat.npy. iscell_original : np.ndarray Original iscell array before filtering (n_rois,) or (n_rois, 2). iscell_filtered : np.ndarray Filtered iscell array (n_rois,) or (n_rois, 2). img_key : str, optional Key in ops to use as background image. Default is 'max_proj'. alpha : float, optional Blending factor for mask overlay (0-1). Default is 0.5. save_path : str or Path, optional Path to save the figure. If None, displays with plt.show(). figsize : tuple, optional Figure size (width, height) in inches. Default is (18, 6). title : str, optional Custom title for the figure. Returns ------- matplotlib.figure.Figure The generated figure. Examples -------- >>> from lbm_suite2p_python import filter_by_max_diameter, plot_filtered_cells >>> res = load_planar_results(plane_dir) >>> iscell_filtered = filter_by_max_diameter( ... res["iscell"], res["stat"], max_diameter_px=15 ... ) >>> plot_filtered_cells(plane_dir, res["iscell"], iscell_filtered) """ plane_dir = Path(plane_dir) # load results res = load_planar_results(plane_dir) ops = load_ops(plane_dir) # get background image with coordinate offsets img, yoff, xoff = get_background_image(ops, img_key) img_h, img_w = img.shape[:2] # normalize iscell arrays to 1D boolean if iscell_original.ndim == 2: iscell_original = iscell_original[:, 0] if iscell_filtered.ndim == 2: iscell_filtered = iscell_filtered[:, 0] iscell_original = iscell_original.astype(bool) iscell_filtered = iscell_filtered.astype(bool) stat = res["stat"] # Identify kept and removed cells kept_mask = iscell_filtered removed_mask = iscell_original & ~iscell_filtered n_kept = kept_mask.sum() n_removed = removed_mask.sum() n_original = iscell_original.sum() # Create masks for visualization (with coordinate offset for cropped images) mask_kept = stat_to_mask(stat[kept_mask], img_h, img_w, yoff, xoff) mask_removed = stat_to_mask(stat[removed_mask], img_h, img_w, yoff, xoff) # Normalize image img_norm = normalize99(img) img_rgb = np.stack([img_norm] * 3, axis=-1).astype(np.float32) fig, axes = plt.subplots(1, 3, figsize=figsize) # Panel 1: Kept cells (green) ax = axes[0] overlay_kept = img_rgb.copy() if mask_kept.max() > 0: mask_px = mask_kept > 0 overlay_kept[mask_px] = (1 - alpha) * overlay_kept[mask_px] + alpha * np.array([0, 1, 0]) ax.imshow(overlay_kept) ax.set_title(f"Kept: {n_kept} cells", fontsize=12, fontweight='bold', color='green') ax.axis('off') # Panel 2: Removed cells (red) ax = axes[1] overlay_removed = img_rgb.copy() if mask_removed.max() > 0: mask_px = mask_removed > 0 overlay_removed[mask_px] = (1 - alpha) * overlay_removed[mask_px] + alpha * np.array([1, 0, 0]) ax.imshow(overlay_removed) ax.set_title(f"Removed: {n_removed} cells", fontsize=12, fontweight='bold', color='red') ax.axis('off') # Panel 3: Both overlaid ax = axes[2] overlay_both = img_rgb.copy() if mask_kept.max() > 0: mask_px = mask_kept > 0 overlay_both[mask_px] = (1 - alpha) * overlay_both[mask_px] + alpha * np.array([0, 1, 0]) if mask_removed.max() > 0: mask_px = mask_removed > 0 overlay_both[mask_px] = (1 - alpha) * overlay_both[mask_px] + alpha * np.array([1, 0, 0]) ax.imshow(overlay_both) ax.set_title(f"Combined: {n_kept} kept (green) / {n_removed} removed (red)", fontsize=12, fontweight='bold') ax.axis('off') # Add legend from matplotlib.patches import Patch legend_elements = [ Patch(facecolor='green', alpha=0.7, label=f'Kept ({n_kept})'), Patch(facecolor='red', alpha=0.7, label=f'Removed ({n_removed})'), ] axes[2].legend(handles=legend_elements, loc='upper right', fontsize=10) # Title if title is None: plane_name = plane_dir.name title = f"{plane_name}: {n_original}{n_kept} cells ({n_removed} removed)" plt.suptitle(title, fontsize=14, fontweight='bold') plt.tight_layout() if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.close(fig) else: plt.show() return fig def plot_filter_exclusions( plane_dir, iscell_filtered, filter_results: list, stat=None, ops=None, img_key: str = "max_proj", alpha: float = 0.5, save_dir=None, figsize: tuple = (12, 6), ): """ Create one PNG per filter showing cells it excluded. For each filter that rejected cells, creates a visualization with: - Accepted cells (green) - Cells rejected by this specific filter (red) - Title with filter name and parameters Parameters ---------- plane_dir : str or Path Path to Suite2p plane directory. iscell_filtered : np.ndarray Final filtered iscell array (n_rois,) or (n_rois, 2). filter_results : list of dict Results from apply_filters(), each dict has 'name', 'removed_mask', 'info'. stat : np.ndarray, optional Suite2p stat array. If None, loads from plane_dir. ops : dict, optional Suite2p ops dict. If None, loads from plane_dir. img_key : str, default "max_proj" Key in ops for background image. alpha : float, default 0.5 Overlay transparency. save_dir : str or Path, optional Directory to save PNGs. Defaults to plane_dir. figsize : tuple, default (12, 6) Figure size. Returns ------- dict Filter metadata: {filter_name: {params, n_rejected, n_remaining}} """ plane_dir = Path(plane_dir) save_dir = Path(save_dir) if save_dir else plane_dir # load data if needed if stat is None: stat = np.load(plane_dir / "stat.npy", allow_pickle=True) if ops is None: ops = load_ops(plane_dir) # get background image with coordinate offsets img, yoff, xoff = get_background_image(ops, img_key) img_h, img_w = img.shape[:2] # normalize iscell if iscell_filtered.ndim == 2: iscell_filtered = iscell_filtered[:, 0] accepted_mask = iscell_filtered.astype(bool) # normalize image img_norm = normalize99(img) img_rgb = np.stack([img_norm] * 3, axis=-1).astype(np.float32) # create accepted cells mask (used in all figures) mask_accepted = stat_to_mask(stat[accepted_mask], img_h, img_w, yoff, xoff) filter_metadata = {} for result in filter_results: name = result["name"] removed_mask = result["removed_mask"] info = result["info"] config = result.get("config", {}) n_rejected = removed_mask.sum() if n_rejected == 0: continue # build params from user config first (more meaningful), then computed info params = {} # user-specified params for key in ["min_diameter_um", "max_diameter_um", "min_diameter_px", "max_diameter_px", "min_area_px", "max_area_px", "min_mult", "max_mult", "max_ratio"]: if key in config and config[key] is not None: val = config[key] params[key] = round(val, 1) if isinstance(val, float) else val # computed params (fallback if no user config) if not params: for key in ["min_px", "max_px", "min_ratio", "max_ratio", "lower_px", "upper_px"]: if key in info and info[key] is not None: params[key] = round(info[key], 1) # create mask for rejected cells mask_rejected = stat_to_mask(stat[removed_mask], img_h, img_w, yoff, xoff) # create figure fig, ax = plt.subplots(figsize=figsize) overlay = img_rgb.copy() # draw accepted cells (green) if mask_accepted.max() > 0: mask_px = mask_accepted > 0 overlay[mask_px] = (1 - alpha) * overlay[mask_px] + alpha * np.array([0.2, 0.8, 0.2]) # draw rejected cells (red) if mask_rejected.max() > 0: mask_px = mask_rejected > 0 overlay[mask_px] = (1 - alpha) * overlay[mask_px] + alpha * np.array([0.9, 0.2, 0.2]) ax.imshow(overlay) ax.axis("off") # title with filter info params_str = ", ".join(f"{k}={v}" for k, v in params.items()) title = f"{name}: {n_rejected} excluded" if params_str: title += f" ({params_str})" ax.set_title(title, fontsize=12, fontweight="bold") # legend from matplotlib.patches import Patch legend_elements = [ Patch(facecolor=(0.2, 0.8, 0.2), alpha=0.7, label=f"Accepted ({accepted_mask.sum()})"), Patch(facecolor=(0.9, 0.2, 0.2), alpha=0.7, label=f"Excluded ({n_rejected})"), ] ax.legend(handles=legend_elements, loc="upper right", fontsize=10) plt.tight_layout() # save save_path = save_dir / f"14_filter_{name}.png" plt.savefig(save_path, dpi=150, bbox_inches="tight") plt.close(fig) print(f" Saved {save_path.name}") # store metadata filter_metadata[name] = { "params": params, "n_rejected": int(n_rejected), "n_remaining": int(accepted_mask.sum()), } return filter_metadata def plot_cell_filter_summary( plane_dir, iscell_suite2p=None, iscell_final=None, filter_results: list = None, stat=None, ops=None, img_key: str = "max_proj", alpha: float = 0.5, save_path=None, figsize: tuple = (16, 10), ): """ Create a summary figure showing all filtering stages for a plane. Shows suite2p classification, each filter's effect, and final result in a single well-formatted figure. Parameters ---------- plane_dir : str or Path Path to Suite2p plane directory. iscell_suite2p : np.ndarray, optional Original suite2p iscell (before accept_all_cells). Loads from iscell_suite2p.npy if exists, otherwise uses iscell.npy. iscell_final : np.ndarray, optional Final iscell after all filters. If None, loads from iscell.npy. filter_results : list of dict, optional Results from apply_filters(). If None, attempts to reconstruct from ops['filter_metadata']. stat : np.ndarray, optional Suite2p stat array. If None, loads from plane_dir. ops : dict, optional Suite2p ops dict. If None, loads from plane_dir. img_key : str, default "max_proj" Key in ops for background image. alpha : float, default 0.5 Overlay transparency. save_path : str or Path, optional Path to save the figure. If None, displays with plt.show(). figsize : tuple, default (16, 10) Figure size. Returns ------- matplotlib.figure.Figure The generated figure. """ import matplotlib.pyplot as plt from matplotlib.patches import Patch plane_dir = Path(plane_dir) # load data if stat is None: stat = np.load(plane_dir / "stat.npy", allow_pickle=True) if ops is None: ops = load_ops(plane_dir) # load iscell arrays if iscell_suite2p is None: s2p_file = plane_dir / "iscell_suite2p.npy" if s2p_file.exists(): iscell_suite2p = np.load(s2p_file, allow_pickle=True) else: iscell_suite2p = np.load(plane_dir / "iscell.npy", allow_pickle=True) if iscell_final is None: iscell_final = np.load(plane_dir / "iscell.npy", allow_pickle=True) # normalize to 1d if iscell_suite2p.ndim == 2: iscell_suite2p = iscell_suite2p[:, 0] if iscell_final.ndim == 2: iscell_final = iscell_final[:, 0] # get filter metadata from ops if not provided filter_metadata = ops.get("filter_metadata", {}) # get background image img, yoff, xoff = get_background_image(ops, img_key) img_h, img_w = img.shape[:2] img_norm = normalize99(img) img_rgb = np.stack([img_norm] * 3, axis=-1).astype(np.float32) # compute masks n_rois = len(stat) suite2p_accepted = iscell_suite2p.astype(bool) suite2p_rejected = ~suite2p_accepted final_accepted = iscell_final.astype(bool) # determine what suite2p rejected vs what filters rejected n_suite2p_rejected = suite2p_rejected.sum() n_filter_rejected = (suite2p_accepted & ~final_accepted).sum() n_final_accepted = final_accepted.sum() # build panels: suite2p classification + filters + final panels = [] # panel 1: suite2p classification panels.append({ "title": "suite2p classification", "accepted_mask": suite2p_accepted, "rejected_mask": suite2p_rejected, "n_accepted": int(suite2p_accepted.sum()), "n_rejected": int(n_suite2p_rejected), "subtitle": f"{n_suite2p_rejected} rejected by suite2p", }) # panels for each filter that rejected cells if filter_metadata: for name, meta in filter_metadata.items(): n_rejected = meta.get("n_rejected", 0) if n_rejected > 0: params = meta.get("params", {}) params_str = ", ".join(f"{k}={v}" for k, v in params.items()) panels.append({ "title": f"filter: {name}", "filter_name": name, "n_rejected": n_rejected, "subtitle": f"{n_rejected} excluded" + (f" ({params_str})" if params_str else ""), }) # final panel: result panels.append({ "title": "final result", "accepted_mask": final_accepted, "rejected_mask": ~final_accepted, "n_accepted": int(n_final_accepted), "n_rejected": int(n_rois - n_final_accepted), "subtitle": f"{n_final_accepted} accepted, {n_rois - n_final_accepted} rejected", }) # create figure n_panels = len(panels) if n_panels <= 2: ncols = n_panels nrows = 1 elif n_panels <= 4: ncols = 2 nrows = 2 else: ncols = 3 nrows = (n_panels + 2) // 3 fig, axes = plt.subplots(nrows, ncols, figsize=figsize) if n_panels == 1: axes = np.array([axes]) axes = axes.flatten() # hide unused axes for i in range(n_panels, len(axes)): axes[i].axis("off") # color scheme color_accepted = np.array([0.2, 0.8, 0.2]) # green color_rejected = np.array([0.9, 0.2, 0.2]) # red color_filtered = np.array([1.0, 0.6, 0.0]) # orange for filter-rejected for i, panel in enumerate(panels): ax = axes[i] overlay = img_rgb.copy() if "accepted_mask" in panel: # draw accepted cells accepted_mask = panel["accepted_mask"] if accepted_mask.sum() > 0: mask_px = stat_to_mask(stat[accepted_mask], img_h, img_w, yoff, xoff) if mask_px.max() > 0: px = mask_px > 0 overlay[px] = (1 - alpha) * overlay[px] + alpha * color_accepted # draw rejected cells rejected_mask = panel["rejected_mask"] if rejected_mask.sum() > 0: mask_px = stat_to_mask(stat[rejected_mask], img_h, img_w, yoff, xoff) if mask_px.max() > 0: px = mask_px > 0 overlay[px] = (1 - alpha) * overlay[px] + alpha * color_rejected elif "filter_name" in panel: # for filter panels, show accepted (green) and this filter's rejected (orange) # reconstruct which cells this filter rejected # we show: final accepted (green) + cells rejected by this filter (orange) if final_accepted.sum() > 0: mask_px = stat_to_mask(stat[final_accepted], img_h, img_w, yoff, xoff) if mask_px.max() > 0: px = mask_px > 0 overlay[px] = (1 - alpha) * overlay[px] + alpha * color_accepted # for filter panels without explicit mask, just show the count in subtitle # (we don't have the exact removed_mask saved, just metadata) ax.imshow(overlay) ax.axis("off") ax.set_title(panel["title"], fontsize=11, fontweight="bold") # add subtitle if "subtitle" in panel: ax.text( 0.5, -0.02, panel["subtitle"], transform=ax.transAxes, ha="center", va="top", fontsize=9, color="gray" ) # add legend to last panel legend_elements = [ Patch(facecolor=color_accepted, alpha=0.7, label="accepted"), Patch(facecolor=color_rejected, alpha=0.7, label="rejected"), ] axes[n_panels - 1].legend( handles=legend_elements, loc="upper right", fontsize=9 ) # overall title fig.suptitle( f"Cell Filter Summary: {n_rois} total ROIs → {n_final_accepted} accepted", fontsize=13, fontweight="bold", y=0.98 ) plt.tight_layout(rect=[0, 0, 1, 0.96]) if save_path: plt.savefig(save_path, dpi=150, bbox_inches="tight") plt.close(fig) else: plt.show() return fig def plot_diameter_histogram( stat, iscell=None, max_diameter_px: float = None, pixel_size_um: float = None, bins: int = 50, save_path=None, figsize: tuple = (10, 6), ): """ Plot histogram of cell diameters with optional threshold line. Parameters ---------- stat : np.ndarray or list Suite2p stat array with ROI statistics. iscell : np.ndarray, optional Cell classification array. If provided, only plots accepted cells. max_diameter_px : float, optional Threshold diameter in pixels to show as vertical line. pixel_size_um : float, optional Pixel size in microns. If provided, adds micron scale to x-axis. bins : int, optional Number of histogram bins. Default is 50. save_path : str or Path, optional Path to save the figure. figsize : tuple, optional Figure size. Default is (10, 6). Returns ------- matplotlib.figure.Figure The generated figure. """ # Filter by iscell if provided if iscell is not None: if iscell.ndim == 2: iscell = iscell[:, 0] iscell = iscell.astype(bool) stat = stat[iscell] # Get radii if len(stat) == 0: print("No cells to plot") return None if "radius" not in stat[0]: radii = np.array([np.sqrt(len(s["xpix"]) / np.pi) for s in stat]) else: radii = np.array([s["radius"] for s in stat]) diameters_px = 2 * radii fig, ax = plt.subplots(figsize=figsize) # Plot histogram counts, bin_edges, patches = ax.hist( diameters_px, bins=bins, color='steelblue', edgecolor='white', alpha=0.7 ) # Color bars above threshold red if max_diameter_px is not None: for patch, left_edge in zip(patches, bin_edges[:-1]): if left_edge >= max_diameter_px: patch.set_facecolor('red') patch.set_alpha(0.7) # Add threshold line ax.axvline(max_diameter_px, color='red', linestyle='--', linewidth=2, label=f'Threshold: {max_diameter_px:.1f} px') # Count cells above threshold n_above = (diameters_px > max_diameter_px).sum() n_total = len(diameters_px) ax.legend(title=f'{n_above}/{n_total} cells above threshold') ax.set_xlabel('Diameter (pixels)', fontsize=12) ax.set_ylabel('Count', fontsize=12) ax.set_title(f'Cell Diameter Distribution (n={len(diameters_px)})', fontsize=14) # Add micron scale if pixel size provided if pixel_size_um is not None: ax2 = ax.twiny() ax2.set_xlim(ax.get_xlim()[0] * pixel_size_um, ax.get_xlim()[1] * pixel_size_um) ax2.set_xlabel('Diameter (µm)', fontsize=12) if max_diameter_px is not None: max_um = max_diameter_px * pixel_size_um ax2.axvline(max_um, color='red', linestyle='--', linewidth=2, alpha=0.5) # Add statistics median_d = np.median(diameters_px) mean_d = np.mean(diameters_px) stats_text = f'Median: {median_d:.1f} px\nMean: {mean_d:.1f} px' if pixel_size_um: stats_text += f'\n({median_d * pixel_size_um:.1f} / {mean_d * pixel_size_um:.1f} µm)' ax.text(0.95, 0.95, stats_text, transform=ax.transAxes, verticalalignment='top', horizontalalignment='right', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5), fontsize=10) plt.tight_layout() if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.close(fig) else: plt.show() return fig