Source code for lbm_suite2p_python.volume

import glob
import subprocess
from pathlib import Path

import cv2
import numpy as np
from matplotlib import pyplot as plt

from lbm_suite2p_python.utils import get_common_path
from lbm_suite2p_python.postprocessing import load_ops


def update_ops_paths(ops_files: str | list):
    """
    Update save_path, save_path0, and save_folder in an ops dictionary based on its current location. Use after moving an ops_file or batch of ops_files.
    """
    if isinstance(ops_files, (str, Path)):
        ops_files = [ops_files]

    for ops_file in ops_files:
        ops = np.load(ops_file, allow_pickle=True).item()

        ops_path = Path(ops_file)
        plane0_folder = ops_path.parent
        plane_folder = plane0_folder.parent

        ops["save_path"] = str(plane0_folder)
        ops["save_path0"] = str(plane_folder)
        ops["save_folder"] = plane_folder.name
        ops["ops_path"] = ops_path

        np.save(ops_file, ops)


[docs] def plot_execution_time(filepath, savepath): """ Plots the execution time for each processing step per z-plane. This function loads execution timing data from a `.npy` file and visualizes the runtime of different processing steps as a stacked bar plot with a black background. Parameters ---------- filepath : str or Path Path to the `.npy` file containing the volume timing stats. savepath : str or Path Path to save the generated figure. Notes ----- - The `.npy` file should contain structured data with `plane`, `registration`, `detection`, `extraction`, `classification`, `deconvolution`, and `total_plane_runtime` fields. """ plane_stats = np.load(filepath) planes = plane_stats["plane"] reg_time = plane_stats["registration"] detect_time = plane_stats["detection"] extract_time = plane_stats["extraction"] total_time = plane_stats["total_plane_runtime"] plt.figure(figsize=(10, 6), facecolor="black") ax = plt.gca() ax.set_facecolor("black") plt.xlabel("Z-Plane", fontsize=14, fontweight="bold", color="white") plt.ylabel("Execution Time (s)", fontsize=14, fontweight="bold", color="white") plt.title( "Execution Time per Processing Step", fontsize=16, fontweight="bold", color="white", ) plt.bar(planes, reg_time, label="Registration", alpha=0.8, color="#FF5733") plt.bar( planes, detect_time, label="Detection", alpha=0.8, bottom=reg_time, color="#33FF57", ) bars3 = plt.bar( planes, extract_time, label="Extraction", alpha=0.8, bottom=reg_time + detect_time, color="#3357FF", ) for bar, total in zip(bars3, total_time): height = bar.get_y() + bar.get_height() if total > 1: # Only label if execution time is large enough to be visible plt.text( bar.get_x() + bar.get_width() / 2, height + 2, f"{int(total)}", ha="center", va="bottom", fontsize=12, color="white", fontweight="bold", ) plt.xticks(planes, fontsize=12, fontweight="bold", color="white") plt.yticks(fontsize=12, fontweight="bold", color="white") plt.grid(axis="y", linestyle="--", alpha=0.4, color="white") ax.spines["bottom"].set_color("white") ax.spines["left"].set_color("white") ax.spines["top"].set_color("white") ax.spines["right"].set_color("white") plt.legend( fontsize=12, facecolor="black", edgecolor="white", labelcolor="white", loc="upper left", bbox_to_anchor=(1, 1), ) plt.savefig(savepath, bbox_inches="tight", facecolor="black") plt.show()
[docs] def plot_volume_signal(zstats, savepath): """ Plots the mean fluorescence signal per z-plane with standard deviation error bars. This function loads signal statistics from a `.npy` file and visualizes the mean fluorescence signal per z-plane, with error bars representing the standard deviation. Parameters ---------- zstats : str or Path Path to the `.npy` file containing the volume stats. The output of `get_zstats()`. savepath : str or Path Path to save the generated figure. Notes ----- - The `.npy` file should contain structured data with `plane`, `mean_trace`, and `std_trace` fields. - Error bars represent the standard deviation of the fluorescence signal. """ plane_stats = np.load(zstats) planes = plane_stats["plane"] mean_signal = plane_stats["mean_trace"] std_signal = plane_stats["std_trace"] fig, ax = plt.subplots(figsize=(10, 5), facecolor="black") ax.set_facecolor("black") ax.errorbar( planes, mean_signal, yerr=std_signal, fmt="o-", color="#3498db", ecolor="#85c1e9", elinewidth=1.5, capsize=3, markersize=5, alpha=0.9, label="Mean ± STD", ) ax.set_xlabel("Z-Plane", fontsize=10, fontweight="bold", color="white") ax.set_ylabel("Mean Raw Signal", fontsize=10, fontweight="bold", color="white") ax.set_title("Mean Fluorescence Signal per Z-Plane", fontsize=11, fontweight="bold", color="white") ax.tick_params(colors="white", labelsize=9) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.spines["bottom"].set_color("white") ax.spines["left"].set_color("white") ax.legend(fontsize=8, facecolor="#1a1a1a", edgecolor="white", labelcolor="white") plt.savefig(savepath, bbox_inches="tight", facecolor="black", dpi=150) plt.close(fig)
def plot_volume_neuron_counts(zstats, savepath): """ Plots the number of accepted and rejected neurons per z-plane. This function loads neuron count data from a `.npy` file and visualizes the accepted vs. rejected neurons as a stacked bar plot with a black background. Parameters ---------- zstats : str, Path Full path to the zstats.npy file. savepath : str or Path Path to directory where generated figure will be saved. Notes ----- - The `.npy` file should contain structured data with `plane`, `accepted`, and `rejected` fields. """ zstats = Path(zstats) if not zstats.is_file(): raise FileNotFoundError(f"{zstats} is not a valid zstats.npy file.") plane_stats = np.load(zstats) savepath = Path(savepath) planes = plane_stats["plane"] accepted = plane_stats["accepted"] rejected = plane_stats["rejected"] savename = savepath.joinpath( f"all_neurons_{accepted.sum()}acc_{rejected.sum()}rej.png" ) fig, ax = plt.subplots(figsize=(10, 5), facecolor="black") ax.set_facecolor("black") bar_width = 0.8 bars1 = ax.bar( planes, accepted, width=bar_width, label=f"Accepted ({accepted.sum()})", alpha=0.85, color="#2ecc71", edgecolor="#27ae60", linewidth=0.5 ) bars2 = ax.bar( planes, rejected, width=bar_width, bottom=accepted, label=f"Rejected ({rejected.sum()})", alpha=0.85, color="#e74c3c", edgecolor="#c0392b", linewidth=0.5 ) # Add count labels inside bars (only if tall enough) for bar in bars1: height = bar.get_height() if height > 5: # Only label if tall enough to fit text ax.text( bar.get_x() + bar.get_width() / 2, height / 2, f"{int(height)}", ha="center", va="center", fontsize=8, color="white", fontweight="bold" ) for bar1, bar2 in zip(bars1, bars2): height1 = bar1.get_height() height2 = bar2.get_height() if height2 > 5: ax.text( bar2.get_x() + bar2.get_width() / 2, height1 + height2 / 2, f"{int(height2)}", ha="center", va="center", fontsize=8, color="white", fontweight="bold" ) ax.set_xlabel("Z-Plane", fontsize=10, fontweight="bold", color="white") ax.set_ylabel("Number of ROIs", fontsize=10, fontweight="bold", color="white") ax.set_title("Accepted vs Rejected ROIs per Z-Plane", fontsize=11, fontweight="bold", color="white") ax.tick_params(colors="white", labelsize=9) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.spines["bottom"].set_color("white") ax.spines["left"].set_color("white") ax.legend(fontsize=8, facecolor="#1a1a1a", edgecolor="white", labelcolor="white", loc="upper right") plt.savefig(savename, bbox_inches="tight", facecolor="black", dpi=150) plt.close(fig) def get_volume_stats(ops_files: list[str | Path], overwrite: bool = True): """ Given a list of ops.npy files, accumulate common statistics for assessing zplane quality. Parameters ---------- ops_files : list of str or Path Each item in the list should be a path pointing to a z-lanes `ops.npy` file. The number of items in this list should match the number of z-planes in your session. overwrite : bool If a file already exists, it will be overwritten. Defaults to True. Notes ----- - The `.npy` file should contain structured data with `plane`, `accepted`, and `rejected` fields. """ if not ops_files: print("No ops files found.") return None plane_stats = {} for i, file in enumerate(ops_files): output_ops = load_ops(file) raw_z = output_ops.get("plane", None) if raw_z is None: zplane_num = i else: if isinstance(raw_z, (int, np.integer)): zplane_num = int(raw_z) else: s = str(raw_z) digits = "".join([c for c in s if c.isdigit()]) zplane_num = int(digits) if digits else i save_path = Path(output_ops["save_path"]) timing = output_ops.get("timing", {}) # Check if required files exist iscell_file = save_path / "iscell.npy" traces_file = save_path / "F.npy" if not iscell_file.exists(): print(f"Skipping plane {zplane_num}: iscell.npy not found at {save_path}") continue if not traces_file.exists(): print(f"Skipping plane {zplane_num}: F.npy not found at {save_path}") continue # Load files try: iscell_raw = np.load(iscell_file, allow_pickle=True) traces = np.load(traces_file, allow_pickle=True) # Validate iscell data - check for _NoValue or other invalid types if not isinstance(iscell_raw, np.ndarray) or iscell_raw.size == 0: print(f"Skipping plane {zplane_num}: iscell.npy is empty or invalid") continue # Handle potential _NoValue entries by filtering to valid numeric data try: iscell = iscell_raw[:, 0].astype(bool) except (TypeError, ValueError) as e: print(f"Skipping plane {zplane_num}: iscell data conversion failed - {e}") continue # Validate traces if not isinstance(traces, np.ndarray) or traces.size == 0: print(f"Skipping plane {zplane_num}: F.npy is empty or invalid") continue except Exception as e: print(f"Skipping plane {zplane_num}: Error loading files - {e}") continue # Safe stat computations with explicit conversion try: n_accepted = int(np.sum(iscell)) n_rejected = int(np.sum(~iscell)) trace_mean = float(np.nanmean(traces)) trace_std = float(np.nanstd(traces)) except (TypeError, ValueError) as e: print(f"Skipping plane {zplane_num}: Error computing statistics - {e}") continue plane_stats[zplane_num] = { "accepted": n_accepted, "rejected": n_rejected, "mean": trace_mean, "std": trace_std, "registration": timing.get("registration", np.nan), "detection": timing.get("detection", timing.get("detect", np.nan)), "extraction": timing.get("extraction", np.nan), "classification": timing.get("classification", np.nan), "deconvolution": timing.get("deconvolution", np.nan), "total_runtime": timing.get("total_plane_runtime", np.nan), "filepath": str(file), "zplane": zplane_num, } # Check if any planes had valid statistics if not plane_stats: print("No valid plane statistics collected - all planes skipped or missing files") return None common = get_common_path(ops_files) out = [] for p, stats in sorted(plane_stats.items()): out.append( ( p, stats["accepted"], stats["rejected"], stats["mean"], stats["std"], stats["registration"], stats["detection"], stats["extraction"], stats["classification"], stats["deconvolution"], stats["total_runtime"], stats["filepath"], stats["zplane"], ) ) dtype = [ ("plane", "i4"), ("accepted", "i4"), ("rejected", "i4"), ("mean_trace", "f8"), ("std_trace", "f8"), ("registration", "f8"), ("detection", "f8"), ("extraction", "f8"), ("classification", "f8"), ("deconvolution", "f8"), ("total_plane_runtime", "f8"), ("filepath", "U255"), ("zplane", "i4"), ] arr = np.array(out, dtype=dtype) save_path = Path(common) / "zstats.npy" if overwrite or not save_path.exists(): np.save(save_path, arr) return str(save_path) def save_images_to_movie(image_input, savepath, duration=None, format=".mp4"): """ Convert a sequence of saved images into a movie. TODO: move to mbo_utilities. Parameters ---------- image_input : str, Path, or list Directory containing saved segmentation images or a list of image file paths. savepath : str or Path Path to save the video file. duration : int, optional Desired total video duration in seconds. If None, defaults to 1 FPS (1 image per second). format : str, optional Video format: ".mp4" (PowerPoint-compatible), ".avi" (lossless), ".mov" (ProRes). Default is ".mp4". Examples -------- >>> import mbo_utilities as mbo >>> import lbm_suite2p_python as lsp Get all png files autosaved during LBM-Suite2p-Python `run_volume()` >>> segmentation_pngs = mbo.get_files("path/suite3d/results/", "segmentation.png", max_depth=3) >>> lsp.save_images_to_movie(segmentation_pngs, "path/to/save/segmentation.png", format=".mp4") """ savepath = Path(savepath).with_suffix(format) # Ensure correct file extension temp_video = savepath.with_suffix(".avi") # Temporary AVI file for MOV conversion savepath.parent.mkdir(parents=True, exist_ok=True) if isinstance(image_input, (str, Path)): image_dir = Path(image_input) image_files = sorted( glob.glob(str(image_dir / "*.png")) + glob.glob(str(image_dir / "*.jpg")) + glob.glob(str(image_dir / "*.tif")) ) elif isinstance(image_input, list): image_files = sorted(map(str, image_input)) else: raise ValueError( "image_input must be a directory path or a list of file paths." ) if not image_files: return first_image = cv2.imread(image_files[0]) height, width, _ = first_image.shape fps = len(image_files) / duration if duration else 1 if format == ".mp4": fourcc = cv2.VideoWriter_fourcc(*"XVID") video_path = savepath elif format == ".avi": fourcc = cv2.VideoWriter_fourcc(*"HFYU") video_path = savepath elif format == ".mov": fourcc = cv2.VideoWriter_fourcc(*"HFYU") video_path = temp_video else: raise ValueError("Invalid format. Use '.mp4', '.avi', or '.mov'.") video_writer = cv2.VideoWriter( str(video_path), fourcc, max(fps, 1), (width, height) ) for image_file in image_files: frame = cv2.imread(image_file) video_writer.write(frame) video_writer.release() if format == ".mp4": ffmpeg_cmd = [ "ffmpeg", "-y", "-i", str(video_path), "-vcodec", "libx264", "-acodec", "aac", "-preset", "slow", "-crf", "18", str(savepath), # Save directly to `savepath` ] subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) print(f"MP4 saved at {savepath}") elif format == ".mov": ffmpeg_cmd = [ "ffmpeg", "-y", "-i", str(temp_video), "-c:v", "prores_ks", # Use Apple ProRes codec "-profile:v", "3", # ProRes 422 LT "-pix_fmt", "yuv422p10le", str(savepath), ] subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) temp_video.unlink() def consolidate_volume(suite2p_path, merged_dir="merged", overwrite=False): """ Consolidate all plane results into a single merged directory. Combines ops.npy, stat.npy, iscell.npy, F.npy, Fneu.npy, and spks.npy from all planeXX_stitched folders into a single merged/ folder with plane-indexed arrays. Parameters ---------- suite2p_path : str or Path Path to suite2p directory containing planeXX_stitched folders merged_dir : str, optional Name of merged directory to create (default: "merged") overwrite : bool, optional Whether to overwrite existing merged directory (default: False) Returns ------- merged_path : Path Path to the created merged directory Examples -------- >>> import lbm_suite2p_python as lsp >>> merged = lsp.consolidate_volume("path/to/suite2p") >>> # Load consolidated results >>> ops = np.load(merged / "ops.npy", allow_pickle=True).item() >>> stat = np.load(merged / "stat.npy", allow_pickle=True) >>> iscell = np.load(merged / "iscell.npy") """ suite2p_path = Path(suite2p_path) merged_path = suite2p_path / merged_dir # Find all plane directories plane_dirs = sorted(suite2p_path.glob("plane*_stitched")) if not plane_dirs: raise ValueError(f"No plane*_stitched directories found in {suite2p_path}") print(f"Found {len(plane_dirs)} planes to consolidate") # Check if merged directory exists if merged_path.exists() and not overwrite: print(f"Merged directory already exists: {merged_path}") print("Use overwrite=True to recreate") return merged_path # Create merged directory merged_path.mkdir(exist_ok=True) # Initialize lists for consolidation all_stats = [] all_iscell = [] all_F = [] all_Fneu = [] all_spks = [] all_ops = [] # Track ROI offsets for each plane n_cells_per_plane = [] for plane_dir in plane_dirs: plane_num = int(''.join(filter(str.isdigit, plane_dir.name))) print(f" Processing plane {plane_num}: {plane_dir.name}") # Load required files stat_file = plane_dir / "stat.npy" iscell_file = plane_dir / "iscell.npy" ops_file = plane_dir / "ops.npy" if not all([stat_file.exists(), iscell_file.exists(), ops_file.exists()]): print(f" WARNING: Missing required files, skipping plane {plane_num}") continue # Load data stat = np.load(stat_file, allow_pickle=True) iscell = np.load(iscell_file) ops = np.load(ops_file, allow_pickle=True).item() # Add plane number to each stat entry for s in stat: s['iplane'] = plane_num all_stats.extend(stat) all_iscell.append(iscell) all_ops.append(ops) n_cells_per_plane.append(len(stat)) # Load optional trace files F_file = plane_dir / "F.npy" Fneu_file = plane_dir / "Fneu.npy" spks_file = plane_dir / "spks.npy" if F_file.exists(): F = np.load(F_file) all_F.append(F) else: print(f" WARNING: F.npy not found for plane {plane_num}") if Fneu_file.exists(): Fneu = np.load(Fneu_file) all_Fneu.append(Fneu) else: print(f" WARNING: Fneu.npy not found for plane {plane_num}") if spks_file.exists(): spks = np.load(spks_file) all_spks.append(spks) else: print(f" WARNING: spks.npy not found for plane {plane_num}") # Save consolidated files print(f"\nSaving consolidated results to {merged_path}") # Save stat.npy stat_array = np.array(all_stats, dtype=object) np.save(merged_path / "stat.npy", stat_array) print(f" Saved stat.npy: {len(stat_array)} total ROIs") # Save iscell.npy if all_iscell: iscell_array = np.vstack(all_iscell) np.save(merged_path / "iscell.npy", iscell_array) n_accepted = (iscell_array[:, 0] == 1).sum() print(f" Saved iscell.npy: {n_accepted} accepted, {len(iscell_array) - n_accepted} rejected") # Save trace files if all_F: F_array = np.vstack(all_F) np.save(merged_path / "F.npy", F_array) print(f" Saved F.npy: shape {F_array.shape}") if all_Fneu: Fneu_array = np.vstack(all_Fneu) np.save(merged_path / "Fneu.npy", Fneu_array) print(f" Saved Fneu.npy: shape {Fneu_array.shape}") if all_spks: spks_array = np.vstack(all_spks) np.save(merged_path / "spks.npy", spks_array) print(f" Saved spks.npy: shape {spks_array.shape}") # Save consolidated ops # Use first plane's ops as template and add plane-specific info consolidated_ops = all_ops[0].copy() if all_ops else {} consolidated_ops['nplanes'] = len(plane_dirs) consolidated_ops['n_cells_per_plane'] = n_cells_per_plane consolidated_ops['plane_dirs'] = [str(d) for d in plane_dirs] consolidated_ops['save_path'] = str(merged_path) np.save(merged_path / "ops.npy", consolidated_ops) print(f" Saved ops.npy: {len(plane_dirs)} planes consolidated") print(f"\nConsolidation complete!") print(f"Total ROIs: {len(stat_array)}") print(f"Planes: {len(plane_dirs)}") return merged_path def plot_volume_diagnostics( ops_files: list[str | Path], save_path: str | Path = None, figsize: tuple = (16, 12), ) -> plt.Figure: """ Generate a single-figure diagnostic summary for an entire processed volume. Creates a publication-quality figure showing across all z-planes: - Row 1: ROI counts (accepted/rejected stacked bars), Mean signal per plane - Row 2: SNR distribution per plane, Size distribution per plane - Row 3: Compactness vs SNR (all planes), Skewness vs SNR (all planes) Parameters ---------- ops_files : list of str or Path List of paths to ops.npy files for each z-plane. save_path : str or Path, optional If provided, save figure to this path. figsize : tuple, default (16, 12) Figure size in inches. Returns ------- fig : matplotlib.figure.Figure The generated figure object. """ from lbm_suite2p_python.postprocessing import load_ops if not ops_files: fig = plt.figure(figsize=figsize, facecolor="black") fig.text(0.5, 0.5, "No ops files provided", ha="center", va="center", fontsize=16, fontweight="bold", color="white") return fig # Collect data from all planes plane_data = [] all_snr = [] all_npix = [] all_compactness = [] all_skewness = [] all_plane_ids = [] # track which plane each ROI belongs to for ops_file in ops_files: ops_file = Path(ops_file) ops = load_ops(ops_file) plane_dir = ops_file.parent raw_plane = ops.get("plane", None) # Extract plane number if raw_plane is not None: if isinstance(raw_plane, (int, np.integer)): plane_num = int(raw_plane) else: s = str(raw_plane) digits = "".join([c for c in s if c.isdigit()]) plane_num = int(digits) if digits else len(plane_data) else: plane_num = len(plane_data) # Load required files iscell_file = plane_dir / "iscell.npy" stat_file = plane_dir / "stat.npy" F_file = plane_dir / "F.npy" Fneu_file = plane_dir / "Fneu.npy" if not all([iscell_file.exists(), stat_file.exists(), F_file.exists()]): continue try: iscell_raw = np.load(iscell_file, allow_pickle=True) if not isinstance(iscell_raw, np.ndarray) or iscell_raw.size == 0: continue iscell = iscell_raw[:, 0].astype(bool) stat = np.load(stat_file, allow_pickle=True) F = np.load(F_file, allow_pickle=True) Fneu = np.load(Fneu_file, allow_pickle=True) if Fneu_file.exists() else np.zeros_like(F) except Exception: continue n_accepted = int(np.sum(iscell)) n_rejected = int(np.sum(~iscell)) # Compute SNR for accepted cells F_corr = F - 0.7 * Fneu baseline = np.percentile(F_corr, 20, axis=1, keepdims=True) baseline = np.maximum(baseline, 1e-6) dff = (F_corr - baseline) / baseline signal = np.std(dff, axis=1) noise = np.median(np.abs(np.diff(dff, axis=1)), axis=1) / 0.6745 snr = signal / (noise + 1e-6) # Extract ROI properties npix = np.array([s.get("npix", 0) for s in stat]) compactness = np.array([s.get("compact", np.nan) for s in stat]) skewness = np.array([s.get("skew", np.nan) for s in stat]) # Store per-plane stats mean_signal = float(np.nanmean(F)) std_signal = float(np.nanstd(F)) mean_snr = float(np.nanmean(snr[iscell])) if n_accepted > 0 else 0.0 plane_data.append({ "plane": plane_num, "n_accepted": n_accepted, "n_rejected": n_rejected, "mean_signal": mean_signal, "std_signal": std_signal, "mean_snr": mean_snr, }) # Collect accepted cell data for scatter plots if n_accepted > 0: all_snr.extend(snr[iscell]) all_npix.extend(npix[iscell]) all_compactness.extend(compactness[iscell]) all_skewness.extend(skewness[iscell]) all_plane_ids.extend([plane_num] * n_accepted) if not plane_data: fig = plt.figure(figsize=figsize, facecolor="black") fig.text(0.5, 0.5, "No valid plane data found\n\nCheck that ops.npy, stat.npy, F.npy exist", ha="center", va="center", fontsize=14, fontweight="bold", color="white") return fig # Convert to arrays planes = np.array([d["plane"] for d in plane_data]) n_accepted = np.array([d["n_accepted"] for d in plane_data]) n_rejected = np.array([d["n_rejected"] for d in plane_data]) mean_signals = np.array([d["mean_signal"] for d in plane_data]) std_signals = np.array([d["std_signal"] for d in plane_data]) mean_snrs = np.array([d["mean_snr"] for d in plane_data]) all_snr = np.array(all_snr) all_npix = np.array(all_npix) all_compactness = np.array(all_compactness) all_skewness = np.array(all_skewness) all_plane_ids = np.array(all_plane_ids) # Create figure with 3x2 grid fig = plt.figure(figsize=figsize, facecolor="black") gs = fig.add_gridspec(3, 2, hspace=0.35, wspace=0.25, left=0.08, right=0.95, top=0.93, bottom=0.08) # Color palette for planes n_planes = len(planes) cmap = plt.cm.viridis plane_colors = {p: cmap(i / max(1, n_planes - 1)) for i, p in enumerate(planes)} # ========== Panel 1: ROI Counts per Plane ========== ax1 = fig.add_subplot(gs[0, 0]) ax1.set_facecolor("black") bar_width = 0.8 bars1 = ax1.bar(planes, n_accepted, width=bar_width, label=f"Accepted ({n_accepted.sum()})", alpha=0.85, color="#2ecc71", edgecolor="#27ae60", linewidth=0.5) bars2 = ax1.bar(planes, n_rejected, width=bar_width, bottom=n_accepted, label=f"Rejected ({n_rejected.sum()})", alpha=0.85, color="#e74c3c", edgecolor="#c0392b", linewidth=0.5) # Labels inside bars for bar in bars1: h = bar.get_height() if h > 5: ax1.text(bar.get_x() + bar.get_width()/2, h/2, f"{int(h)}", ha="center", va="center", fontsize=7, color="white", fontweight="bold") for b1, b2 in zip(bars1, bars2): h1, h2 = b1.get_height(), b2.get_height() if h2 > 5: ax1.text(b2.get_x() + b2.get_width()/2, h1 + h2/2, f"{int(h2)}", ha="center", va="center", fontsize=7, color="white", fontweight="bold") ax1.set_xlabel("Z-Plane", fontsize=9, fontweight="bold", color="white") ax1.set_ylabel("Number of ROIs", fontsize=9, fontweight="bold", color="white") ax1.set_title("ROI Counts per Plane", fontsize=10, fontweight="bold", color="white") ax1.tick_params(colors="white", labelsize=8) ax1.spines["top"].set_visible(False) ax1.spines["right"].set_visible(False) ax1.spines["bottom"].set_color("white") ax1.spines["left"].set_color("white") ax1.legend(fontsize=7, facecolor="#1a1a1a", edgecolor="white", labelcolor="white", loc="upper right") # ========== Panel 2: Mean Signal per Plane ========== ax2 = fig.add_subplot(gs[0, 1]) ax2.set_facecolor("black") ax2.errorbar(planes, mean_signals, yerr=std_signals, fmt="o-", color="#3498db", ecolor="#85c1e9", elinewidth=1.5, capsize=3, markersize=5, alpha=0.9, label="Mean ± STD") ax2.set_xlabel("Z-Plane", fontsize=9, fontweight="bold", color="white") ax2.set_ylabel("Mean Raw Signal", fontsize=9, fontweight="bold", color="white") ax2.set_title("Fluorescence Signal per Plane", fontsize=10, fontweight="bold", color="white") ax2.tick_params(colors="white", labelsize=8) ax2.spines["top"].set_visible(False) ax2.spines["right"].set_visible(False) ax2.spines["bottom"].set_color("white") ax2.spines["left"].set_color("white") ax2.legend(fontsize=7, facecolor="#1a1a1a", edgecolor="white", labelcolor="white") # ========== Panel 3: SNR Distribution (violin or box per plane) ========== ax3 = fig.add_subplot(gs[1, 0]) ax3.set_facecolor("black") if len(all_snr) > 0: # Box plot per plane snr_by_plane = [all_snr[all_plane_ids == p] for p in planes] snr_by_plane = [s[~np.isnan(s)] for s in snr_by_plane] bp = ax3.boxplot(snr_by_plane, positions=planes, widths=0.6, patch_artist=True, showfliers=False, medianprops=dict(color="#ffe66d", linewidth=2)) for patch in bp["boxes"]: patch.set_facecolor("#2ecc71") patch.set_alpha(0.7) for whisker in bp["whiskers"]: whisker.set_color("white") for cap in bp["caps"]: cap.set_color("white") # Add mean line ax3.plot(planes, mean_snrs, "o--", color="#e74c3c", markersize=4, label="Mean SNR") else: ax3.text(0.5, 0.5, "No SNR data", ha="center", va="center", fontsize=12, color="white") ax3.set_xlabel("Z-Plane", fontsize=9, fontweight="bold", color="white") ax3.set_ylabel("SNR", fontsize=9, fontweight="bold", color="white") ax3.set_title("SNR Distribution per Plane", fontsize=10, fontweight="bold", color="white") ax3.tick_params(colors="white", labelsize=8) ax3.spines["top"].set_visible(False) ax3.spines["right"].set_visible(False) ax3.spines["bottom"].set_color("white") ax3.spines["left"].set_color("white") if len(all_snr) > 0: ax3.legend(fontsize=7, facecolor="#1a1a1a", edgecolor="white", labelcolor="white") # ========== Panel 4: Size Distribution (box per plane) ========== ax4 = fig.add_subplot(gs[1, 1]) ax4.set_facecolor("black") if len(all_npix) > 0: npix_by_plane = [all_npix[all_plane_ids == p] for p in planes] npix_by_plane = [n[n > 0] for n in npix_by_plane] bp4 = ax4.boxplot(npix_by_plane, positions=planes, widths=0.6, patch_artist=True, showfliers=False, medianprops=dict(color="#ffe66d", linewidth=2)) for patch in bp4["boxes"]: patch.set_facecolor("#3498db") patch.set_alpha(0.7) for whisker in bp4["whiskers"]: whisker.set_color("white") for cap in bp4["caps"]: cap.set_color("white") # Mean size per plane mean_sizes = [np.mean(n) if len(n) > 0 else 0 for n in npix_by_plane] ax4.plot(planes, mean_sizes, "o--", color="#e74c3c", markersize=4, label="Mean Size") else: ax4.text(0.5, 0.5, "No size data", ha="center", va="center", fontsize=12, color="white") ax4.set_xlabel("Z-Plane", fontsize=9, fontweight="bold", color="white") ax4.set_ylabel("Size (pixels)", fontsize=9, fontweight="bold", color="white") ax4.set_title("ROI Size Distribution per Plane", fontsize=10, fontweight="bold", color="white") ax4.tick_params(colors="white", labelsize=8) ax4.spines["top"].set_visible(False) ax4.spines["right"].set_visible(False) ax4.spines["bottom"].set_color("white") ax4.spines["left"].set_color("white") if len(all_npix) > 0: ax4.legend(fontsize=7, facecolor="#1a1a1a", edgecolor="white", labelcolor="white") # ========== Panel 5: Compactness Distribution per Plane ========== ax5 = fig.add_subplot(gs[2, 0]) ax5.set_facecolor("black") if len(all_compactness) > 0: compact_by_plane = [all_compactness[all_plane_ids == p] for p in planes] compact_by_plane = [c[~np.isnan(c)] for c in compact_by_plane] # Only create boxplot if we have data valid_compact = [c for c in compact_by_plane if len(c) > 0] valid_planes_compact = [p for p, c in zip(planes, compact_by_plane) if len(c) > 0] if valid_compact: bp5 = ax5.boxplot(valid_compact, positions=valid_planes_compact, widths=0.6, patch_artist=True, showfliers=False, medianprops=dict(color="#ffe66d", linewidth=2)) for patch in bp5["boxes"]: patch.set_facecolor("#9b59b6") # Purple for compactness patch.set_alpha(0.7) for whisker in bp5["whiskers"]: whisker.set_color("white") for cap in bp5["caps"]: cap.set_color("white") # Mean compactness per plane mean_compact = [np.mean(c) if len(c) > 0 else np.nan for c in compact_by_plane] valid_mean_compact = [m for m, c in zip(mean_compact, compact_by_plane) if len(c) > 0] ax5.plot(valid_planes_compact, valid_mean_compact, "o--", color="#e74c3c", markersize=4, label="Mean") ax5.legend(fontsize=7, facecolor="#1a1a1a", edgecolor="white", labelcolor="white") else: ax5.text(0.5, 0.5, "No compactness data", ha="center", va="center", fontsize=12, color="white") else: ax5.text(0.5, 0.5, "No data", ha="center", va="center", fontsize=12, color="white") ax5.set_xlabel("Z-Plane", fontsize=9, fontweight="bold", color="white") ax5.set_ylabel("Compactness", fontsize=9, fontweight="bold", color="white") ax5.set_title("Compactness Distribution per Plane", fontsize=10, fontweight="bold", color="white") ax5.tick_params(colors="white", labelsize=8) ax5.spines["top"].set_visible(False) ax5.spines["right"].set_visible(False) ax5.spines["bottom"].set_color("white") ax5.spines["left"].set_color("white") # ========== Panel 6: Skewness Distribution per Plane ========== ax6 = fig.add_subplot(gs[2, 1]) ax6.set_facecolor("black") if len(all_skewness) > 0: skew_by_plane = [all_skewness[all_plane_ids == p] for p in planes] skew_by_plane = [s[~np.isnan(s)] for s in skew_by_plane] # Only create boxplot if we have data valid_skew = [s for s in skew_by_plane if len(s) > 0] valid_planes_skew = [p for p, s in zip(planes, skew_by_plane) if len(s) > 0] if valid_skew: bp6 = ax6.boxplot(valid_skew, positions=valid_planes_skew, widths=0.6, patch_artist=True, showfliers=False, medianprops=dict(color="#ffe66d", linewidth=2)) for patch in bp6["boxes"]: patch.set_facecolor("#e67e22") # Orange for skewness patch.set_alpha(0.7) for whisker in bp6["whiskers"]: whisker.set_color("white") for cap in bp6["caps"]: cap.set_color("white") # Mean skewness per plane mean_skew = [np.mean(s) if len(s) > 0 else np.nan for s in skew_by_plane] valid_mean_skew = [m for m, s in zip(mean_skew, skew_by_plane) if len(s) > 0] ax6.plot(valid_planes_skew, valid_mean_skew, "o--", color="#e74c3c", markersize=4, label="Mean") ax6.legend(fontsize=7, facecolor="#1a1a1a", edgecolor="white", labelcolor="white") else: ax6.text(0.5, 0.5, "No skewness data", ha="center", va="center", fontsize=12, color="white") else: ax6.text(0.5, 0.5, "No data", ha="center", va="center", fontsize=12, color="white") ax6.set_xlabel("Z-Plane", fontsize=9, fontweight="bold", color="white") ax6.set_ylabel("Skewness", fontsize=9, fontweight="bold", color="white") ax6.set_title("Skewness Distribution per Plane", fontsize=10, fontweight="bold", color="white") ax6.tick_params(colors="white", labelsize=8) ax6.spines["top"].set_visible(False) ax6.spines["right"].set_visible(False) ax6.spines["bottom"].set_color("white") ax6.spines["left"].set_color("white") # Title total_accepted = n_accepted.sum() total_rejected = n_rejected.sum() fig.suptitle(f"Volume Quality Diagnostics: {n_planes} planes, {total_accepted} accepted, {total_rejected} rejected ROIs", fontsize=12, fontweight="bold", color="white", y=0.98) if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="black") plt.close(fig) return fig def plot_orthoslices( ops_files: list[str | Path], save_path: str | Path = None, figsize: tuple = (16, 6), use_mean: bool = True, ) -> plt.Figure: """ Generate orthogonal maximum intensity projections (XY, XZ, YZ) of the volume. Creates a 3-panel figure showing the volume from three orthogonal views, which is standard in microscopy for visualizing 3D structure. Axes are displayed in micrometers when valid voxel size metadata is available. Parameters ---------- ops_files : list of str or Path List of paths to ops.npy files for each z-plane, ordered by z. save_path : str or Path, optional If provided, save figure to this path. figsize : tuple, default (16, 6) Figure size in inches. use_mean : bool, default True If True, use meanImg. If False, use refImg (registered reference). Returns ------- fig : matplotlib.figure.Figure The generated figure object. """ from lbm_suite2p_python.postprocessing import load_ops if not ops_files: fig = plt.figure(figsize=figsize, facecolor="black") fig.text(0.5, 0.5, "No ops files provided", ha="center", va="center", fontsize=16, fontweight="bold", color="white") return fig # Get voxel size from first ops file first_ops = load_ops(ops_files[0]) try: from mbo_utilities.metadata import get_voxel_size voxel = get_voxel_size(first_ops) dx_um, dy_um, dz_um = voxel.dx, voxel.dy, voxel.dz except ImportError: # Fallback if mbo_utilities not available pixel_res = first_ops.get("pixel_resolution", [1.0, 1.0]) if isinstance(pixel_res, (int, float)): dx_um, dy_um = float(pixel_res), float(pixel_res) else: dx_um = float(pixel_res[0]) if len(pixel_res) > 0 else 1.0 dy_um = float(pixel_res[1]) if len(pixel_res) > 1 else dx_um dz_um = float(first_ops.get("dz", first_ops.get("z_step", 15.0))) # Check if we have valid (non-default) voxel sizes has_valid_xy = dx_um != 1.0 or dy_um != 1.0 has_valid_z = dz_um != 1.0 # Collect images from all planes images = [] plane_nums = [] for ops_file in ops_files: ops_file = Path(ops_file) ops = load_ops(ops_file) # Get image img_key = "meanImg" if use_mean else "refImg" img = ops.get(img_key) if img is None or not isinstance(img, np.ndarray): img = ops.get("meanImg" if not use_mean else "refImg") if img is None or not isinstance(img, np.ndarray): continue # Get plane number raw_plane = ops.get("plane", len(images)) if isinstance(raw_plane, (int, np.integer)): plane_num = int(raw_plane) else: s = str(raw_plane) digits = "".join([c for c in s if c.isdigit()]) plane_num = int(digits) if digits else len(images) images.append(img) plane_nums.append(plane_num) if not images: fig = plt.figure(figsize=figsize, facecolor="black") fig.text(0.5, 0.5, "No valid images found", ha="center", va="center", fontsize=16, fontweight="bold", color="white") return fig # Sort by plane number sort_idx = np.argsort(plane_nums) images = [images[i] for i in sort_idx] plane_nums = [plane_nums[i] for i in sort_idx] # Stack into 3D volume (Z, Y, X) volume = np.stack(images, axis=0) nz, ny, nx = volume.shape # Compute projections xy_proj = np.max(volume, axis=0) # Max along Z -> XY view xz_proj = np.max(volume, axis=1) # Max along Y -> XZ view yz_proj = np.max(volume, axis=2) # Max along X -> YZ view # Create figure fig = plt.figure(figsize=figsize, facecolor="black") # Calculate aspect ratios for proper scaling z_scale = dz_um xy_scale = (dx_um + dy_um) / 2 # Average XY scale gs = fig.add_gridspec(1, 3, wspace=0.15, left=0.05, right=0.95, top=0.88, bottom=0.1) # Determine axis labels and extent based on valid voxel size if has_valid_xy: x_label = "X (μm)" y_label = "Y (μm)" xy_extent = [0, nx * dx_um, ny * dy_um, 0] xz_extent = [0, nx * dx_um, nz * dz_um, 0] yz_extent = [0, nz * dz_um, ny * dy_um, 0] else: x_label = "X (pixels)" y_label = "Y (pixels)" xy_extent = None xz_extent = None yz_extent = None if has_valid_z: z_label = "Z (μm)" else: z_label = "Z (plane)" # Panel 1: XY projection (top-down view) ax1 = fig.add_subplot(gs[0, 0]) ax1.set_facecolor("black") im1 = ax1.imshow(xy_proj, cmap="magma", aspect="equal", extent=xy_extent, vmin=np.percentile(xy_proj, 1), vmax=np.percentile(xy_proj, 99.5)) ax1.set_xlabel(x_label, fontsize=10, fontweight="bold", color="white") ax1.set_ylabel(y_label, fontsize=10, fontweight="bold", color="white") ax1.set_title("XY Projection (top view)", fontsize=11, fontweight="bold", color="white") ax1.tick_params(colors="white", labelsize=8) for spine in ax1.spines.values(): spine.set_color("white") # Panel 2: XZ projection (side view) ax2 = fig.add_subplot(gs[0, 1]) ax2.set_facecolor("black") im2 = ax2.imshow(xz_proj, cmap="magma", aspect=z_scale/xy_scale, extent=xz_extent, vmin=np.percentile(xz_proj, 1), vmax=np.percentile(xz_proj, 99.5)) ax2.set_xlabel(x_label, fontsize=10, fontweight="bold", color="white") ax2.set_ylabel(z_label, fontsize=10, fontweight="bold", color="white") ax2.set_title("XZ Projection (front view)", fontsize=11, fontweight="bold", color="white") ax2.tick_params(colors="white", labelsize=8) for spine in ax2.spines.values(): spine.set_color("white") # Panel 3: YZ projection (side view) ax3 = fig.add_subplot(gs[0, 2]) ax3.set_facecolor("black") im3 = ax3.imshow(yz_proj.T, cmap="magma", aspect=xy_scale/z_scale, extent=yz_extent, vmin=np.percentile(yz_proj, 1), vmax=np.percentile(yz_proj, 99.5)) ax3.set_xlabel(z_label, fontsize=10, fontweight="bold", color="white") ax3.set_ylabel(y_label, fontsize=10, fontweight="bold", color="white") ax3.set_title("YZ Projection (side view)", fontsize=11, fontweight="bold", color="white") ax3.tick_params(colors="white", labelsize=8) for spine in ax3.spines.values(): spine.set_color("white") # Add colorbar cbar = fig.colorbar(im1, ax=[ax1, ax2, ax3], shrink=0.6, pad=0.02, location="right") cbar.set_label("Max Intensity", fontsize=10, color="white") cbar.ax.tick_params(colors="white") cbar.outline.set_edgecolor("white") # Title with volume dimensions in appropriate units if has_valid_xy and has_valid_z: vol_x = nx * dx_um vol_y = ny * dy_um vol_z = nz * dz_um title = f"Orthogonal Projections: {nz} planes, {vol_x:.0f}×{vol_y:.0f}×{vol_z:.0f} μm" else: title = f"Orthogonal Projections: {nz} planes, {ny}×{nx} pixels" fig.suptitle(title, fontsize=12, fontweight="bold", color="white", y=0.96) if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="black") plt.close(fig) return fig def plot_3d_roi_map( ops_files: list[str | Path], save_path: str | Path = None, figsize: tuple = (14, 10), color_by: str = "snr", show_rejected: bool = False, ) -> plt.Figure: """ Generate a 3D scatter plot of ROI centroids across the volume. Creates a 3D visualization showing the spatial distribution of detected cells colored by SNR. Axes are displayed in micrometers when valid voxel size metadata is available, otherwise in pixels/planes. Parameters ---------- ops_files : list of str or Path List of paths to ops.npy files for each z-plane. save_path : str or Path, optional If provided, save figure to this path. figsize : tuple, default (14, 10) Figure size in inches. color_by : str, default "snr" How to color the ROIs: "snr", "plane", "size", or "activity". show_rejected : bool, default False If True, also show rejected ROIs in gray. Returns ------- fig : matplotlib.figure.Figure The generated figure object. """ from lbm_suite2p_python.postprocessing import load_ops if not ops_files: fig = plt.figure(figsize=figsize, facecolor="black") fig.text(0.5, 0.5, "No ops files provided", ha="center", va="center", fontsize=16, fontweight="bold", color="white") return fig # Get voxel size from first ops file first_ops = load_ops(ops_files[0]) try: from mbo_utilities.metadata import get_voxel_size voxel = get_voxel_size(first_ops) dx_um, dy_um, dz_um = voxel.dx, voxel.dy, voxel.dz except ImportError: # Fallback if mbo_utilities not available pixel_res = first_ops.get("pixel_resolution", [1.0, 1.0]) if isinstance(pixel_res, (int, float)): dx_um, dy_um = float(pixel_res), float(pixel_res) else: dx_um = float(pixel_res[0]) if len(pixel_res) > 0 else 1.0 dy_um = float(pixel_res[1]) if len(pixel_res) > 1 else dx_um dz_um = float(first_ops.get("dz", first_ops.get("z_step", 15.0))) # Check if we have valid (non-default) voxel sizes has_valid_xy = dx_um != 1.0 or dy_um != 1.0 has_valid_z = dz_um != 1.0 # Collect ROI data from all planes all_x = [] all_y = [] all_z = [] all_colors = [] all_accepted = [] # For rejected ROIs rej_x, rej_y, rej_z = [], [], [] for ops_file in ops_files: ops_file = Path(ops_file) ops = load_ops(ops_file) plane_dir = ops_file.parent # Get plane number raw_plane = ops.get("plane", len(all_x)) if isinstance(raw_plane, (int, np.integer)): plane_num = int(raw_plane) else: s = str(raw_plane) digits = "".join([c for c in s if c.isdigit()]) plane_num = int(digits) if digits else 0 # Load required files stat_file = plane_dir / "stat.npy" iscell_file = plane_dir / "iscell.npy" if not stat_file.exists() or not iscell_file.exists(): continue try: stat = np.load(stat_file, allow_pickle=True) iscell_raw = np.load(iscell_file, allow_pickle=True) if not isinstance(iscell_raw, np.ndarray) or iscell_raw.size == 0: continue iscell = iscell_raw[:, 0].astype(bool) except Exception: continue # Get color values based on color_by if color_by == "snr": F_file = plane_dir / "F.npy" Fneu_file = plane_dir / "Fneu.npy" if F_file.exists(): F = np.load(F_file, allow_pickle=True) Fneu = np.load(Fneu_file, allow_pickle=True) if Fneu_file.exists() else np.zeros_like(F) F_corr = F - 0.7 * Fneu baseline = np.percentile(F_corr, 20, axis=1, keepdims=True) baseline = np.maximum(baseline, 1e-6) dff = (F_corr - baseline) / baseline signal = np.std(dff, axis=1) noise = np.median(np.abs(np.diff(dff, axis=1)), axis=1) / 0.6745 color_vals = signal / (noise + 1e-6) else: color_vals = np.ones(len(stat)) * plane_num elif color_by == "size": color_vals = np.array([s.get("npix", 100) for s in stat]) elif color_by == "activity": F_file = plane_dir / "F.npy" if F_file.exists(): F = np.load(F_file, allow_pickle=True) color_vals = np.std(F, axis=1) else: color_vals = np.ones(len(stat)) * plane_num else: # plane color_vals = np.ones(len(stat)) * plane_num # Extract centroids and convert to microns for i, s in enumerate(stat): med = s.get("med", [0, 0]) y_px, x_px = med[0], med[1] # Convert pixels to microns x_um = x_px * dx_um y_um = y_px * dy_um z_um = plane_num * dz_um if iscell[i]: all_x.append(x_um) all_y.append(y_um) all_z.append(z_um) all_colors.append(color_vals[i]) all_accepted.append(True) elif show_rejected: rej_x.append(x_um) rej_y.append(y_um) rej_z.append(z_um) if not all_x: fig = plt.figure(figsize=figsize, facecolor="black") fig.text(0.5, 0.5, "No ROI data found", ha="center", va="center", fontsize=16, fontweight="bold", color="white") return fig # Convert to arrays all_x = np.array(all_x) all_y = np.array(all_y) all_z = np.array(all_z) all_colors = np.array(all_colors) # Create figure with 3D axis fig = plt.figure(figsize=figsize, facecolor="black") ax = fig.add_subplot(111, projection="3d", facecolor="black") # Set pane colors to dark ax.xaxis.pane.fill = False ax.yaxis.pane.fill = False ax.zaxis.pane.fill = False ax.xaxis.pane.set_edgecolor("white") ax.yaxis.pane.set_edgecolor("white") ax.zaxis.pane.set_edgecolor("white") # Plot rejected ROIs first (if enabled) if show_rejected and rej_x: ax.scatter(rej_x, rej_y, rej_z, c="gray", s=10, alpha=0.3, label="Rejected") # Choose colormap based on color_by if color_by == "plane": cmap = "viridis" clabel = "Z-Plane" elif color_by == "snr": cmap = "plasma" clabel = "SNR" # Clip extreme values vmin, vmax = np.percentile(all_colors, [5, 95]) all_colors = np.clip(all_colors, vmin, vmax) elif color_by == "size": cmap = "cividis" clabel = "Size (pixels)" vmin, vmax = np.percentile(all_colors, [5, 95]) all_colors = np.clip(all_colors, vmin, vmax) else: # activity cmap = "magma" clabel = "Activity (std)" vmin, vmax = np.percentile(all_colors, [5, 95]) all_colors = np.clip(all_colors, vmin, vmax) # Plot accepted ROIs scatter = ax.scatter(all_x, all_y, all_z, c=all_colors, cmap=cmap, s=15, alpha=0.7, edgecolors="none") # Add colorbar cbar = fig.colorbar(scatter, ax=ax, shrink=0.6, pad=0.1) cbar.set_label(clabel, fontsize=10, color="white") cbar.ax.tick_params(colors="white") cbar.outline.set_edgecolor("white") # Style axes with appropriate units based on valid voxel size x_label = "X (μm)" if has_valid_xy else "X (pixels)" y_label = "Y (μm)" if has_valid_xy else "Y (pixels)" z_label = "Z (μm)" if has_valid_z else "Z (plane)" ax.set_xlabel(x_label, fontsize=10, fontweight="bold", color="white", labelpad=10) ax.set_ylabel(y_label, fontsize=10, fontweight="bold", color="white", labelpad=10) ax.set_zlabel(z_label, fontsize=10, fontweight="bold", color="white", labelpad=10) ax.tick_params(colors="white", labelsize=8) ax.xaxis.label.set_color("white") ax.yaxis.label.set_color("white") ax.zaxis.label.set_color("white") # Set grid color ax.xaxis._axinfo["grid"]["color"] = (1, 1, 1, 0.2) ax.yaxis._axinfo["grid"]["color"] = (1, 1, 1, 0.2) ax.zaxis._axinfo["grid"]["color"] = (1, 1, 1, 0.2) # Title with volume dimensions n_cells = len(all_x) n_planes = len(np.unique(all_z)) x_range = all_x.max() - all_x.min() y_range = all_y.max() - all_y.min() z_range = all_z.max() - all_z.min() if has_valid_xy and has_valid_z: vol_str = f"Volume: {x_range:.0f} × {y_range:.0f} × {z_range:.0f} μm" else: vol_str = f"Volume: {x_range:.0f} × {y_range:.0f} × {z_range:.0f}" fig.suptitle( f"3D ROI Distribution: {n_cells} cells across {n_planes} planes\n{vol_str}", fontsize=12, fontweight="bold", color="white", y=0.95 ) if show_rejected and rej_x: ax.legend(fontsize=9, facecolor="#1a1a1a", edgecolor="white", labelcolor="white") # Adjust view angle for better visualization ax.view_init(elev=20, azim=45) if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="black") plt.close(fig) return fig