Source code for lbm_suite2p_python.zplane

from pathlib import Path
from typing import Tuple

import numpy as np
import pandas as pd
import tifffile
import math

import matplotlib.offsetbox
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.lines import Line2D
from matplotlib.patches import Rectangle
from matplotlib.offsetbox import VPacker, HPacker, DrawingArea
import matplotlib.gridspec as gridspec

from scipy.ndimage import distance_transform_edt

from lbm_suite2p_python.postprocessing import (
    load_ops,
    load_planar_results,
    dff_rolling_percentile,
    dff_shot_noise,
    compute_trace_quality_score,
)
from lbm_suite2p_python.utils import (
    _resize_masks_fit_crop,
    bin1d,
)


def infer_units(f: np.ndarray) -> str:
    """
    Infer calcium imaging signal type from array values:
    - 'raw': values in hundreds or thousands
    - 'dff': unitless ΔF/F₀, typically ~0–1
    - 'dff-percentile': ΔF/F₀ in percent, typically ~10–100

    Returns one of: 'raw', 'dff', 'dff-percentile'
    """
    f = np.asarray(f)
    if np.issubdtype(f.dtype, np.integer):
        return "raw"

    p1, p50, p99 = np.nanpercentile(f, [1, 50, 99])

    if p99 > 500 or p50 > 100:
        return "raw"
    elif 5 < p1 < 30 and 20 < p50 < 60 and 40 < p99 < 100:
        return "dffp"
    elif 0.1 < p1 < 0.2 < p50 < 0.5 < p99 < 1.0:
        return "dff"
    else:
        return "unknown"


def format_time(t):
    if t < 60:
        # make sure we dont show 0 seconds
        return f"{int(np.ceil(t))} s"
    elif t < 3600:
        return f"{int(round(t / 60))} min"
    else:
        return f"{int(round(t / 3600))} h"


def get_color_permutation(n):
    # choose a step from n//2+1 up to n-1 that is coprime with n
    for s in range(n // 2 + 1, n):
        if math.gcd(s, n) == 1:
            return [(i * s) % n for i in range(n)]
    return list(range(n))


class AnchoredHScaleBar(matplotlib.offsetbox.AnchoredOffsetbox):
    """
    create an anchored horizontal scale bar.

    parameters
    ----------
    size : float, optional
        bar length in data units (fixed; default is 1).
    label : str, optional
        text label (default is "").
    loc : int, optional
        location code (default is 2).
    ax : axes, optional
        axes to attach the bar (default uses current axes).
    pad, borderpad, ppad, sep : float, optional
        spacing parameters.
    linekw : dict, optional
        line properties.
    """

    def __init__(
        self,
        size=1,
        label="",
        loc=2,
        ax=None,
        pad=0.4,
        borderpad=0.5,
        ppad=0,
        sep=2,
        prop=None,
        frameon=True,
        linekw=None,
        **kwargs,
    ):
        if linekw is None:
            linekw = {}
        if ax is None:
            ax = plt.gca()
        # trans = ax.get_xaxis_transform()
        trans = ax.transAxes

        size_bar = matplotlib.offsetbox.AuxTransformBox(trans)
        line = Line2D([0, size], [0, 0], **linekw)
        size_bar.add_artist(line)
        txt = matplotlib.offsetbox.TextArea(label)
        self.txt = txt
        self.vpac = VPacker(children=[size_bar, txt], align="center", pad=ppad, sep=sep)
        super().__init__(
            loc,  # noqa
            pad=pad,
            borderpad=borderpad,
            child=self.vpac,
            prop=prop,
            frameon=frameon,
            **kwargs,
        )


class AnchoredVScaleBar(matplotlib.offsetbox.AnchoredOffsetbox):
    """
    Create an anchored vertical scale bar.

    Parameters
    ----------
    height : float, optional
        Bar height in data units (default is 1).
    label : str, optional
        Text label (default is "").
    loc : int, optional
        Location code (default is 2).
    ax : axes, optional
        Axes to attach the bar (default uses current axes).
    pad, borderpad, ppad, sep : float, optional
        Spacing parameters.
    linekw : dict, optional
        Line properties.
    spacer_width : float, optional
        Width of spacer between bar and text.
    """

    def __init__(
        self,
        height=1,
        label="",
        loc=2,
        ax=None,
        pad=0.4,
        borderpad=0.5,
        ppad=0,
        sep=2,
        prop=None,
        frameon=True,
        linekw=None,
        spacer_width=6,
        **kwargs,
    ):
        if ax is None:
            ax = plt.gca()
        if linekw is None:
            linekw = {}
        trans = ax.transAxes

        size_bar = matplotlib.offsetbox.AuxTransformBox(trans)
        line = Line2D([0, 0], [0, height], **linekw)
        size_bar.add_artist(line)

        txt = matplotlib.offsetbox.TextArea(
            label, textprops=dict(rotation=90, ha="left", va="bottom")
        )
        self.txt = txt

        spacer = DrawingArea(spacer_width, 0, 0, 0)
        self.hpac = HPacker(
            children=[size_bar, spacer, txt], align="bottom", pad=ppad, sep=sep
        )
        super().__init__(
            loc,  # noqa
            pad=pad,
            borderpad=borderpad,
            child=self.hpac,
            prop=prop,
            frameon=frameon,
            **kwargs,
        )


def plot_traces_noise(
    dff_noise,
    colors,
    fps=17.0,
    window=220,
    savepath=None,
    title="Trace Noise",
    lw=0.5,
):
    """
    Plot stacked noise traces in the same style as plot_traces.

    Parameters
    ----------
    dff_noise : ndarray
        Noise traces, shape (n_neurons, n_timepoints).
    colors : ndarray
        Colormap array returned from plot_traces(return_color=True).
    fps : float
        Sampling rate, Hz.
    window : float
        Time window (seconds) to display.
    savepath : str or Path, optional
        If given, save to file.
    title : str
        Title for figure.
    lw : float
        Line width.
    """

    n_neurons, n_timepoints = dff_noise.shape
    data_time = np.arange(n_timepoints) / fps
    current_frame = min(int(window * fps), n_timepoints - 1)

    # auto offset based on noise traces
    p10 = np.percentile(dff_noise[:, : current_frame + 1], 10, axis=1)
    p90 = np.percentile(dff_noise[:, : current_frame + 1], 90, axis=1)
    offset = np.median(p90 - p10) * 1.2

    fig, ax = plt.subplots(figsize=(10, 6), facecolor="black")
    ax.set_facecolor("black")
    ax.tick_params(axis="x", which="both", labelbottom=False, length=0, colors="white")
    ax.tick_params(axis="y", which="both", labelleft=False, length=0, colors="white")
    for spine in ax.spines.values():
        spine.set_visible(False)

    for i in reversed(range(n_neurons)):
        trace = dff_noise[i, : current_frame + 1]
        shifted_trace = trace + i * offset
        ax.plot(
            data_time[: current_frame + 1],
            shifted_trace,
            color=colors[i],
            lw=lw,
            zorder=-i,
        )

    if title:
        fig.suptitle(title, fontsize=16, fontweight="bold", color="white")

    if savepath:
        plt.savefig(savepath, dpi=200, facecolor=fig.get_facecolor())
        plt.close(fig)
    else:
        plt.show()


[docs] def plot_traces( f, save_path: str | Path = "", cell_indices: np.ndarray | list[int] | None = None, fps=17.0, num_neurons=20, window=220, title="", offset=None, lw=0.5, cmap="tab10", scale_bar_unit: str = None, mask_overlap: bool = True, ) -> None: """ Plot stacked fluorescence traces with automatic offset and scale bars. Parameters ---------- f : ndarray 2d array of fluorescence traces (n_neurons x n_timepoints). save_path : str, optional Path to save the output plot. fps : float Sampling rate in frames per second. num_neurons : int Number of neurons to display if cell_indices is None. window : float Time window (in seconds) to display. title : str Title of the figure. offset : float or None Vertical offset between traces; if None, computed automatically. lw : float Line width for data points. cmap : str Matplotlib colormap string. scale_bar_unit : str, optional Unit suffix for the vertical scale bar (e.g., "% ΔF/F₀", "a.u."). The numeric value is computed automatically based on the plot's vertical scale. If None, inferred from data range. cell_indices : array-like or None Specific cell indices to plot. If provided, overrides num_neurons. mask_overlap : bool, default True If True, lower traces mask (occlude) traces above them, creating a layered effect where each trace has a black background. """ if isinstance(f, dict): raise ValueError("f must be a numpy array, not a dictionary") n_timepoints = f.shape[-1] data_time = np.arange(n_timepoints) / fps current_frame = min(int(window * fps), n_timepoints - 1) if cell_indices is None: displayed_neurons = min(num_neurons, f.shape[0]) indices = np.arange(displayed_neurons) else: indices = np.array(cell_indices) if indices.dtype == bool: indices = np.where(indices)[0] # convert boolean mask to int indices displayed_neurons = len(indices) if len(indices) == 0: return None if offset is None: p10 = np.percentile(f[indices, : current_frame + 1], 10, axis=1) p90 = np.percentile(f[indices, : current_frame + 1], 90, axis=1) offset = np.median(p90 - p10) * 1.2 # Ensure minimum offset to prevent trace overlap min_offset = np.percentile(p90 - p10, 75) * 0.8 offset = max(offset, min_offset, 1e-6) # Absolute minimum to prevent divide-by-zero cmap_inst = plt.get_cmap(cmap) colors = cmap_inst(np.linspace(0, 1, displayed_neurons)) perm = get_color_permutation(displayed_neurons) colors = colors[perm] # fig, ax = plt.subplots(figsize=(10, 6), facecolor="black") # ax.set_facecolor("black") # Build shifted traces array (no masking - let z-order handle overlap) shifted_traces = np.zeros((displayed_neurons, current_frame + 1)) for i in range(displayed_neurons): trace = f[indices[i], : current_frame + 1] baseline = np.percentile(trace, 8) shifted_traces[i] = (trace - baseline) + i * offset # Plot traces with z-ordering (lower traces on top via higher zorder) fig, ax = plt.subplots(figsize=(10, 6), facecolor="black") ax.set_facecolor("black") ax.tick_params(axis="x", which="both", labelbottom=False, length=0, colors="white") ax.tick_params(axis="y", which="both", labelleft=False, length=0, colors="white") for spine in ax.spines.values(): spine.set_visible(False) # Plot from top to bottom so lower-indexed traces appear on top time_slice = data_time[: current_frame + 1] for i in range(displayed_neurons - 1, -1, -1): z = displayed_neurons - i # Lower index = higher zorder = on top if mask_overlap: # Fill below trace with black to mask traces above ax.fill_between( time_slice, shifted_traces[i], y2=shifted_traces[i].min() - offset, color="black", zorder=z - 0.5, ) ax.plot( time_slice, shifted_traces[i], color=colors[i], lw=lw, zorder=z, ) time_bar_length = 0.1 * window if time_bar_length < 60: time_label = f"{time_bar_length:.0f} s" elif time_bar_length < 3600: time_label = f"{time_bar_length / 60:.0f} min" else: time_label = f"{time_bar_length / 3600:.1f} hr" # Set y-limits with small padding (no extra space for scalebars - they go outside) y_min = np.min(shifted_traces) y_max = np.max(shifted_traces) y_range = y_max - y_min ax.set_ylim(y_min - y_range * 0.02, y_max + y_range * 0.02) # Compute vertical scale bar value (10% of y-range in data units) scale_bar_height_frac = 0.10 # 10% of axes height scale_bar_data_value = y_range * scale_bar_height_frac # Use provided unit or default to "a.u." if scale_bar_unit is None: scale_bar_unit = "a.u." # Format the scale bar label with computed value if scale_bar_data_value >= 100: scale_bar_label = f"{int(round(scale_bar_data_value, -1))} {scale_bar_unit}" elif scale_bar_data_value >= 10: scale_bar_label = f"{int(round(scale_bar_data_value))} {scale_bar_unit}" elif scale_bar_data_value >= 1: scale_bar_label = f"{scale_bar_data_value:.0f} {scale_bar_unit}" else: scale_bar_label = f"{scale_bar_data_value:.2f} {scale_bar_unit}" # Adjust subplot to make room for scalebars at bottom and right fig.subplots_adjust(bottom=0.12, right=0.88) linekw = dict(color="white", linewidth=3) # Time scale bar - use fig.text for fixed position below axes # Get axes position in figure coordinates ax_pos = ax.get_position() time_bar_x = ax_pos.x1 - 0.02 # right side of axes time_bar_y = 0.07 # fixed position just below axes # Draw horizontal line for time scale bar line_width_fig = 0.08 # width in figure coords fig.add_artist(plt.Line2D( [time_bar_x - line_width_fig, time_bar_x], [time_bar_y, time_bar_y], transform=fig.transFigure, color="white", linewidth=3, clip_on=False, )) # Add time label fig.text( time_bar_x - line_width_fig / 2, time_bar_y - 0.02, time_label, ha="center", va="top", color="white", fontsize=10, transform=fig.transFigure, ) # Vertical scale bar - positioned just outside right edge, bottom aligned with x-axis vsb = AnchoredVScaleBar( height=scale_bar_height_frac, label=scale_bar_label, loc="lower right", frameon=False, pad=0.5, sep=4, linekw=linekw, ax=ax, spacer_width=0, ) # Position just outside right edge of axes, bottom at y=0 vsb.set_bbox_to_anchor((1.02, 0.0), transform=ax.transAxes) vsb.txt._text.set_color("white") ax.add_artist(vsb) if title: fig.suptitle(title, fontsize=16, fontweight="bold", color="white") ax.set_ylabel( f"Neuron Count: {displayed_neurons}", fontsize=10, fontweight="bold", color="white", labelpad=5, ) if save_path: plt.savefig(save_path, dpi=200, facecolor=fig.get_facecolor()) plt.close(fig) else: plt.show() return None
def animate_traces( f, save_path="./scrolling.mp4", fps=17.0, start_neurons=20, window=120, title="", gap=None, lw=0.5, cmap="tab10", anim_fps=60, expand_after=5, speed_factor=1.0, expansion_factor=2.0, smooth_factor=1, ): """WIP""" n_neurons, n_timepoints = f.shape data_time = np.arange(n_timepoints) / fps T_data = data_time[-1] current_frame = min(int(window * fps), n_timepoints - 1) t_f_local = (T_data - window + expansion_factor * expand_after) / ( 1 + expansion_factor ) if gap is None: p10 = np.percentile(f[:start_neurons, : current_frame + 1], 10, axis=1) p90 = np.percentile(f[:start_neurons, : current_frame + 1], 90, axis=1) gap = np.median(p90 - p10) * 1.2 cmap_inst = plt.get_cmap(cmap) colors = cmap_inst(np.linspace(0, 1, n_neurons)) perm = np.random.permutation(n_neurons) colors = colors[perm] all_shifted = [] for i in range(start_neurons): trace = f[i, : current_frame + 1] baseline = np.percentile(trace, 8) shifted = (trace - baseline) + i * gap all_shifted.append(shifted) all_y = np.concatenate(all_shifted) y_min = np.min(all_y) y_max = np.max(all_y) rounded_dff = np.round(y_max - y_min) * 0.1 dff_label = f"{rounded_dff:.0f} % ΔF/F₀" fig, ax = plt.subplots(figsize=(10, 6), facecolor="black") ax.set_facecolor("black") ax.tick_params(axis="x", labelbottom=False, length=0) ax.tick_params(axis="y", labelleft=False, length=0) for spine in ax.spines.values(): spine.set_visible(False) fills = [] linekw = dict(color="white", linewidth=3) hsb = AnchoredHScaleBar( size=0.1, label=format_time(0.1 * window), loc=4, frameon=False, pad=0.6, sep=4, linekw=linekw, ax=ax, ) hsb.set_bbox_to_anchor((0.97, -0.1), transform=ax.transAxes) # noqa ax.add_artist(hsb) vsb = AnchoredVScaleBar( height=0.1, label=dff_label, loc="lower right", # noqa frameon=False, pad=0, sep=4, linekw=linekw, ax=ax, spacer_width=0, ) ax.add_artist(vsb) lines = [] for i in range(n_neurons): (line,) = ax.plot([], [], color=colors[i], lw=lw, zorder=-i) lines.append(line) def init(): for ix in range(n_neurons): if ix < start_neurons: _trace = f[ix, : current_frame + 1] _baseline = np.percentile(_trace, 8) _shifted = (_trace - _baseline) + ix * gap lines[ix].set_data(data_time[: current_frame + 1], _shifted) else: lines[ix].set_data([], []) extra = 0.05 * window ax.set_xlim(0, window + extra) ax.set_ylim(y_min - 0.05 * abs(y_min), y_max + 0.05 * abs(y_max)) return lines + [hsb, vsb] def update(frame): t = speed_factor * frame / anim_fps if t < expand_after: x_min = t x_max = t + window n_visible = start_neurons else: u = min(1.0, (t - expand_after) / (t_f_local - expand_after)) ease = 3 * u**2 - 2 * u**3 # smoothstep easing x_min = t window_start = window window_end = window + expansion_factor * (T_data - window - expand_after) current_window = window_start + (window_end - window_start) * ease x_max = x_min + current_window n_visible = start_neurons + int((n_neurons - start_neurons) * ease) n_visible = min(n_neurons, n_visible) i_lower = int(x_min * fps) i_upper = int(x_max * fps) i_upper = max(i_upper, i_lower + 1) for ix in range(n_neurons): if ix < n_visible: _trace = f[ix, i_lower:i_upper] _baseline = np.percentile(_trace, 8) _shifted = (_trace - _baseline) + ix * gap lines[ix].set_data(data_time[i_lower:i_upper], _shifted) else: lines[ix].set_data([], []) for fill in fills: fill.remove() fills.clear() for ix in range(n_visible - 1): trace1 = f[ix, i_lower:i_upper] baseline1 = np.percentile(trace1, 8) shifted1 = (trace1 - baseline1) + ix * gap trace2 = f[ix + 1, i_lower:i_upper] baseline2 = np.percentile(trace2, 8) shifted2 = (trace2 - baseline2) + (ix + 1) * gap fill = ax.fill_between( data_time[i_lower:i_upper], shifted1, shifted2, where=shifted1 > shifted2, color="black", zorder=-ix - 1, ) fills.append(fill) _all_shifted = [ (f[ix, i_lower:i_upper] - np.percentile(f[ix, i_lower:i_upper], 8)) + ix * gap for ix in range(n_visible) ] _all_y = np.concatenate(_all_shifted) y_min_new, y_max_new = np.min(_all_y), np.max(_all_y) extra_axis = 0.05 * (x_max - x_min) ax.set_xlim(x_min, x_max + extra_axis) ax.set_ylim( y_min_new - 0.05 * abs(y_min_new), y_max_new + 0.05 * abs(y_max_new) ) if title: ax.set_title(title, fontsize=16, fontweight="bold", color="white") _dff_rounded = np.round(y_max_new - y_min_new) * 0.1 if _dff_rounded > 300: vsb.set_visible(False) else: _dff_label = f"{_dff_rounded:.0f} % ΔF/F₀" vsb.txt.set_text(_dff_label) hsb.txt.set_text(format_time(0.1 * (x_max - x_min))) ax.set_ylabel( f"Neuron Count: {n_visible}", fontsize=8, fontweight="bold", labelpad=2 ) return lines + [hsb, vsb] + fills effective_anim_fps = anim_fps * smooth_factor total_frames = int(np.ceil((T_data / speed_factor))) ani = FuncAnimation( fig, update, frames=total_frames, init_func=init, interval=1000 / effective_anim_fps, blit=True, ) ani.save(save_path, fps=anim_fps) plt.show() def feather_mask(mask, max_alpha=0.75, edge_width=3): # mask alpha using distance transform dist_out = distance_transform_edt(mask == 0) alpha = np.clip((edge_width - dist_out) / edge_width, 0, 1) return alpha * max_alpha def plot_masks( img: np.ndarray, stat: list[dict] | dict, mask_idx: np.ndarray, savepath: str | Path = None, colors=None, title=None, ): """ Draw ROI overlays onto the mean image. Parameters ---------- img : ndarray (Ly x Lx) Background image to overlay on. stat : list[dict] Suite2p ROI stat dictionaries (with "ypix", "xpix", "lam"). mask_idx : ndarray[bool] Boolean array selecting which ROIs to plot. savepath : str or Path, optional Path to save the figure. If None, displays with plt.show(). colors : ndarray or list, optional Array/list of RGB tuples for each ROI selected. If None, colors are assigned via HSV colormap. title : str, optional Title string to place on the figure. """ # Normalize background image using percentile stretch for better contrast # this prevents dark images when min/max are extreme outliers vmin = np.nanpercentile(img, 1) vmax = np.nanpercentile(img, 99) normalized = (img - vmin) / (vmax - vmin + 1e-6) normalized = np.clip(normalized, 0, 1) # Set NaN regions to 0 (black background) normalized = np.nan_to_num(normalized, nan=0.0) canvas = np.tile(normalized, (3, 1, 1)).transpose(1, 2, 0) # Get image dimensions for bounds checking Ly, Lx = img.shape[:2] # Assign colors if not provided n_masks = mask_idx.sum() if colors is None: colors = plt.cm.hsv(np.linspace(0, 1, n_masks + 1))[:, :3] # noqa c = 0 for n, s in enumerate(stat): if mask_idx[n]: ypix, xpix, lam = s["ypix"], s["xpix"], s["lam"] # Bounds checking - only keep pixels within image dimensions valid_mask = (ypix >= 0) & (ypix < Ly) & (xpix >= 0) & (xpix < Lx) if not np.any(valid_mask): c += 1 continue # Skip ROI if no valid pixels ypix = ypix[valid_mask] xpix = xpix[valid_mask] lam = lam[valid_mask] lam = lam / (lam.max() + 1e-10) col = colors[c] c += 1 for k in range(3): canvas[ypix, xpix, k] = ( 0.5 * canvas[ypix, xpix, k] + 0.5 * col[k] * lam ) fig, ax = plt.subplots(figsize=(10, 10), facecolor="black") ax.set_facecolor("black") ax.imshow(canvas, interpolation="nearest") if title is not None: ax.set_title(title, fontsize=10, color="white", fontweight="bold") ax.axis("off") plt.tight_layout() if savepath: if Path(savepath).is_dir(): raise ValueError("savepath must be a file path, not a directory.") plt.savefig(savepath, dpi=300, facecolor="black") plt.close(fig) else: plt.show()
[docs] def plot_projection( ops, output_directory=None, fig_label=None, vmin=None, vmax=None, add_scalebar=False, proj="meanImg", display_masks=False, accepted_only=False, ): from suite2p.detection.stats import ROI if proj == "meanImg": txt = "Mean-Image" elif proj == "max_proj": txt = "Max-Projection" elif proj == "meanImgE": txt = "Mean-Image (Enhanced)" else: raise ValueError( "Unknown projection type. Options are ['meanImg', 'max_proj', 'meanImgE']" ) if output_directory: output_directory = Path(output_directory) data = ops[proj] shape = data.shape fig, ax = plt.subplots(figsize=(6, 6), facecolor="black") vmin = np.nanpercentile(data, 2) if vmin is None else vmin vmax = np.nanpercentile(data, 98) if vmax is None else vmax if vmax - vmin < 1e-6: vmax = vmin + 1e-6 ax.imshow(data, cmap="gray", vmin=vmin, vmax=vmax) # move projection title higher if masks are displayed to avoid overlap. proj_title_y = 1.07 if display_masks else 1.02 ax.text( 0.5, proj_title_y, txt, transform=ax.transAxes, fontsize=14, fontweight="bold", fontname="Courier New", color="white", ha="center", va="bottom", ) if fig_label: fig_label = fig_label.replace("_", " ").replace("-", " ").replace(".", " ") ax.set_ylabel(fig_label, color="white", fontweight="bold", fontsize=12) ax.set_xticks([]) ax.set_yticks([]) if display_masks: res = load_planar_results(ops) stat = res["stat"] iscell_mask = res["iscell"][:, 0].astype(bool) im = ROI.stats_dicts_to_3d_array( stat, Ly=ops["Ly"], Lx=ops["Lx"], label_id=True ) im[im == 0] = np.nan accepted_cells = np.sum(iscell_mask) rejected_cells = np.sum(~iscell_mask) cell_rois = _resize_masks_fit_crop( np.nanmax(im[iscell_mask], axis=0) if np.any(iscell_mask) else np.zeros_like(im[0]), shape, ) green_overlay = np.zeros((*shape, 4), dtype=np.float32) green_overlay[..., 3] = feather_mask(cell_rois > 0, max_alpha=0.9) green_overlay[..., 1] = 1 ax.imshow(green_overlay) if not accepted_only: non_cell_rois = _resize_masks_fit_crop( ( np.nanmax(im[~iscell_mask], axis=0) if np.any(~iscell_mask) else np.zeros_like(im[0]) ), shape, ) magenta_overlay = np.zeros((*shape, 4), dtype=np.float32) magenta_overlay[..., 0] = 1 magenta_overlay[..., 2] = 1 magenta_overlay[..., 3] = (non_cell_rois > 0) * 0.5 ax.imshow(magenta_overlay) ax.text( 0.37, 1.02, f"Accepted: {accepted_cells:03d}", transform=ax.transAxes, fontsize=14, fontweight="bold", fontname="Courier New", color="lime", ha="right", va="bottom", ) ax.text( 0.63, 1.02, f"Rejected: {rejected_cells:03d}", transform=ax.transAxes, fontsize=14, fontweight="bold", fontname="Courier New", color="magenta", ha="left", va="bottom", ) if add_scalebar and "dx" in ops: pixel_size = ops["dx"] scale_bar_length = 100 / pixel_size scalebar_x = shape[1] * 0.05 scalebar_y = shape[0] * 0.90 ax.add_patch( Rectangle( (scalebar_x, scalebar_y), scale_bar_length, 5, edgecolor="white", facecolor="white", ) ) ax.text( scalebar_x + scale_bar_length / 2, scalebar_y - 10, "100 μm", color="white", fontsize=10, ha="center", fontweight="bold", ) # remove the spines that will show up as white bars for spine in ax.spines.values(): spine.set_visible(False) plt.tight_layout() if output_directory: output_directory.parent.mkdir(parents=True, exist_ok=True) plt.savefig(output_directory, dpi=300, facecolor="black") plt.close(fig) else: plt.show()
def plot_noise_distribution( noise_levels: np.ndarray, output_filename=None, title="Noise Level Distribution" ): """ Plots and saves the distribution of noise levels across neurons as a standardized image. Parameters ---------- noise_levels : np.ndarray 1D array of noise levels for each neuron. output_filename : str or Path, optional Path to save the plot. If empty, the plot will be displayed instead of saved. title : str, optional Suptitle for plot, default is "Noise Level Distribution". See Also -------- lbm_suite2p_python.dff_shot_noise """ if output_filename: output_filename = Path(output_filename) if output_filename.is_dir(): raise AttributeError( f"save_path should be a fully qualified file path, not a directory: {output_filename}" ) fig = plt.figure(figsize=(8, 5)) plt.hist(noise_levels, bins=50, color="gray", alpha=0.7, edgecolor="black") mean_noise: float = np.mean(noise_levels) # noqa plt.axvline( mean_noise, color="r", linestyle="dashed", linewidth=2, label=f"Mean: {mean_noise:.2f}", ) plt.xlabel("Noise Level", fontsize=14, fontweight="bold") plt.ylabel("Number of Neurons", fontsize=14, fontweight="bold") plt.title(title, fontsize=16, fontweight="bold") plt.legend(fontsize=12) plt.xticks(fontsize=12) plt.yticks(fontsize=12) if output_filename: plt.savefig(output_filename, dpi=200, bbox_inches="tight") plt.close(fig) else: plt.show()
[docs] def plot_rastermap( spks, model, neuron_bin_size=None, fps=17, vmin=0, vmax=0.8, xmin=0, xmax=None, save_path=None, title=None, title_kwargs=None, fig_text=None, ): n_neurons, n_timepoints = spks.shape if title_kwargs is None: title_kwargs = dict(fontsize=14, fontweight="bold", color="white") if neuron_bin_size is None: neuron_bin_size = max(1, np.ceil(n_neurons // 500)) else: neuron_bin_size = max(1, min(neuron_bin_size, n_neurons)) sn = bin1d(spks[model.isort], neuron_bin_size, axis=0) if xmax is None or xmax < xmin or xmax > sn.shape[1]: xmax = sn.shape[1] sn = sn[:, xmin:xmax] current_time = np.round((xmax - xmin) / fps, 1) current_neurons = sn.shape[0] fig, ax = plt.subplots(figsize=(6, 3), dpi=200) img = ax.imshow(sn, cmap="gray_r", vmin=vmin, vmax=vmax, aspect="auto") fig.patch.set_facecolor("black") ax.set_facecolor("black") ax.tick_params(axis="both", labelbottom=False, labelleft=False, length=0) for spine in ax.spines.values(): spine.set_visible(False) heatmap_pos = ax.get_position() scalebar_length = heatmap_pos.width * 0.1 # 10% width of heatmap scalebar_duration = np.round( current_time * 0.1 # noqa ) # 10% of the displayed time in heatmap x_start = heatmap_pos.x1 - scalebar_length x_end = heatmap_pos.x1 y_position = heatmap_pos.y0 fig.lines.append( plt.Line2D( [x_start, x_end], [y_position - 0.03, y_position - 0.03], transform=fig.transFigure, color="white", linewidth=2, solid_capstyle="butt", ) ) fig.text( x=(x_start + x_end) / 2, y=y_position - 0.045, # slightly below the scalebar s=f"{scalebar_duration:.0f} s", ha="center", va="top", color="white", fontsize=6, ) axins = fig.add_axes( [ # noqa heatmap_pos.x0, # exactly aligned with heatmap's left edge heatmap_pos.y0 - 0.03, # slightly below the heatmap heatmap_pos.width * 0.1, # 20% width of heatmap 0.015, # height of the colorbar ] ) cbar = fig.colorbar(img, cax=axins, orientation="horizontal", ticks=[vmin, vmax]) cbar.ax.tick_params(labelsize=5, colors="white", pad=2) cbar.outline.set_edgecolor("white") # noqa fig.text( heatmap_pos.x0, heatmap_pos.y0 - 0.1, # below the colorbar with spacing "z-scored", ha="left", va="top", color="white", fontsize=6, ) scalebar_neurons = int(0.1 * current_neurons) x_position = heatmap_pos.x1 + 0.01 # slightly right of heatmap y_start = heatmap_pos.y0 y_end = y_start + (heatmap_pos.height * scalebar_neurons / current_neurons) line = plt.Line2D( [x_position, x_position], [y_start, y_end], transform=fig.transFigure, color="white", linewidth=2, ) line.set_figure(fig) fig.lines.append(line) ntype = "neurons" if scalebar_neurons == 1 else "neurons" fig.text( x=x_position + 0.008, y=y_start, s=f"{scalebar_neurons} {ntype}", ha="left", va="bottom", color="white", fontsize=6, rotation=90, ) if fig_text is None: fig_text = f"Neurons: {spks.shape[0]}, Superneurons: {sn.shape[0]}, n_clusters: {model.n_PCs}, n_PCs: {model.n_clusters}, locality: {model.locality}" fig.text( x=(heatmap_pos.x0 + heatmap_pos.x1) / 2, y=y_start - 0.085, # vertically between existing scalebars s=fig_text, ha="center", va="top", color="white", fontsize=6, ) if title is not None: plt.suptitle(title, **title_kwargs) if save_path is not None: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=200, facecolor="black", bbox_inches="tight") plt.close(fig) else: plt.show() return fig, ax
def save_pc_panels_and_metrics(ops, savepath, pcs=(0, 1, 2, 3)): """ Save PC metrics in two forms: 1. Alternating TIFF (PC Low/High side-by-side per frame, press play in ImageJ to flip). 2. Panel TIFF (static figures for PC1/2 and PC3/4). Also saves summary metrics as CSV. Parameters ---------- ops : dict or str or Path Suite2p ops dict or path to ops.npy. Must contain "regPC" and "regDX". savepath : str or Path Output file stem (without extension). pcs : tuple of int PCs to include (default first four). """ if not isinstance(ops, dict): ops = np.load(ops, allow_pickle=True).item() if "nframes" in ops and ops["nframes"] < 1500: print( f"1500 frames needed for registration metrics, found {ops['nframes']}. Skipping PC metrics." ) return {} elif "regPC" not in ops or "regDX" not in ops: print("regPC or regDX not found in ops, skipping PC metrics.") return {} elif len(pcs) != 4 or any(p < 0 for p in pcs): raise ValueError( "pcs must be a tuple of four non-negative integers." " E.g., (0, 1, 2, 3) for the first four PCs." f" Got: {pcs}" ) regPC = ops["regPC"] # shape (2, nPC, Ly, Lx) regDX = ops["regDX"] # shape (nPC, 3) savepath = Path(savepath) alt_frames = [] alt_labels = [] for view, view_name in zip([0, 1], ["Low", "High"]): # side-by-side: PC1 | PC2 left = regPC[view, pcs[0]] right = regPC[view, pcs[1]] combined = np.hstack([left, right]) alt_frames.append(combined.astype(np.float32)) alt_labels.append(f"PC{pcs[0] + 1}/{pcs[1] + 1} {view_name}") # side-by-side: PC3 | PC4 left = regPC[view, pcs[2]] right = regPC[view, pcs[3]] combined = np.hstack([left, right]) alt_frames.append(combined.astype(np.float32)) alt_labels.append(f"PC{pcs[2] + 1}/{pcs[3] + 1} {view_name}") panel_frames = [] panel_labels = [] for left, right in [(pcs[0], pcs[1]), (pcs[2], pcs[3])]: for view, view_name in zip([0, 1], ["Low", "High"]): fig, axes = plt.subplots(1, 2, figsize=(10, 5)) axes[0].imshow(regPC[view, left], cmap="gray") axes[0].set_title(f"PC{left + 1} {view_name}") axes[0].axis("off") axes[1].imshow(regPC[view, right], cmap="gray") axes[1].set_title(f"PC{right + 1} {view_name}") axes[1].axis("off") fig.tight_layout() fig.canvas.draw() img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) # noqa w, h = fig.canvas.get_width_height() img = img.reshape((h, w, 4))[..., :3] panel_frames.append(img) panel_labels.append(f"PC{left + 1}/{right + 1} {view_name}") plt.close(fig) panel_tiff = savepath.with_name(savepath.stem + "_panels.tif") tifffile.imwrite( panel_tiff, np.stack(panel_frames, axis=0), imagej=True, metadata={"Labels": panel_labels}, ) df = pd.DataFrame(regDX, columns=["Rigid", "Avg_NR", "Max_NR"]) metrics = { "Avg_Rigid": df["Rigid"].mean(), "Avg_Average_NR": df["Avg_NR"].mean(), "Avg_Max_NR": df["Max_NR"].mean(), "Max_Rigid": df["Rigid"].max(), "Max_Average_NR": df["Avg_NR"].max(), "Max_Max_NR": df["Max_NR"].max(), } csv_path = savepath.with_suffix(".csv") pd.DataFrame([metrics]).to_csv(csv_path, index=False) return { "panel_tiff": panel_tiff, "metrics_csv": csv_path, } def plot_multiplane_masks( suite2p_path: str | Path, stat: np.ndarray, iscell: np.ndarray, nrows: int = 3, ncols: int = 5, figsize: tuple = (20, 12), save_path: str | Path = None, cmap: str = "gray", ) -> plt.Figure: """ Plot ROI masks from all planes in a publication-quality grid layout. Creates a multi-panel figure showing detected ROIs overlaid on mean images for each z-plane, with accepted cells in green and rejected cells in red. Parameters ---------- suite2p_path : str or Path Path to suite2p directory containing plane folders (e.g., plane01_stitched/). stat : np.ndarray Consolidated stat array with 'iplane' field indicating plane assignment. iscell : np.ndarray Cell classification array (n_rois, 2) where column 0 is binary classification. nrows : int, default 3 Number of rows in the figure grid. ncols : int, default 5 Number of columns in the figure grid. figsize : tuple, default (20, 12) Figure size in inches (width, height). save_path : str or Path, optional If provided, save figure to this path. Otherwise display interactively. cmap : str, default "gray" Colormap for background images. Returns ------- fig : matplotlib.figure.Figure The generated figure object. Examples -------- >>> stat = np.load("merged/stat.npy", allow_pickle=True) >>> iscell = np.load("merged/iscell.npy") >>> fig = plot_multiplane_masks("path/to/suite2p", stat, iscell) """ suite2p_path = Path(suite2p_path) plane_dirs = sorted(suite2p_path.glob("plane*_stitched")) if not plane_dirs: plane_dirs = sorted(suite2p_path.glob("plane*")) nplanes = len(plane_dirs) # Use a clean, publication-ready style with plt.style.context("default"): fig, axes = plt.subplots( nrows, ncols, figsize=figsize, facecolor="white", gridspec_kw={"wspace": 0.05, "hspace": 0.15} ) axes = axes.flatten() for idx, plane_dir in enumerate(plane_dirs): if idx >= len(axes): break ax = axes[idx] # Extract plane number from directory name plane_name = plane_dir.name digits = "".join(filter(str.isdigit, plane_name)) plane_num = int(digits) if digits else idx + 1 # Load plane ops for mean image ops_file = plane_dir / "ops.npy" if ops_file.exists(): plane_ops = np.load(ops_file, allow_pickle=True).item() img = plane_ops.get("meanImg", plane_ops.get("meanImgE")) if img is None: img = np.zeros((plane_ops.get("Ly", 512), plane_ops.get("Lx", 512))) else: img = np.zeros((512, 512)) # Display background image with proper contrast vmin, vmax = np.nanpercentile(img, [1, 99]) ax.imshow(img, cmap=cmap, aspect="equal", vmin=vmin, vmax=vmax) # Get ROIs for this plane plane_mask = np.array([s.get("iplane", 0) == plane_num for s in stat]) plane_stat = stat[plane_mask] plane_iscell = iscell[plane_mask] # Draw accepted cells (green) accepted_idx = plane_iscell[:, 0] == 1 for s in plane_stat[accepted_idx]: ypix, xpix = s["ypix"], s["xpix"] ax.scatter(xpix, ypix, c="lime", s=0.3, alpha=0.7, linewidths=0) # Draw rejected cells (red, more transparent) rejected_idx = plane_iscell[:, 0] == 0 for s in plane_stat[rejected_idx]: ypix, xpix = s["ypix"], s["xpix"] ax.scatter(xpix, ypix, c="red", s=0.2, alpha=0.4, linewidths=0) n_acc = accepted_idx.sum() n_rej = rejected_idx.sum() # Clean title with plane info ax.set_title( f"Plane {plane_num:02d}\n{n_acc} / {n_rej}", fontsize=9, fontweight="bold", pad=3 ) ax.axis("off") # Hide unused subplots for idx in range(nplanes, len(axes)): axes[idx].axis("off") # Add legend from matplotlib.lines import Line2D legend_elements = [ Line2D([0], [0], marker="o", color="w", markerfacecolor="lime", markersize=8, label="Accepted"), Line2D([0], [0], marker="o", color="w", markerfacecolor="red", markersize=8, label="Rejected"), ] fig.legend( handles=legend_elements, loc="lower center", ncol=2, fontsize=10, frameon=False, bbox_to_anchor=(0.5, 0.02) ) plt.tight_layout(rect=[0, 0.05, 1, 1]) if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=200, bbox_inches="tight", facecolor="white") plt.close(fig) else: plt.show() return fig def plot_plane_quality_metrics( stat: np.ndarray, iscell: np.ndarray, save_path: str | Path = None, figsize: tuple = (14, 10), ) -> plt.Figure: """ Generate publication-quality ROI quality metrics across all planes. Creates a multi-panel figure with line plots showing mean ± std: - Compactness vs plane - Skewness vs plane - ROI size (npix) vs plane - Radius vs plane Parameters ---------- stat : np.ndarray Consolidated stat array with 'iplane', 'compact', 'npix' fields. iscell : np.ndarray Cell classification array (n_rois, 2). save_path : str or Path, optional If provided, save figure to this path. figsize : tuple, default (14, 10) Figure size in inches. Returns ------- fig : matplotlib.figure.Figure The generated figure object. Examples -------- >>> stat = np.load("merged/stat.npy", allow_pickle=True) >>> iscell = np.load("merged/iscell.npy") >>> fig = plot_plane_quality_metrics(stat, iscell, save_path="quality.png") """ # Extract metrics plane_nums = np.array([s.get("iplane", 0) for s in stat]) unique_planes = np.unique(plane_nums) n_planes = len(unique_planes) compactness = np.array([s.get("compact", np.nan) for s in stat]) skewness = np.array([s.get("skew", np.nan) for s in stat]) npix = np.array([s.get("npix", 0) for s in stat]) radius = np.array([s.get("radius", np.nan) for s in stat]) accepted = iscell[:, 0] == 1 # Dark theme colors (consistent with plot_volume_diagnostics) bg_color = "black" text_color = "white" colors = { "compactness": "#9b59b6", # Purple "skewness": "#e67e22", # Orange "size": "#3498db", # Blue "radius": "#2ecc71", # Green } mean_line_color = "#e74c3c" # Red for mean markers # Compute mean and std per plane for accepted cells def compute_stats_per_plane(values, plane_nums, accepted, unique_planes): means = [] stds = [] for p in unique_planes: mask = (plane_nums == p) & accepted & ~np.isnan(values) if mask.sum() > 0: means.append(np.mean(values[mask])) stds.append(np.std(values[mask])) else: means.append(np.nan) stds.append(np.nan) return np.array(means), np.array(stds) compact_mean, compact_std = compute_stats_per_plane(compactness, plane_nums, accepted, unique_planes) skew_mean, skew_std = compute_stats_per_plane(skewness, plane_nums, accepted, unique_planes) npix_mean, npix_std = compute_stats_per_plane(npix.astype(float), plane_nums, accepted, unique_planes) radius_mean, radius_std = compute_stats_per_plane(radius, plane_nums, accepted, unique_planes) with plt.style.context("default"): fig, axes = plt.subplots(2, 2, figsize=figsize, facecolor=bg_color) axes = axes.flatten() x = np.arange(n_planes) def style_axis(ax, xlabel, ylabel, title): ax.set_facecolor(bg_color) ax.set_xlabel(xlabel, fontweight="bold", fontsize=10, color=text_color) ax.set_ylabel(ylabel, fontweight="bold", fontsize=10, color=text_color) ax.set_title(title, fontweight="bold", fontsize=11, color=text_color) ax.tick_params(colors=text_color, labelsize=9) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.spines["bottom"].set_color(text_color) ax.spines["left"].set_color(text_color) # Set x-ticks to show plane numbers if n_planes <= 20: ax.set_xticks(x) ax.set_xticklabels([f"{int(p)}" for p in unique_planes]) else: step = max(1, n_planes // 10) ax.set_xticks(x[::step]) ax.set_xticklabels([f"{int(p)}" for p in unique_planes[::step]]) # Panel 1: Compactness ax = axes[0] valid = ~np.isnan(compact_mean) ax.fill_between(x[valid], (compact_mean - compact_std)[valid], (compact_mean + compact_std)[valid], alpha=0.3, color=colors["compactness"]) ax.plot(x[valid], compact_mean[valid], 'o-', color=colors["compactness"], linewidth=2, markersize=5) style_axis(ax, "Z-Plane", "Compactness", "ROI Compactness (Accepted)") # Panel 2: Skewness ax = axes[1] valid = ~np.isnan(skew_mean) ax.fill_between(x[valid], (skew_mean - skew_std)[valid], (skew_mean + skew_std)[valid], alpha=0.3, color=colors["skewness"]) ax.plot(x[valid], skew_mean[valid], 'o-', color=colors["skewness"], linewidth=2, markersize=5) style_axis(ax, "Z-Plane", "Skewness", "Trace Skewness (Accepted)") # Panel 3: ROI Size (npix) ax = axes[2] valid = ~np.isnan(npix_mean) ax.fill_between(x[valid], (npix_mean - npix_std)[valid], (npix_mean + npix_std)[valid], alpha=0.3, color=colors["size"]) ax.plot(x[valid], npix_mean[valid], 'o-', color=colors["size"], linewidth=2, markersize=5) style_axis(ax, "Z-Plane", "Number of Pixels", "ROI Size (Accepted)") # Panel 4: Radius ax = axes[3] valid = ~np.isnan(radius_mean) ax.fill_between(x[valid], (radius_mean - radius_std)[valid], (radius_mean + radius_std)[valid], alpha=0.3, color=colors["radius"]) ax.plot(x[valid], radius_mean[valid], 'o-', color=colors["radius"], linewidth=2, markersize=5) style_axis(ax, "Z-Plane", "Radius (pixels)", "ROI Radius (Accepted)") # Main title total_accepted = int(accepted.sum()) total_rois = len(stat) fig.suptitle( f"Volume Quality Metrics: {total_accepted} accepted / {total_rois} total ROIs", fontsize=12, fontweight="bold", color=text_color, y=0.98 ) plt.tight_layout(rect=[0, 0, 1, 0.96]) if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor=bg_color) plt.close(fig) return fig def plot_trace_analysis( F: np.ndarray, Fneu: np.ndarray, stat: np.ndarray, iscell: np.ndarray, ops: dict, save_path: str | Path = None, figsize: tuple = (16, 14), ) -> Tuple[plt.Figure, dict]: """ Generate trace analysis figure showing extreme examples by quality metrics. Creates a 6-panel figure showing example ΔF/F traces for: - Highest SNR / Lowest SNR - Lowest shot noise / Highest shot noise - Highest skewness / Lowest skewness Parameters ---------- F : np.ndarray Fluorescence traces array (n_rois, n_frames). Fneu : np.ndarray Neuropil fluorescence array (n_rois, n_frames). stat : np.ndarray Stat array with 'iplane' and 'skew' fields. iscell : np.ndarray Cell classification array. ops : dict Ops dictionary with 'fs' (frame rate) field. save_path : str or Path, optional If provided, save figure to this path. figsize : tuple, default (16, 14) Figure size in inches. Returns ------- fig : matplotlib.figure.Figure The generated figure object. metrics : dict Dictionary containing computed metrics (snr, shot_noise, skewness, dff). Examples -------- >>> fig, metrics = plot_trace_analysis(F, Fneu, stat, iscell, ops) >>> print(f"Mean SNR: {np.mean(metrics['snr']):.2f}") """ accepted = iscell[:, 0] == 1 n_accepted = int(np.sum(accepted)) if n_accepted == 0: fig = plt.figure(figsize=figsize, facecolor="black") fig.text(0.5, 0.5, "No accepted ROIs found", ha="center", va="center", fontsize=16, fontweight="bold", color="white") return fig, {} F_acc = F[accepted] Fneu_acc = Fneu[accepted] stat_acc = stat[accepted] plane_nums = np.array([s.get("iplane", 0) for s in stat_acc]) fs = ops.get("fs", 30.0) # Compute ΔF/F F_corrected = F_acc - 0.7 * Fneu_acc baseline = np.percentile(F_corrected, 20, axis=1, keepdims=True) baseline = np.maximum(baseline, 1e-6) dff = (F_corrected - baseline) / baseline # Compute metrics # SNR: signal / noise signal = np.std(dff, axis=1) noise = np.median(np.abs(np.diff(dff, axis=1)), axis=1) / 0.6745 # MAD estimator snr = signal / (noise + 1e-6) # Shot noise: noise level (MAD of diff) shot_noise = noise # Skewness: from stat or compute from trace skewness = np.array([s.get("skew", np.nan) for s in stat_acc]) # Fill NaNs with computed skewness if needed nan_mask = np.isnan(skewness) if nan_mask.any(): from scipy.stats import skew as scipy_skew for i in np.where(nan_mask)[0]: skewness[i] = scipy_skew(dff[i]) # Style configuration bg_color = "black" text_color = "white" # Colors for each metric type colors = { "snr_high": "#2ecc71", # Green - good "snr_low": "#e74c3c", # Red - bad "noise_low": "#3498db", # Blue - good "noise_high": "#e67e22", # Orange - bad "skew_high": "#9b59b6", # Purple - high activity "skew_low": "#95a5a6", # Gray - low activity } # Find indices for each category valid_mask = ~np.isnan(snr) & ~np.isnan(shot_noise) & ~np.isnan(skewness) valid_idx = np.where(valid_mask)[0] if len(valid_idx) == 0: fig = plt.figure(figsize=figsize, facecolor="black") fig.text(0.5, 0.5, "No valid ROIs with computed metrics", ha="center", va="center", fontsize=16, fontweight="bold", color="white") return fig, {} # Get indices for extremes snr_valid = snr[valid_mask] noise_valid = shot_noise[valid_mask] skew_valid = skewness[valid_mask] idx_snr_high = valid_idx[np.argmax(snr_valid)] idx_snr_low = valid_idx[np.argmin(snr_valid)] idx_noise_low = valid_idx[np.argmin(noise_valid)] idx_noise_high = valid_idx[np.argmax(noise_valid)] idx_skew_high = valid_idx[np.argmax(skew_valid)] idx_skew_low = valid_idx[np.argmin(skew_valid)] # Time axis - show up to 100s or full trace n_frames_show = min(int(100 * fs), dff.shape[1]) time = np.arange(n_frames_show) / fs with plt.style.context("default"): fig = plt.figure(figsize=figsize, facecolor=bg_color) gs = gridspec.GridSpec(3, 2, figure=fig, hspace=0.4, wspace=0.25, left=0.08, right=0.95, top=0.92, bottom=0.06) def plot_trace_panel(ax, idx, title, color, metric_name, metric_val): """Plot a single trace panel.""" ax.set_facecolor(bg_color) trace = dff[idx, :n_frames_show] ax.plot(time, trace, color=color, linewidth=0.8, alpha=0.9) # Add zero line ax.axhline(0, color="gray", linestyle="--", linewidth=0.5, alpha=0.5) # Get plane info plane = plane_nums[idx] # Style ax.set_xlabel("Time (s)", fontsize=10, fontweight="bold", color=text_color) ax.set_ylabel("ΔF/F", fontsize=10, fontweight="bold", color=text_color) ax.set_title(f"{title}\n{metric_name}={metric_val:.2f}, Plane {plane}", fontsize=11, fontweight="bold", color=text_color) ax.tick_params(colors=text_color, labelsize=9) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.spines["bottom"].set_color(text_color) ax.spines["left"].set_color(text_color) # Set reasonable y-limits y_max = np.percentile(trace, 99.5) y_min = np.percentile(trace, 0.5) margin = (y_max - y_min) * 0.1 ax.set_ylim(y_min - margin, y_max + margin) # Row 1: SNR extremes ax1 = fig.add_subplot(gs[0, 0]) plot_trace_panel(ax1, idx_snr_high, "Highest SNR", colors["snr_high"], "SNR", snr[idx_snr_high]) ax2 = fig.add_subplot(gs[0, 1]) plot_trace_panel(ax2, idx_snr_low, "Lowest SNR", colors["snr_low"], "SNR", snr[idx_snr_low]) # Row 2: Shot noise extremes ax3 = fig.add_subplot(gs[1, 0]) plot_trace_panel(ax3, idx_noise_low, "Lowest Shot Noise", colors["noise_low"], "Noise", shot_noise[idx_noise_low]) ax4 = fig.add_subplot(gs[1, 1]) plot_trace_panel(ax4, idx_noise_high, "Highest Shot Noise", colors["noise_high"], "Noise", shot_noise[idx_noise_high]) # Row 3: Skewness extremes ax5 = fig.add_subplot(gs[2, 0]) plot_trace_panel(ax5, idx_skew_high, "Highest Skewness", colors["skew_high"], "Skew", skewness[idx_skew_high]) ax6 = fig.add_subplot(gs[2, 1]) plot_trace_panel(ax6, idx_skew_low, "Lowest Skewness", colors["skew_low"], "Skew", skewness[idx_skew_low]) # Main title with summary stats fig.suptitle( f"Trace Quality Extremes: {n_accepted} accepted ROIs | " f"SNR: {np.nanmedian(snr):.1f} (median) | " f"Noise: {np.nanmedian(shot_noise):.3f} (median)", fontsize=12, fontweight="bold", color=text_color, y=0.98 ) if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor=bg_color) plt.close(fig) metrics = { "snr": snr, "shot_noise": shot_noise, "skewness": skewness, "dff": dff, } return fig, metrics def create_volume_summary_table( stat: np.ndarray, iscell: np.ndarray, F: np.ndarray = None, Fneu: np.ndarray = None, ops: dict = None, save_path: str | Path = None, ) -> pd.DataFrame: """ Generates per-plane and aggregate statistics including ROI counts, SNR metrics, and quality measures. Parameters ---------- stat : np.ndarray Consolidated stat array with plane assignments. iscell : np.ndarray Cell classification array. F : np.ndarray, optional Fluorescence traces for SNR calculation. Fneu : np.ndarray, optional Neuropil traces for SNR calculation. ops : dict, optional Ops dictionary with frame rate. save_path : str or Path, optional If provided, save CSV to this path. Returns ------- df : pd.DataFrame Summary statistics table. Examples -------- >>> df = create_volume_summary_table(stat, iscell, F, Fneu, ops) >>> print(df.to_string()) """ accepted = iscell[:, 0] == 1 plane_nums = np.array([s.get("iplane", 0) for s in stat]) unique_planes = np.unique(plane_nums) # Compute SNR if traces provided snr = None mean_F_arr = None if F is not None and Fneu is not None: F_acc = F[accepted] Fneu_acc = Fneu[accepted] F_corrected = F_acc - 0.7 * Fneu_acc baseline = np.percentile(F_corrected, 20, axis=1, keepdims=True) baseline = np.maximum(baseline, 1e-6) dff = (F_corrected - baseline) / baseline signal = np.std(dff, axis=1) noise = np.median(np.abs(np.diff(dff, axis=1)), axis=1) / 0.6745 snr = signal / (noise + 1e-6) mean_F_arr = np.mean(F_acc, axis=1) plane_nums_acc = plane_nums[accepted] else: plane_nums_acc = plane_nums[accepted] # Extract metrics compactness = np.array([s.get("compact", np.nan) for s in stat]) npix = np.array([s.get("npix", 0) for s in stat]) summary_data = [] for p in unique_planes: plane_mask = plane_nums == p plane_mask_acc = plane_nums_acc == p if snr is not None else plane_mask & accepted n_total = plane_mask.sum() n_accepted = (plane_mask & accepted).sum() row = { "Plane": int(p), "Total_ROIs": int(n_total), "Accepted": int(n_accepted), "Rejected": int(n_total - n_accepted), "Accept_Rate_%": f"{100 * n_accepted / max(1, n_total):.1f}", "Mean_Compact": f"{np.nanmean(compactness[plane_mask & accepted]):.2f}", "Mean_Size_px": f"{np.mean(npix[plane_mask & accepted]):.0f}", } if snr is not None and plane_mask_acc.sum() > 0: row["Mean_SNR"] = f"{np.mean(snr[plane_mask_acc]):.2f}" row["Median_SNR"] = f"{np.median(snr[plane_mask_acc]):.2f}" row["High_SNR_%"] = f"{100 * np.sum(snr[plane_mask_acc] > 2) / plane_mask_acc.sum():.1f}" row["Mean_F"] = f"{np.mean(mean_F_arr[plane_mask_acc]):.0f}" summary_data.append(row) df = pd.DataFrame(summary_data) # Add totals row totals = { "Plane": "ALL", "Total_ROIs": int(len(stat)), "Accepted": int(accepted.sum()), "Rejected": int((~accepted).sum()), "Accept_Rate_%": f"{100 * accepted.sum() / len(stat):.1f}", "Mean_Compact": f"{np.nanmean(compactness[accepted]):.2f}", "Mean_Size_px": f"{np.mean(npix[accepted]):.0f}", } if snr is not None: totals["Mean_SNR"] = f"{np.mean(snr):.2f}" totals["Median_SNR"] = f"{np.median(snr):.2f}" totals["High_SNR_%"] = f"{100 * np.sum(snr > 2) / len(snr):.1f}" totals["Mean_F"] = f"{np.mean(mean_F_arr):.0f}" df = pd.concat([df, pd.DataFrame([totals])], ignore_index=True) if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) df.to_csv(save_path, index=False) print(f"Summary table saved to: {save_path}") return df def plot_plane_diagnostics( plane_dir: str | Path, save_path: str | Path = None, figsize: tuple = (16, 14), n_examples: int = 4, ) -> plt.Figure: """ Generate a single-figure diagnostic summary for a processed plane. Creates a publication-quality figure showing: - ROI size distribution (accepted vs rejected) - SNR distribution with quality threshold - Compactness vs SNR scatter - Summary statistics text - Zoomed ROI examples: N highest SNR and N lowest SNR cells Robust to low/zero cell counts - will display informative messages when data is insufficient for certain visualizations. Parameters ---------- plane_dir : str or Path Path to the plane directory containing ops.npy, stat.npy, etc. save_path : str or Path, optional If provided, save figure to this path. figsize : tuple, default (16, 14) Figure size in inches. n_examples : int, default 4 Number of high/low SNR ROI examples to show. Returns ------- fig : matplotlib.figure.Figure The generated figure object. """ plane_dir = Path(plane_dir) # Load results res = load_planar_results(plane_dir) ops = load_ops(plane_dir / "ops.npy") stat = res["stat"] iscell = res["iscell"] F = res["F"] Fneu = res["Fneu"] # Handle edge case: no ROIs at all n_total = len(stat) if n_total == 0: fig = plt.figure(figsize=figsize, facecolor="black") fig.text(0.5, 0.5, "No ROIs detected\n\nCheck detection parameters:\n- threshold_scaling\n- cellprob_threshold\n- diameter", ha="center", va="center", fontsize=16, fontweight="bold", color="white") plane_name = plane_dir.name fig.suptitle(f"Quality Diagnostics: {plane_name}", fontsize=14, fontweight="bold", y=0.98, color="white") if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="black") plt.close(fig) return fig # iscell from load_planar_results is (n_rois, 2): [:, 0] is 0/1, [:, 1] is probability accepted = iscell[:, 0].astype(bool) cell_prob = iscell[:, 1] # classifier probability for each ROI n_accepted = int(accepted.sum()) n_rejected = int((~accepted).sum()) # Compute metrics for ALL ROIs (not just accepted) F_corr = F - 0.7 * Fneu baseline = np.percentile(F_corr, 20, axis=1, keepdims=True) baseline = np.maximum(baseline, 1e-6) dff = (F_corr - baseline) / baseline # SNR calculation for all ROIs signal = np.std(dff, axis=1) noise = np.median(np.abs(np.diff(dff, axis=1)), axis=1) / 0.6745 snr = signal / (noise + 1e-6) # Extract ROI properties npix = np.array([s.get("npix", 0) for s in stat]) compactness = np.array([s.get("compact", np.nan) for s in stat]) skewness = np.array([s.get("skew", np.nan) for s in stat]) fs = ops.get("fs", 30.0) # Compute stats with safe defaults snr_acc = snr[accepted] if n_accepted > 0 else np.array([np.nan]) npix_acc = npix[accepted] if n_accepted > 0 else np.array([0]) mean_snr = np.nanmean(snr_acc) if n_accepted > 0 else 0.0 median_snr = np.nanmedian(snr_acc) if n_accepted > 0 else 0.0 high_snr_pct = 100 * np.sum(snr_acc > 2) / max(1, len(snr_acc)) if n_accepted > 0 else 0.0 mean_size = np.mean(npix_acc) if n_accepted > 0 else 0.0 # Get mean image for ROI zoom panels mean_img = ops.get("meanImgE", ops.get("meanImg")) # Create figure with custom layout - dark background like consolidate.ipynb # Row 0: Size dist, SNR dist, SNR vs Compactness, Activity vs SNR (4 panels) # Row 1: High SNR ROI zooms (n_examples panels) # Row 2: High SNR ROI traces # Row 3: Low SNR ROI zooms (n_examples panels) # Row 4: Low SNR ROI traces fig = plt.figure(figsize=(figsize[0], figsize[1] + 2), facecolor="black") # use nested gridspec: top row has more spacing, bottom rows are tight # Increased gap between top plots (bottom=0.62) and ROI images (top=0.48) # Added 5th row as spacer between high and low SNR groups gs_top = gridspec.GridSpec(1, 4, figure=fig, left=0.06, right=0.98, top=0.95, bottom=0.62, wspace=0.35) gs_bottom = gridspec.GridSpec(5, max(4, n_examples), figure=fig, left=0.02, right=0.98, top=0.48, bottom=0.02, hspace=0.02, wspace=0.08, height_ratios=[1, 0.4, 0.15, 1, 0.4]) # compute activity metric: number of transients (peaks above 2 std) if n_accepted > 0: dff_acc = dff[accepted] activity = np.sum(dff_acc > 2, axis=1) # count frames above 2 std else: activity = np.array([]) # compute shot noise per ROI (standardized noise metric) # shot_noise = median(|diff(dff)|) / sqrt(fs) * 100 (in %/sqrt(Hz)) frame_diffs = np.abs(np.diff(dff, axis=1)) shot_noise = np.median(frame_diffs, axis=1) / np.sqrt(fs) * 100 # Panel 1: ROI size distribution - use step histogram for clarity ax_size = fig.add_subplot(gs_top[0, 0]) ax_size.set_facecolor("black") all_npix = npix[npix > 0] if len(all_npix) > 0: bins = np.linspace(0, np.percentile(all_npix, 99), 40) # Use step histograms with distinct line styles for clear separation if n_accepted > 0: ax_size.hist(npix[accepted], bins=bins, histtype="stepfilled", alpha=0.7, color="#2ecc71", edgecolor="#2ecc71", linewidth=1.5, label=f"Accepted ({n_accepted})") if n_rejected > 0: ax_size.hist(npix[~accepted], bins=bins, histtype="step", color="#e74c3c", linewidth=2, linestyle="-", label=f"Rejected ({n_rejected})") ax_size.legend(fontsize=7, facecolor="#1a1a1a", edgecolor="white", labelcolor="white", loc="upper right") else: ax_size.text(0.5, 0.5, "No ROI data", ha="center", va="center", fontsize=12, color="white") ax_size.set_xlabel("Size (pixels)", fontweight="bold", fontsize=9, color="white") ax_size.set_ylabel("Count", fontweight="bold", fontsize=9, color="white") ax_size.set_title("ROI Size", fontweight="bold", fontsize=10, color="white") ax_size.tick_params(colors="white", labelsize=8) ax_size.spines["top"].set_visible(False) ax_size.spines["right"].set_visible(False) ax_size.spines["bottom"].set_color("white") ax_size.spines["left"].set_color("white") # Panel 2: SNR distribution - use step histogram for clarity ax_snr = fig.add_subplot(gs_top[0, 1]) ax_snr.set_facecolor("black") all_snr = snr[~np.isnan(snr)] if len(all_snr) > 0: bins = np.linspace(0, np.percentile(all_snr, 99), 40) # Filled for accepted, outline for rejected - no overlap confusion if n_accepted > 0: ax_snr.hist(snr[accepted], bins=bins, histtype="stepfilled", alpha=0.7, color="#2ecc71", edgecolor="#2ecc71", linewidth=1.5, label=f"Accepted ({n_accepted})") ax_snr.axvline(median_snr, color="#ffe66d", linestyle="-", linewidth=2, label=f"Median={median_snr:.1f}") if n_rejected > 0: ax_snr.hist(snr[~accepted], bins=bins, histtype="step", color="#e74c3c", linewidth=2, linestyle="-", label=f"Rejected ({n_rejected})") ax_snr.legend(fontsize=7, facecolor="#1a1a1a", edgecolor="white", labelcolor="white", loc="upper right") else: ax_snr.text(0.5, 0.5, "No SNR data", ha="center", va="center", fontsize=12, color="white") ax_snr.set_xlabel("SNR", fontweight="bold", fontsize=9, color="white") ax_snr.set_ylabel("Count", fontweight="bold", fontsize=9, color="white") ax_snr.set_title("SNR Distribution", fontweight="bold", fontsize=10, color="white") ax_snr.tick_params(colors="white", labelsize=8) ax_snr.spines["top"].set_visible(False) ax_snr.spines["right"].set_visible(False) ax_snr.spines["bottom"].set_color("white") ax_snr.spines["left"].set_color("white") # Panels 3 & 4: Compactness vs SNR and Activity vs SNR (shared Y-axis = SNR) # Color by skewness (activity pattern quality metric) ax_compact = fig.add_subplot(gs_top[0, 2]) ax_activity = fig.add_subplot(gs_top[0, 3], sharey=ax_compact) ax_compact.set_facecolor("black") ax_activity.set_facecolor("black") has_scatter_data = False if n_accepted > 0: valid_compact = accepted & ~np.isnan(compactness) & ~np.isnan(skewness) valid_activity = accepted & ~np.isnan(skewness) snr_acc = snr[accepted] skew_acc = skewness[accepted] # Get shared color limits from skewness (more informative than SNR for color) valid_skew = skew_acc[~np.isnan(skew_acc)] if len(valid_skew) > 0: vmin, vmax = np.nanpercentile(valid_skew, [5, 95]) else: vmin, vmax = 0, 1 if valid_compact.sum() > 0: # Panel 3: Compactness vs SNR (SNR on y-axis) sc1 = ax_compact.scatter(compactness[valid_compact], snr[valid_compact], c=skewness[valid_compact], cmap="plasma", alpha=0.7, s=20, vmin=vmin, vmax=vmax) has_scatter_data = True if len(activity) > 0 and valid_activity.sum() > 0: # Panel 4: Activity vs SNR (SNR on y-axis) sc2 = ax_activity.scatter(activity, snr_acc, c=skew_acc, cmap="plasma", alpha=0.7, s=20, vmin=vmin, vmax=vmax) # Add single colorbar for both plots (attached to activity plot) cbar = plt.colorbar(sc2, ax=ax_activity, shrink=0.8) cbar.set_label("Skewness", fontsize=8, color="white") cbar.ax.yaxis.set_tick_params(color="white") cbar.outline.set_edgecolor("white") plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="white") if not has_scatter_data: ax_compact.text(0.5, 0.5, "No data", ha="center", va="center", fontsize=12, color="white") ax_activity.text(0.5, 0.5, "No data", ha="center", va="center", fontsize=12, color="white") ax_compact.set_xlabel("Compactness", fontweight="bold", fontsize=9, color="white") ax_compact.set_ylabel("SNR", fontweight="bold", fontsize=9, color="white") ax_compact.set_title("Compactness vs SNR", fontweight="bold", fontsize=10, color="white") ax_compact.tick_params(colors="white", labelsize=8) ax_compact.spines["top"].set_visible(False) ax_compact.spines["right"].set_visible(False) ax_compact.spines["bottom"].set_color("white") ax_compact.spines["left"].set_color("white") ax_activity.set_xlabel("Active Frames", fontweight="bold", fontsize=9, color="white") ax_activity.set_ylabel("SNR", fontweight="bold", fontsize=9, color="white") ax_activity.set_title("Activity vs SNR", fontweight="bold", fontsize=10, color="white") ax_activity.tick_params(colors="white", labelsize=8) # Hide y-axis labels on right plot since it shares y-axis with left plt.setp(ax_activity.get_yticklabels(), visible=False) ax_activity.spines["top"].set_visible(False) ax_activity.spines["right"].set_visible(False) ax_activity.spines["bottom"].set_color("white") ax_activity.spines["left"].set_color("white") # Helper function to plot zoomed ROI def plot_roi_zoom(ax, roi_idx, img, stat_entry, snr_val, noise_val, color): """Plot a zoomed view of a single ROI with SNR and shot noise.""" ax.set_facecolor("black") ypix = stat_entry["ypix"] xpix = stat_entry["xpix"] # Calculate bounding box with padding pad = 15 y_min, y_max = max(0, ypix.min() - pad), min(img.shape[0], ypix.max() + pad) x_min, x_max = max(0, xpix.min() - pad), min(img.shape[1], xpix.max() + pad) # Extract ROI region roi_img = img[y_min:y_max, x_min:x_max] if roi_img.size == 0: ax.text(0.5, 0.5, "No image", ha="center", va="center", fontsize=10, color="white") ax.axis("off") return vmin, vmax = np.nanpercentile(roi_img, [1, 99]) ax.imshow(roi_img, cmap="gray", vmin=vmin, vmax=vmax, aspect="equal") # Draw ROI outline (shifted to local coordinates) local_y = ypix - y_min local_x = xpix - x_min ax.scatter(local_x, local_y, c=color, s=3, alpha=0.7, linewidths=0) # Title with SNR and shot noise ax.set_title(f"#{roi_idx} SNR={snr_val:.1f} σ={noise_val:.2f}", fontsize=7, fontweight="bold", color=color) ax.axis("off") # Helper function to plot a trace snippet def plot_roi_trace(ax, trace, color, window_frames=500): """Plot a short trace snippet with shrunk Y axis.""" ax.set_facecolor("black") # show first N frames or all if shorter n_show = min(window_frames, len(trace)) trace_segment = trace[:n_show] ax.plot(trace_segment, color=color, linewidth=0.8, alpha=0.9) ax.set_xlim(0, n_show) # Shrink Y axis to 5th-95th percentile to reduce whitespace if len(trace_segment) > 0: y_lo, y_hi = np.nanpercentile(trace_segment, [5, 95]) y_range = y_hi - y_lo if y_range > 0: ax.set_ylim(y_lo - 0.1 * y_range, y_hi + 0.1 * y_range) ax.axis("off") # Row 0-1: High SNR ROI examples with traces if n_accepted > 0 and mean_img is not None: accepted_idx = np.where(accepted)[0] snr_accepted = snr[accepted] n_show = min(n_examples, n_accepted) # Get indices of highest SNR cells top_snr_order = np.argsort(snr_accepted)[::-1][:n_show] for i in range(n_examples): # ROI image ax = fig.add_subplot(gs_bottom[0, i]) if i < n_show: local_idx = top_snr_order[i] global_idx = accepted_idx[local_idx] plot_roi_zoom(ax, global_idx, mean_img, stat[global_idx], snr[global_idx], shot_noise[global_idx], "#2ecc71") # trace below ax_trace = fig.add_subplot(gs_bottom[1, i]) plot_roi_trace(ax_trace, dff[global_idx], "#2ecc71") else: ax.set_facecolor("black") ax.axis("off") ax_trace = fig.add_subplot(gs_bottom[1, i]) ax_trace.set_facecolor("black") ax_trace.axis("off") # Row 2 is spacer (empty) # Row 3-4: Low SNR ROI examples with traces bottom_snr_order = np.argsort(snr_accepted)[:n_show] for i in range(n_examples): # ROI image ax = fig.add_subplot(gs_bottom[3, i]) if i < n_show: local_idx = bottom_snr_order[i] global_idx = accepted_idx[local_idx] plot_roi_zoom(ax, global_idx, mean_img, stat[global_idx], snr[global_idx], shot_noise[global_idx], "#ff6b6b") # trace below ax_trace = fig.add_subplot(gs_bottom[4, i]) plot_roi_trace(ax_trace, dff[global_idx], "#ff6b6b") else: ax.set_facecolor("black") ax.axis("off") ax_trace = fig.add_subplot(gs_bottom[4, i]) ax_trace.set_facecolor("black") ax_trace.axis("off") elif n_rejected > 0 and mean_img is not None: # Show rejected ROIs for diagnostics rejected_idx = np.where(~accepted)[0] snr_rejected = snr[~accepted] n_show = min(n_examples, n_rejected) # High SNR rejected top_snr_order = np.argsort(snr_rejected)[::-1][:n_show] for i in range(n_examples): ax = fig.add_subplot(gs_bottom[0, i]) if i < n_show: local_idx = top_snr_order[i] global_idx = rejected_idx[local_idx] plot_roi_zoom(ax, global_idx, mean_img, stat[global_idx], snr[global_idx], shot_noise[global_idx], "#ff6b6b") ax_trace = fig.add_subplot(gs_bottom[1, i]) plot_roi_trace(ax_trace, dff[global_idx], "#ff6b6b") else: ax.set_facecolor("black") ax.axis("off") ax_trace = fig.add_subplot(gs_bottom[1, i]) ax_trace.set_facecolor("black") ax_trace.axis("off") # Row 2 is spacer # Low SNR rejected bottom_snr_order = np.argsort(snr_rejected)[:n_show] for i in range(n_examples): ax = fig.add_subplot(gs_bottom[3, i]) if i < n_show: local_idx = bottom_snr_order[i] global_idx = rejected_idx[local_idx] plot_roi_zoom(ax, global_idx, mean_img, stat[global_idx], snr[global_idx], shot_noise[global_idx], "#ff6b6b") ax_trace = fig.add_subplot(gs_bottom[4, i]) plot_roi_trace(ax_trace, dff[global_idx], "#ff6b6b") else: ax.set_facecolor("black") ax.axis("off") ax_trace = fig.add_subplot(gs_bottom[4, i]) ax_trace.set_facecolor("black") ax_trace.axis("off") else: # No image available for row in [0, 1, 3, 4]: # Skip spacer row 2 for i in range(n_examples): ax = fig.add_subplot(gs_bottom[row, i]) ax.set_facecolor("black") if row in [0, 3]: ax.text(0.5, 0.5, "No data", ha="center", va="center", fontsize=8, color="white") ax.axis("off") # Main title plane_name = plane_dir.name fig.suptitle(f"Quality Diagnostics: {plane_name}", fontsize=14, fontweight="bold", y=0.98, color="white") # No tight_layout - we use manual GridSpec positioning for precise control if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="black") plt.close(fig) else: plt.show() return fig def mask_dead_zones_in_ops(ops, threshold=0.01): """ Mask out dead zones from registration shifts in ops image arrays. Dead zones appear as very dark regions (near zero intensity) at the edges of images after suite3D alignment shifts are applied. Parameters ---------- ops : dict Suite2p ops dictionary containing image arrays threshold : float Fraction of max intensity to use as cutoff (default 0.01 = 1%) Returns ------- ops : dict Modified ops with dead zones set to NaN in image arrays """ if "meanImg" not in ops: return ops # Use meanImg to identify valid regions mean_img = ops["meanImg"] valid_mask = mean_img > (mean_img.max() * threshold) n_invalid = (~valid_mask).sum() if n_invalid > 0: pct_invalid = 100 * n_invalid / valid_mask.size print(f"[mask_dead_zones] Masking {n_invalid} ({pct_invalid:.1f}%) dead zone pixels") # Mask all image arrays in ops for key in ["meanImg", "meanImgE", "max_proj", "Vcorr"]: if key in ops and isinstance(ops[key], np.ndarray): img = ops[key] # Only apply mask if shapes match if img.shape == valid_mask.shape: # Convert to float and set invalid regions to NaN ops[key] = img.astype(float) ops[key][~valid_mask] = np.nan else: print(f"[mask_dead_zones] Skipping {key}: shape {img.shape} != meanImg shape {valid_mask.shape}") return ops def plot_zplane_figures( plane_dir, dff_percentile=8, dff_window_size=None, dff_smooth_window=None, run_rastermap=False, **kwargs ): """ Re-generate Suite2p figures for a merged plane. Parameters ---------- plane_dir : Path Path to the planeXX output directory (with ops.npy, stat.npy, etc.). dff_percentile : int, optional Percentile used for ΔF/F baseline. dff_window_size : int, optional Window size for ΔF/F rolling baseline. If None, auto-calculated as ~10 × tau × fs based on ops values. dff_smooth_window : int, optional Temporal smoothing window for dF/F traces (in frames). If None, auto-calculated as ~0.5 × tau × fs to emphasize transients while reducing noise. Set to 1 to disable. run_rastermap : bool, optional If True, compute and plot rastermap sorting of cells. kwargs : dict Extra keyword args (e.g. fig_label). """ plane_dir = Path(plane_dir) # File naming convention: numbered prefixes ensure proper alphabetical ordering # 01_correlation -> 02_max_projection -> 03_mean -> 04_mean_enhanced # each image immediately followed by its _segmentation variant expected_files = { "ops": plane_dir / "ops.npy", "stat": plane_dir / "stat.npy", "iscell": plane_dir / "iscell.npy", # Summary images with segmentation overlays - numbered for proper ordering "correlation_image": plane_dir / "01_correlation.png", "correlation_segmentation": plane_dir / "01_correlation_segmentation.png", "max_proj": plane_dir / "02_max_projection.png", "max_proj_segmentation": plane_dir / "02_max_projection_segmentation.png", "meanImg": plane_dir / "03_mean.png", "meanImg_segmentation": plane_dir / "03_mean_segmentation.png", "meanImgE": plane_dir / "04_mean_enhanced.png", "meanImgE_segmentation": plane_dir / "04_mean_enhanced_segmentation.png", # Diagnostics and analysis "quality_diagnostics": plane_dir / "05_quality_diagnostics.png", "registration": plane_dir / "06_registration.png", # Traces - multiple cell counts "traces_raw_20": plane_dir / "07a_traces_raw_20.png", "traces_raw_50": plane_dir / "07b_traces_raw_50.png", "traces_raw_100": plane_dir / "07c_traces_raw_100.png", "traces_dff_20": plane_dir / "08a_traces_dff_20.png", "traces_dff_50": plane_dir / "08b_traces_dff_50.png", "traces_dff_100": plane_dir / "08c_traces_dff_100.png", "traces_rejected": plane_dir / "09_traces_rejected.png", # Noise distributions "noise_acc": plane_dir / "10_shot_noise_accepted.png", "noise_rej": plane_dir / "11_shot_noise_rejected.png", # Rastermap "model": plane_dir / "model.npy", "rastermap": plane_dir / "12_rastermap.png", } output_ops = load_ops(expected_files["ops"]) # Dead zones are now handled via yrange/xrange cropping in run_lsp.py # so we don't need to mask them here anymore # output_ops = mask_dead_zones_in_ops(output_ops) # force remake of the heavy figures for key in [ "registration", "traces_raw_20", "traces_raw_50", "traces_raw_100", "traces_dff_20", "traces_dff_50", "traces_dff_100", "traces_rejected", "noise_acc", "noise_rej", "rastermap", ]: if key in expected_files: if expected_files[key].exists(): try: expected_files[key].unlink() except PermissionError: print(f"Error: Cannot delete {expected_files[key]}, it's open elsewhere.") if expected_files["stat"].is_file(): res = load_planar_results(plane_dir) # iscell is (n_rois, 2): [:, 0] is 0/1, [:, 1] is classifier probability iscell_mask = res["iscell"][:, 0].astype(bool) cell_prob = res["iscell"][:, 1] spks = res["spks"] F = res["F"] # Split by accepted/rejected F_accepted = F[iscell_mask] if iscell_mask.sum() > 0 else np.zeros((0, F.shape[1])) F_rejected = F[~iscell_mask] if (~iscell_mask).sum() > 0 else np.zeros((0, F.shape[1])) spks_cells = spks[iscell_mask] if iscell_mask.sum() > 0 else np.zeros((0, spks.shape[1])) n_accepted = F_accepted.shape[0] n_rejected = F_rejected.shape[0] print(f"Plotting results for {n_accepted} accepted / {n_rejected} rejected ROIs") # --- RASTERMAP (only for sufficient cell counts) --- # rastermap sorts neurons by activity similarity for visualization # we cache the model to avoid recomputing, but validate it matches current data model = None if run_rastermap and n_accepted >= 2: try: from lbm_suite2p_python.zplane import plot_rastermap import rastermap has_rastermap = True except ImportError: print("rastermap not found. Install via: pip install rastermap") print(" or: pip install mbo_utilities[rastermap]") has_rastermap = False rastermap, plot_rastermap = None, None if has_rastermap: model_file = expected_files["model"] plot_file = expected_files["rastermap"] need_recompute = True # check if cached model exists and is valid for current cell count if model_file.is_file(): try: cached_model = np.load(model_file, allow_pickle=True).item() # Handle both direct model objects and dict wrappers if hasattr(cached_model, "isort"): cached_isort = cached_model.isort elif isinstance(cached_model, dict) and "isort" in cached_model: cached_isort = cached_model["isort"] else: cached_isort = None if cached_isort is not None and len(cached_isort) == n_accepted: model = cached_model need_recompute = False print(f" Using cached rastermap model ({n_accepted} cells)") else: # stale model - cell count changed since last run cached_len = len(cached_isort) if cached_isort is not None else "?" print(f" Rastermap model stale (cached {cached_len} vs current {n_accepted} cells), recomputing...") model_file.unlink() except Exception as e: print(f" Failed to load cached rastermap model: {e}, recomputing...") model_file.unlink(missing_ok=True) # fit new model if needed if need_recompute: print(f" Computing rastermap model for {n_accepted} cells...") params = { "n_clusters": 100 if n_accepted >= 200 else None, "n_PCs": min(128, max(2, n_accepted - 1)), "locality": 0.0 if n_accepted >= 200 else 0.1, "time_lag_window": 15, "grid_upsample": 10 if n_accepted >= 200 else 0, } model = rastermap.Rastermap(**params).fit(spks_cells) np.save(model_file, model) # regenerate plot if missing (even if model was cached) if model is not None and not plot_file.is_file(): plot_rastermap( spks_cells, model, neuron_bin_size=0, save_path=plot_file, title_kwargs={"fontsize": 8, "y": 0.95}, title="Rastermap Sorted Activity", ) # apply sorting to traces for downstream plots if model is not None: # Handle both direct model objects and dict wrappers if hasattr(model, "isort"): isort = model.isort elif isinstance(model, dict) and "isort" in model: isort = model["isort"] else: isort = None if isort is not None: isort_global = np.where(iscell_mask)[0][isort] output_ops["isort"] = isort_global F_accepted = F_accepted[isort] # --- COMPUTE ΔF/F --- fs = output_ops.get("fs", 1.0) tau = output_ops.get("tau", 1.0) # Compute unsmoothed dF/F for shot noise (smoothing reduces frame-to-frame variance) if n_accepted > 0: dffp_acc_unsmoothed = dff_rolling_percentile( F_accepted, percentile=dff_percentile, window_size=dff_window_size, smooth_window=1, # No smoothing for shot noise fs=fs, tau=tau, ) * 100 # Smoothed version for trace plotting dffp_acc = dff_rolling_percentile( F_accepted, percentile=dff_percentile, window_size=dff_window_size, smooth_window=dff_smooth_window, fs=fs, tau=tau, ) * 100 else: dffp_acc_unsmoothed = np.zeros((0, F.shape[1])) dffp_acc = np.zeros((0, F.shape[1])) if n_rejected > 0: dffp_rej_unsmoothed = dff_rolling_percentile( F_rejected, percentile=dff_percentile, window_size=dff_window_size, smooth_window=1, # No smoothing for shot noise fs=fs, tau=tau, ) * 100 # Smoothed version for trace plotting dffp_rej = dff_rolling_percentile( F_rejected, percentile=dff_percentile, window_size=dff_window_size, smooth_window=dff_smooth_window, fs=fs, tau=tau, ) * 100 else: dffp_rej_unsmoothed = np.zeros((0, F.shape[1])) dffp_rej = np.zeros((0, F.shape[1])) # --- TRACE PLOTS (robust to any cell count >= 1) --- # Sort traces by quality score (SNR, skewness, shot noise) for visualization # Generate plots with 20, 50, and 100 cells if available if n_accepted > 0: # Get accepted cell stat for skewness stat_accepted = [s for s, m in zip(res["stat"], iscell_mask) if m] # Compute quality scores and sort quality = compute_trace_quality_score( F_accepted, Fneu=res["Fneu"][iscell_mask] if "Fneu" in res else None, stat=stat_accepted, fs=fs, ) quality_sort_idx = quality["sort_idx"] # Sort traces by quality (best first) dffp_acc_sorted = dffp_acc[quality_sort_idx] F_accepted_sorted = F_accepted[quality_sort_idx] # Generate trace plots at multiple cell counts cell_counts = [20, 50, 100] for n_cells in cell_counts: if n_accepted >= n_cells: # dF/F traces (percent) plot_traces( dffp_acc_sorted, save_path=expected_files[f"traces_dff_{n_cells}"], num_neurons=n_cells, scale_bar_unit=r"% $\Delta$F/F$_0$", title=rf"Top {n_cells} $\Delta$F/F Traces by Quality (n={n_accepted} total)", ) # Raw traces plot_traces( F_accepted_sorted, save_path=expected_files[f"traces_raw_{n_cells}"], num_neurons=n_cells, scale_bar_unit="a.u.", title=f"Top {n_cells} Raw Traces by Quality (n={n_accepted} total)", ) elif n_cells == 20: # Always generate 20-cell plot even if fewer cells available plot_traces( dffp_acc_sorted, save_path=expected_files["traces_dff_20"], num_neurons=min(20, n_accepted), scale_bar_unit=r"% $\Delta$F/F$_0$", title=rf"Top {min(20, n_accepted)} $\Delta$F/F Traces by Quality (n={n_accepted} total)", ) plot_traces( F_accepted_sorted, save_path=expected_files["traces_raw_20"], num_neurons=min(20, n_accepted), scale_bar_unit="a.u.", title=f"Top {min(20, n_accepted)} Raw Traces by Quality (n={n_accepted} total)", ) else: print(" No accepted cells - skipping accepted trace plots") if n_rejected > 0: plot_traces( dffp_rej, save_path=expected_files["traces_rejected"], num_neurons=min(20, n_rejected), scale_bar_unit=r"% $\Delta$F/F$_0$", title=rf"$\Delta$F/F Traces - Rejected ROIs (n={n_rejected})", ) else: print(" No rejected ROIs - skipping rejected trace plots") # --- NOISE DISTRIBUTIONS (robust to any cell count >= 1) --- # Use unsmoothed dF/F for shot noise (smoothing artificially reduces noise) if n_accepted > 0: dff_noise_acc = dff_shot_noise(dffp_acc_unsmoothed, fs) plot_noise_distribution( dff_noise_acc, output_filename=expected_files["noise_acc"], title=f"Shot-Noise Distribution (Accepted, n={n_accepted})", ) if n_rejected > 0: dff_noise_rej = dff_shot_noise(dffp_rej_unsmoothed, fs) plot_noise_distribution( dff_noise_rej, output_filename=expected_files["noise_rej"], title=f"Shot-Noise Distribution (Rejected, n={n_rejected})", ) # --- SEGMENTATION OVERLAYS --- # Suite2p stores images in two coordinate systems: # - FULL space: refImg, meanImg, meanImgE (same size as original Ly x Lx) # - CROPPED space: max_proj, Vcorr (size determined by yrange/xrange after registration) # The stat coordinates are in FULL image space. stat_full = res["stat"] # stat coordinates in full image space # Helper to check if image is valid def _is_valid_image(img): if img is None: return False if isinstance(img, (int, float)) and img == 0: return False if isinstance(img, np.ndarray) and img.size == 0: return False return True # Get crop parameters for images in cropped space yrange = output_ops.get("yrange", [0, output_ops.get("Ly", 512)]) xrange = output_ops.get("xrange", [0, output_ops.get("Lx", 512)]) ymin, xmin = int(yrange[0]), int(xrange[0]) # Create stat with adjusted coordinates for cropped image space if ymin > 0 or xmin > 0: stat_cropped = [] for s in stat_full: s_adj = s.copy() s_adj["ypix"] = s["ypix"] - ymin s_adj["xpix"] = s["xpix"] - xmin stat_cropped.append(s_adj) else: stat_cropped = stat_full # Images in FULL space - use stat_full full_space_images = { "meanImg": ("Mean Image", expected_files["meanImg_segmentation"]), "meanImgE": ("Enhanced Mean Image", expected_files["meanImgE_segmentation"]), } for img_key, (title_name, save_file) in full_space_images.items(): img = output_ops.get(img_key) if _is_valid_image(img): if n_accepted > 0: plot_masks( img=img, stat=stat_full, mask_idx=iscell_mask, savepath=save_file, title=f"{title_name} - Accepted ROIs (n={n_accepted})" ) else: plot_projection( output_ops, save_file, fig_label=kwargs.get("fig_label", plane_dir.stem), display_masks=False, add_scalebar=True, proj=img_key, ) # Images in CROPPED space - use stat_cropped cropped_space_images = { "max_proj": ("Max Projection", expected_files["max_proj_segmentation"]), } for img_key, (title_name, save_file) in cropped_space_images.items(): img = output_ops.get(img_key) if _is_valid_image(img): if n_accepted > 0: plot_masks( img=img, stat=stat_cropped, mask_idx=iscell_mask, savepath=save_file, title=f"{title_name} - Accepted ROIs (n={n_accepted})" ) else: plot_projection( output_ops, save_file, fig_label=kwargs.get("fig_label", plane_dir.stem), display_masks=False, add_scalebar=True, proj=img_key, ) # Correlation image (Vcorr) - in CROPPED space vcorr = output_ops.get("Vcorr") if _is_valid_image(vcorr): # Save correlation image without masks fig, ax = plt.subplots(figsize=(8, 8), facecolor="black") ax.set_facecolor("black") ax.imshow(vcorr, cmap="gray") ax.set_title("Correlation Image", color="white", fontweight="bold") ax.axis("off") plt.tight_layout() plt.savefig(expected_files["correlation_image"], dpi=150, facecolor="black") plt.close(fig) # Correlation image with segmentation if n_accepted > 0: plot_masks( img=vcorr, stat=stat_cropped, mask_idx=iscell_mask, savepath=expected_files["correlation_segmentation"], title=f"Correlation Image - Accepted ROIs (n={n_accepted})" ) # --- SUMMARY IMAGES (no masks) - always generated --- fig_label = kwargs.get("fig_label", plane_dir.stem) for key in ["meanImg", "max_proj", "meanImgE"]: if key in output_ops and output_ops[key] is not None: try: plot_projection( output_ops, expected_files[key], fig_label=fig_label, display_masks=False, add_scalebar=True, proj=key, ) except Exception as e: print(f" Failed to plot {key}: {e}") # --- QUALITY DIAGNOSTICS --- try: plot_plane_diagnostics(plane_dir, save_path=expected_files["quality_diagnostics"]) except Exception as e: print(f" Failed to generate quality diagnostics: {e}") return output_ops