import glob
import subprocess
from pathlib import Path
import cv2
import numpy as np
from matplotlib import pyplot as plt
from lbm_suite2p_python.utils import get_common_path
from lbm_suite2p_python.postprocessing import load_ops
def update_ops_paths(ops_files: str | list):
"""
Update save_path, save_path0, and save_folder in an ops dictionary based on its current location. Use after moving an ops_file or batch of ops_files.
"""
if isinstance(ops_files, (str, Path)):
ops_files = [ops_files]
for ops_file in ops_files:
ops = np.load(ops_file, allow_pickle=True).item()
ops_path = Path(ops_file)
plane0_folder = ops_path.parent
plane_folder = plane0_folder.parent
ops["save_path"] = str(plane0_folder)
ops["save_path0"] = str(plane_folder)
ops["save_folder"] = plane_folder.name
ops["ops_path"] = ops_path
np.save(ops_file, ops)
[docs]
def plot_execution_time(filepath, savepath):
"""
Plots the execution time for each processing step per z-plane.
This function loads execution timing data from a `.npy` file and visualizes the
runtime of different processing steps as a stacked bar plot with a black background.
Parameters
----------
filepath : str or Path
Path to the `.npy` file containing the volume timing stats.
savepath : str or Path
Path to save the generated figure.
Notes
-----
- The `.npy` file should contain structured data with `plane`, `registration`,
`detection`, `extraction`, `classification`, `deconvolution`, and `total_plane_runtime` fields.
"""
plane_stats = np.load(filepath)
planes = plane_stats["plane"]
reg_time = plane_stats["registration"]
detect_time = plane_stats["detection"]
extract_time = plane_stats["extraction"]
total_time = plane_stats["total_plane_runtime"]
plt.figure(figsize=(10, 6), facecolor="black")
ax = plt.gca()
ax.set_facecolor("black")
plt.xlabel("Z-Plane", fontsize=14, fontweight="bold", color="white")
plt.ylabel("Execution Time (s)", fontsize=14, fontweight="bold", color="white")
plt.title(
"Execution Time per Processing Step",
fontsize=16,
fontweight="bold",
color="white",
)
plt.bar(planes, reg_time, label="Registration", alpha=0.8, color="#FF5733")
plt.bar(
planes,
detect_time,
label="Detection",
alpha=0.8,
bottom=reg_time,
color="#33FF57",
)
bars3 = plt.bar(
planes,
extract_time,
label="Extraction",
alpha=0.8,
bottom=reg_time + detect_time,
color="#3357FF",
)
for bar, total in zip(bars3, total_time):
height = bar.get_y() + bar.get_height()
if total > 1: # Only label if execution time is large enough to be visible
plt.text(
bar.get_x() + bar.get_width() / 2,
height + 2,
f"{int(total)}",
ha="center",
va="bottom",
fontsize=12,
color="white",
fontweight="bold",
)
plt.xticks(planes, fontsize=12, fontweight="bold", color="white")
plt.yticks(fontsize=12, fontweight="bold", color="white")
plt.grid(axis="y", linestyle="--", alpha=0.4, color="white")
ax.spines["bottom"].set_color("white")
ax.spines["left"].set_color("white")
ax.spines["top"].set_color("white")
ax.spines["right"].set_color("white")
plt.legend(
fontsize=12,
facecolor="black",
edgecolor="white",
labelcolor="white",
loc="upper left",
bbox_to_anchor=(1, 1),
)
plt.savefig(savepath, bbox_inches="tight", facecolor="black")
plt.show()
[docs]
def plot_volume_signal(zstats, savepath):
"""
Plots the mean fluorescence signal per z-plane with standard deviation error bars.
This function loads signal statistics from a `.npy` file and visualizes the mean
fluorescence signal per z-plane, with error bars representing the standard deviation.
Parameters
----------
zstats : str or Path
Path to the `.npy` file containing the volume stats. The output of `get_zstats()`.
savepath : str or Path
Path to save the generated figure.
Notes
-----
- The `.npy` file should contain structured data with `plane`, `mean_trace`, and `std_trace` fields.
- Error bars represent the standard deviation of the fluorescence signal.
"""
plane_stats = np.load(zstats)
planes = plane_stats["plane"]
mean_signal = plane_stats["mean_trace"]
std_signal = plane_stats["std_trace"]
fig, ax = plt.subplots(figsize=(10, 5), facecolor="black")
ax.set_facecolor("black")
ax.errorbar(
planes,
mean_signal,
yerr=std_signal,
fmt="o-",
color="#3498db",
ecolor="#85c1e9",
elinewidth=1.5,
capsize=3,
markersize=5,
alpha=0.9,
label="Mean ± STD",
)
ax.set_xlabel("Z-Plane", fontsize=10, fontweight="bold", color="white")
ax.set_ylabel("Mean Raw Signal", fontsize=10, fontweight="bold", color="white")
ax.set_title("Mean Fluorescence Signal per Z-Plane", fontsize=11, fontweight="bold", color="white")
ax.tick_params(colors="white", labelsize=9)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_color("white")
ax.spines["left"].set_color("white")
ax.legend(fontsize=8, facecolor="#1a1a1a", edgecolor="white", labelcolor="white")
savepath = Path(savepath)
if savepath.is_dir():
savepath = savepath / "mean_volume_signal.png"
plt.savefig(savepath, bbox_inches="tight", facecolor="black", dpi=150)
plt.close(fig)
def plot_volume_neuron_counts(zstats, savepath):
"""
Plots the number of accepted and rejected neurons per z-plane.
This function loads neuron count data from a `.npy` file and visualizes the
accepted vs. rejected neurons as a stacked bar plot with a black background.
Parameters
----------
zstats : str, Path
Full path to the zstats.npy file.
savepath : str or Path
Path to directory where generated figure will be saved.
Notes
-----
- The `.npy` file should contain structured data with `plane`, `accepted`, and `rejected` fields.
"""
zstats = Path(zstats)
if not zstats.is_file():
raise FileNotFoundError(f"{zstats} is not a valid zstats.npy file.")
plane_stats = np.load(zstats)
savepath = Path(savepath)
planes = plane_stats["plane"]
accepted = plane_stats["accepted"]
rejected = plane_stats["rejected"]
savename = savepath.joinpath(
f"all_neurons_{accepted.sum()}acc_{rejected.sum()}rej.png"
)
fig, ax = plt.subplots(figsize=(10, 5), facecolor="black")
ax.set_facecolor("black")
bar_width = 0.8
bars1 = ax.bar(
planes, accepted, width=bar_width, label=f"Accepted ({accepted.sum()})",
alpha=0.85, color="#2ecc71", edgecolor="#27ae60", linewidth=0.5
)
bars2 = ax.bar(
planes, rejected, width=bar_width, bottom=accepted,
label=f"Rejected ({rejected.sum()})", alpha=0.85, color="#e74c3c",
edgecolor="#c0392b", linewidth=0.5
)
# Add count labels inside bars (only if tall enough)
for bar in bars1:
height = bar.get_height()
if height > 5: # Only label if tall enough to fit text
ax.text(
bar.get_x() + bar.get_width() / 2, height / 2,
f"{int(height)}", ha="center", va="center",
fontsize=8, color="white", fontweight="bold"
)
for bar1, bar2 in zip(bars1, bars2):
height1 = bar1.get_height()
height2 = bar2.get_height()
if height2 > 5:
ax.text(
bar2.get_x() + bar2.get_width() / 2, height1 + height2 / 2,
f"{int(height2)}", ha="center", va="center",
fontsize=8, color="white", fontweight="bold"
)
ax.set_xlabel("Z-Plane", fontsize=10, fontweight="bold", color="white")
ax.set_ylabel("Number of ROIs", fontsize=10, fontweight="bold", color="white")
ax.set_title("Accepted vs Rejected ROIs per Z-Plane", fontsize=11, fontweight="bold", color="white")
ax.tick_params(colors="white", labelsize=9)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_color("white")
ax.spines["left"].set_color("white")
ax.legend(fontsize=8, facecolor="#1a1a1a", edgecolor="white", labelcolor="white", loc="upper right")
plt.savefig(savename, bbox_inches="tight", facecolor="black", dpi=150)
plt.close(fig)
def get_volume_stats(ops_files: list[str | Path], overwrite: bool = True):
"""
Given a list of ops.npy files, accumulate common statistics for assessing zplane quality.
Parameters
----------
ops_files : list of str or Path
Each item in the list should be a path pointing to a z-lanes `ops.npy` file.
The number of items in this list should match the number of z-planes in your session.
overwrite : bool
If a file already exists, it will be overwritten. Defaults to True.
Notes
-----
- The `.npy` file should contain structured data with `plane`, `accepted`, and `rejected` fields.
"""
if not ops_files:
print("No ops files found.")
return None
plane_stats = {}
for i, file in enumerate(ops_files):
output_ops = load_ops(file)
raw_z = output_ops.get("plane", None)
if raw_z is None:
zplane_num = i
else:
if isinstance(raw_z, (int, np.integer)):
zplane_num = int(raw_z)
else:
s = str(raw_z)
digits = "".join([c for c in s if c.isdigit()])
zplane_num = int(digits) if digits else i
save_path = Path(output_ops["save_path"])
timing = output_ops.get("timing", {})
# Check if required files exist
iscell_file = save_path / "iscell.npy"
traces_file = save_path / "F.npy"
if not iscell_file.exists():
print(f"Skipping plane {zplane_num}: iscell.npy not found at {save_path}")
continue
if not traces_file.exists():
print(f"Skipping plane {zplane_num}: F.npy not found at {save_path}")
continue
# Load files
try:
iscell_raw = np.load(iscell_file, allow_pickle=True)
traces = np.load(traces_file, allow_pickle=True)
# Validate iscell data - check for _NoValue or other invalid types
if not isinstance(iscell_raw, np.ndarray) or iscell_raw.size == 0:
print(f"Skipping plane {zplane_num}: iscell.npy is empty or invalid")
continue
# Handle potential _NoValue entries by filtering to valid numeric data
try:
iscell = iscell_raw[:, 0].astype(bool)
except (TypeError, ValueError) as e:
print(f"Skipping plane {zplane_num}: iscell data conversion failed - {e}")
continue
# Validate traces
if not isinstance(traces, np.ndarray) or traces.size == 0:
print(f"Skipping plane {zplane_num}: F.npy is empty or invalid")
continue
except Exception as e:
print(f"Skipping plane {zplane_num}: Error loading files - {e}")
continue
# Safe stat computations with explicit conversion
try:
n_accepted = int(np.sum(iscell))
n_rejected = int(np.sum(~iscell))
trace_mean = float(np.nanmean(traces))
trace_std = float(np.nanstd(traces))
except (TypeError, ValueError) as e:
print(f"Skipping plane {zplane_num}: Error computing statistics - {e}")
continue
plane_stats[zplane_num] = {
"accepted": n_accepted,
"rejected": n_rejected,
"mean": trace_mean,
"std": trace_std,
"registration": timing.get("registration", np.nan),
"detection": timing.get("detection", timing.get("detect", np.nan)),
"extraction": timing.get("extraction", np.nan),
"classification": timing.get("classification", np.nan),
"deconvolution": timing.get("deconvolution", np.nan),
"total_runtime": timing.get("total_plane_runtime", np.nan),
"filepath": str(file),
"zplane": zplane_num,
}
# Check if any planes had valid statistics
if not plane_stats:
print("No valid plane statistics collected - all planes skipped or missing files")
return None
common = get_common_path(ops_files)
out = []
for p, stats in sorted(plane_stats.items()):
out.append(
(
p,
stats["accepted"],
stats["rejected"],
stats["mean"],
stats["std"],
stats["registration"],
stats["detection"],
stats["extraction"],
stats["classification"],
stats["deconvolution"],
stats["total_runtime"],
stats["filepath"],
stats["zplane"],
)
)
dtype = [
("plane", "i4"),
("accepted", "i4"),
("rejected", "i4"),
("mean_trace", "f8"),
("std_trace", "f8"),
("registration", "f8"),
("detection", "f8"),
("extraction", "f8"),
("classification", "f8"),
("deconvolution", "f8"),
("total_plane_runtime", "f8"),
("filepath", "U255"),
("zplane", "i4"),
]
arr = np.array(out, dtype=dtype)
save_path = Path(common) / "zstats.npy"
if overwrite or not save_path.exists():
np.save(save_path, arr)
return str(save_path)
def save_images_to_movie(image_input, savepath, duration=None, format=".mp4"):
"""
Convert a sequence of saved images into a movie.
TODO: move to mbo_utilities.
Parameters
----------
image_input : str, Path, or list
Directory containing saved segmentation images or a list of image file paths.
savepath : str or Path
Path to save the video file.
duration : int, optional
Desired total video duration in seconds. If None, defaults to 1 FPS (1 image per second).
format : str, optional
Video format: ".mp4" (PowerPoint-compatible), ".avi" (lossless), ".mov" (ProRes). Default is ".mp4".
Examples
--------
>>> import mbo_utilities as mbo
>>> import lbm_suite2p_python as lsp
Get all png files autosaved during LBM-Suite2p-Python `run_volume()`
>>> segmentation_pngs = mbo.get_files("path/suite3d/results/", "segmentation.png", max_depth=3)
>>> lsp.save_images_to_movie(segmentation_pngs, "path/to/save/segmentation.png", format=".mp4")
"""
savepath = Path(savepath).with_suffix(format) # Ensure correct file extension
temp_video = savepath.with_suffix(".avi") # Temporary AVI file for MOV conversion
savepath.parent.mkdir(parents=True, exist_ok=True)
if isinstance(image_input, (str, Path)):
image_dir = Path(image_input)
image_files = sorted(
glob.glob(str(image_dir / "*.png"))
+ glob.glob(str(image_dir / "*.jpg"))
+ glob.glob(str(image_dir / "*.tif"))
)
elif isinstance(image_input, list):
image_files = sorted(map(str, image_input))
else:
raise ValueError(
"image_input must be a directory path or a list of file paths."
)
if not image_files:
return
first_image = cv2.imread(image_files[0])
height, width, _ = first_image.shape
fps = len(image_files) / duration if duration else 1
if format == ".mp4":
fourcc = cv2.VideoWriter_fourcc(*"XVID")
video_path = savepath
elif format == ".avi":
fourcc = cv2.VideoWriter_fourcc(*"HFYU")
video_path = savepath
elif format == ".mov":
fourcc = cv2.VideoWriter_fourcc(*"HFYU")
video_path = temp_video
else:
raise ValueError("Invalid format. Use '.mp4', '.avi', or '.mov'.")
video_writer = cv2.VideoWriter(
str(video_path), fourcc, max(fps, 1), (width, height)
)
for image_file in image_files:
frame = cv2.imread(image_file)
video_writer.write(frame)
video_writer.release()
if format == ".mp4":
ffmpeg_cmd = [
"ffmpeg",
"-y",
"-i",
str(video_path),
"-vcodec",
"libx264",
"-acodec",
"aac",
"-preset",
"slow",
"-crf",
"18",
str(savepath), # Save directly to `savepath`
]
subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
print(f"MP4 saved at {savepath}")
elif format == ".mov":
ffmpeg_cmd = [
"ffmpeg",
"-y",
"-i",
str(temp_video),
"-c:v",
"prores_ks", # Use Apple ProRes codec
"-profile:v",
"3", # ProRes 422 LT
"-pix_fmt",
"yuv422p10le",
str(savepath),
]
subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
temp_video.unlink()
def consolidate_volume(suite2p_path, merged_dir="merged", overwrite=False):
"""
Consolidate all plane results into a single merged directory.
Combines ops.npy, stat.npy, iscell.npy, F.npy, Fneu.npy, and spks.npy
from all planeXX_stitched folders into a single merged/ folder with
plane-indexed arrays.
Parameters
----------
suite2p_path : str or Path
Path to suite2p directory containing planeXX_stitched folders
merged_dir : str, optional
Name of merged directory to create (default: "merged")
overwrite : bool, optional
Whether to overwrite existing merged directory (default: False)
Returns
-------
merged_path : Path
Path to the created merged directory
Examples
--------
>>> import lbm_suite2p_python as lsp
>>> merged = lsp.consolidate_volume("path/to/suite2p")
>>> # Load consolidated results
>>> ops = np.load(merged / "ops.npy", allow_pickle=True).item()
>>> stat = np.load(merged / "stat.npy", allow_pickle=True)
>>> iscell = np.load(merged / "iscell.npy")
"""
suite2p_path = Path(suite2p_path)
merged_path = suite2p_path / merged_dir
# Find all plane directories
plane_dirs = sorted(suite2p_path.glob("plane*_stitched"))
if not plane_dirs:
raise ValueError(f"No plane*_stitched directories found in {suite2p_path}")
print(f"Found {len(plane_dirs)} planes to consolidate")
# Check if merged directory exists
if merged_path.exists() and not overwrite:
print(f"Merged directory already exists: {merged_path}")
print("Use overwrite=True to recreate")
return merged_path
# Create merged directory
merged_path.mkdir(exist_ok=True)
# Initialize lists for consolidation
all_stats = []
all_iscell = []
all_F = []
all_Fneu = []
all_spks = []
all_ops = []
# Track ROI offsets for each plane
n_cells_per_plane = []
for plane_dir in plane_dirs:
plane_num = int(''.join(filter(str.isdigit, plane_dir.name)))
print(f" Processing plane {plane_num}: {plane_dir.name}")
# Load required files
stat_file = plane_dir / "stat.npy"
iscell_file = plane_dir / "iscell.npy"
ops_file = plane_dir / "ops.npy"
if not all([stat_file.exists(), iscell_file.exists(), ops_file.exists()]):
print(f" WARNING: Missing required files, skipping plane {plane_num}")
continue
# Load data
stat = np.load(stat_file, allow_pickle=True)
iscell = np.load(iscell_file)
ops = np.load(ops_file, allow_pickle=True).item()
# Add plane number to each stat entry
for s in stat:
s['iplane'] = plane_num
all_stats.extend(stat)
all_iscell.append(iscell)
all_ops.append(ops)
n_cells_per_plane.append(len(stat))
# Load optional trace files
F_file = plane_dir / "F.npy"
Fneu_file = plane_dir / "Fneu.npy"
spks_file = plane_dir / "spks.npy"
if F_file.exists():
F = np.load(F_file)
all_F.append(F)
else:
print(f" WARNING: F.npy not found for plane {plane_num}")
if Fneu_file.exists():
Fneu = np.load(Fneu_file)
all_Fneu.append(Fneu)
else:
print(f" WARNING: Fneu.npy not found for plane {plane_num}")
if spks_file.exists():
spks = np.load(spks_file)
all_spks.append(spks)
else:
print(f" WARNING: spks.npy not found for plane {plane_num}")
# Save consolidated files
print(f"\nSaving consolidated results to {merged_path}")
# Save stat.npy
stat_array = np.array(all_stats, dtype=object)
np.save(merged_path / "stat.npy", stat_array)
print(f" Saved stat.npy: {len(stat_array)} total ROIs")
# Save iscell.npy
if all_iscell:
iscell_array = np.vstack(all_iscell)
np.save(merged_path / "iscell.npy", iscell_array)
n_accepted = (iscell_array[:, 0] == 1).sum()
print(f" Saved iscell.npy: {n_accepted} accepted, {len(iscell_array) - n_accepted} rejected")
# Save trace files
if all_F:
F_array = np.vstack(all_F)
np.save(merged_path / "F.npy", F_array)
print(f" Saved F.npy: shape {F_array.shape}")
if all_Fneu:
Fneu_array = np.vstack(all_Fneu)
np.save(merged_path / "Fneu.npy", Fneu_array)
print(f" Saved Fneu.npy: shape {Fneu_array.shape}")
if all_spks:
spks_array = np.vstack(all_spks)
np.save(merged_path / "spks.npy", spks_array)
print(f" Saved spks.npy: shape {spks_array.shape}")
# Save consolidated ops
# Use first plane's ops as template and add plane-specific info
consolidated_ops = all_ops[0].copy() if all_ops else {}
consolidated_ops['nplanes'] = len(plane_dirs)
consolidated_ops['n_cells_per_plane'] = n_cells_per_plane
consolidated_ops['plane_dirs'] = [str(d) for d in plane_dirs]
consolidated_ops['save_path'] = str(merged_path)
np.save(merged_path / "ops.npy", consolidated_ops)
print(f" Saved ops.npy: {len(plane_dirs)} planes consolidated")
print(f"\nConsolidation complete!")
print(f"Total ROIs: {len(stat_array)}")
print(f"Planes: {len(plane_dirs)}")
return merged_path
def plot_volume_diagnostics(
ops_files: list[str | Path],
save_path: str | Path = None,
figsize: tuple = (16, 12),
) -> plt.Figure:
"""
Generate a single-figure diagnostic summary for an entire processed volume.
Creates a publication-quality figure showing across all z-planes:
- Row 1: ROI counts (accepted/rejected stacked bars), Mean signal per plane
- Row 2: SNR distribution per plane, Size distribution per plane
- Row 3: Compactness vs SNR (all planes), Skewness vs SNR (all planes)
Parameters
----------
ops_files : list of str or Path
List of paths to ops.npy files for each z-plane.
save_path : str or Path, optional
If provided, save figure to this path.
figsize : tuple, default (16, 12)
Figure size in inches.
Returns
-------
fig : matplotlib.figure.Figure
The generated figure object.
"""
from lbm_suite2p_python.postprocessing import load_ops
if not ops_files:
fig = plt.figure(figsize=figsize, facecolor="black")
fig.text(0.5, 0.5, "No ops files provided", ha="center", va="center",
fontsize=16, fontweight="bold", color="white")
return fig
# Collect data from all planes
plane_data = []
all_snr = []
all_npix = []
all_compactness = []
all_skewness = []
all_plane_ids = [] # track which plane each ROI belongs to
for ops_file in ops_files:
ops_file = Path(ops_file)
ops = load_ops(ops_file)
plane_dir = ops_file.parent
raw_plane = ops.get("plane", None)
# Extract plane number
if raw_plane is not None:
if isinstance(raw_plane, (int, np.integer)):
plane_num = int(raw_plane)
else:
s = str(raw_plane)
digits = "".join([c for c in s if c.isdigit()])
plane_num = int(digits) if digits else len(plane_data)
else:
plane_num = len(plane_data)
# Load required files
iscell_file = plane_dir / "iscell.npy"
stat_file = plane_dir / "stat.npy"
F_file = plane_dir / "F.npy"
Fneu_file = plane_dir / "Fneu.npy"
if not all([iscell_file.exists(), stat_file.exists(), F_file.exists()]):
continue
try:
iscell_raw = np.load(iscell_file, allow_pickle=True)
if not isinstance(iscell_raw, np.ndarray) or iscell_raw.size == 0:
continue
iscell = iscell_raw[:, 0].astype(bool)
stat = np.load(stat_file, allow_pickle=True)
F = np.load(F_file, allow_pickle=True)
Fneu = np.load(Fneu_file, allow_pickle=True) if Fneu_file.exists() else np.zeros_like(F)
except Exception:
continue
n_accepted = int(np.sum(iscell))
n_rejected = int(np.sum(~iscell))
# Compute SNR for accepted cells
F_corr = F - 0.7 * Fneu
baseline = np.percentile(F_corr, 20, axis=1, keepdims=True)
baseline = np.maximum(baseline, 1e-6)
dff = (F_corr - baseline) / baseline
signal = np.std(dff, axis=1)
noise = np.median(np.abs(np.diff(dff, axis=1)), axis=1) / 0.6745
snr = signal / (noise + 1e-6)
# Extract ROI properties
npix = np.array([s.get("npix", 0) for s in stat])
compactness = np.array([s.get("compact", np.nan) for s in stat])
skewness = np.array([s.get("skew", np.nan) for s in stat])
# Store per-plane stats
mean_signal = float(np.nanmean(F))
std_signal = float(np.nanstd(F))
mean_snr = float(np.nanmean(snr[iscell])) if n_accepted > 0 else 0.0
plane_data.append({
"plane": plane_num,
"n_accepted": n_accepted,
"n_rejected": n_rejected,
"mean_signal": mean_signal,
"std_signal": std_signal,
"mean_snr": mean_snr,
})
# Collect accepted cell data for scatter plots
if n_accepted > 0:
all_snr.extend(snr[iscell])
all_npix.extend(npix[iscell])
all_compactness.extend(compactness[iscell])
all_skewness.extend(skewness[iscell])
all_plane_ids.extend([plane_num] * n_accepted)
if not plane_data:
fig = plt.figure(figsize=figsize, facecolor="black")
fig.text(0.5, 0.5, "No valid plane data found\n\nCheck that ops.npy, stat.npy, F.npy exist",
ha="center", va="center", fontsize=14, fontweight="bold", color="white")
return fig
# Convert to arrays
planes = np.array([d["plane"] for d in plane_data])
n_accepted = np.array([d["n_accepted"] for d in plane_data])
n_rejected = np.array([d["n_rejected"] for d in plane_data])
mean_signals = np.array([d["mean_signal"] for d in plane_data])
std_signals = np.array([d["std_signal"] for d in plane_data])
mean_snrs = np.array([d["mean_snr"] for d in plane_data])
all_snr = np.array(all_snr)
all_npix = np.array(all_npix)
all_compactness = np.array(all_compactness)
all_skewness = np.array(all_skewness)
all_plane_ids = np.array(all_plane_ids)
# Create figure with 3x2 grid
fig = plt.figure(figsize=figsize, facecolor="black")
gs = fig.add_gridspec(3, 2, hspace=0.35, wspace=0.25,
left=0.08, right=0.95, top=0.93, bottom=0.08)
# Color palette for planes
n_planes = len(planes)
cmap = plt.cm.viridis
plane_colors = {p: cmap(i / max(1, n_planes - 1)) for i, p in enumerate(planes)}
# Panel 1: ROI counts per plane
ax1 = fig.add_subplot(gs[0, 0])
ax1.set_facecolor("black")
bar_width = 0.8
bars1 = ax1.bar(planes, n_accepted, width=bar_width, label=f"Accepted ({n_accepted.sum()})",
alpha=0.85, color="#2ecc71", edgecolor="#27ae60", linewidth=0.5)
bars2 = ax1.bar(planes, n_rejected, width=bar_width, bottom=n_accepted,
label=f"Rejected ({n_rejected.sum()})", alpha=0.85, color="#e74c3c",
edgecolor="#c0392b", linewidth=0.5)
# Labels inside bars
for bar in bars1:
h = bar.get_height()
if h > 5:
ax1.text(bar.get_x() + bar.get_width()/2, h/2, f"{int(h)}",
ha="center", va="center", fontsize=7, color="white", fontweight="bold")
for b1, b2 in zip(bars1, bars2):
h1, h2 = b1.get_height(), b2.get_height()
if h2 > 5:
ax1.text(b2.get_x() + b2.get_width()/2, h1 + h2/2, f"{int(h2)}",
ha="center", va="center", fontsize=7, color="white", fontweight="bold")
ax1.set_xlabel("Z-Plane", fontsize=9, fontweight="bold", color="white")
ax1.set_ylabel("Number of ROIs", fontsize=9, fontweight="bold", color="white")
ax1.set_title("ROI Counts per Plane", fontsize=10, fontweight="bold", color="white")
ax1.tick_params(colors="white", labelsize=8)
ax1.spines["top"].set_visible(False)
ax1.spines["right"].set_visible(False)
ax1.spines["bottom"].set_color("white")
ax1.spines["left"].set_color("white")
ax1.legend(fontsize=7, facecolor="#1a1a1a", edgecolor="white", labelcolor="white", loc="upper right")
# Panel 2: mean signal per plane
ax2 = fig.add_subplot(gs[0, 1])
ax2.set_facecolor("black")
ax2.errorbar(planes, mean_signals, yerr=std_signals, fmt="o-",
color="#3498db", ecolor="#85c1e9", elinewidth=1.5,
capsize=3, markersize=5, alpha=0.9, label="Mean ± STD")
ax2.set_xlabel("Z-Plane", fontsize=9, fontweight="bold", color="white")
ax2.set_ylabel("Mean Raw Signal", fontsize=9, fontweight="bold", color="white")
ax2.set_title("Fluorescence Signal per Plane", fontsize=10, fontweight="bold", color="white")
ax2.tick_params(colors="white", labelsize=8)
ax2.spines["top"].set_visible(False)
ax2.spines["right"].set_visible(False)
ax2.spines["bottom"].set_color("white")
ax2.spines["left"].set_color("white")
ax2.legend(fontsize=7, facecolor="#1a1a1a", edgecolor="white", labelcolor="white")
# Panel 3: SNR distribution per plane
ax3 = fig.add_subplot(gs[1, 0])
ax3.set_facecolor("black")
if len(all_snr) > 0:
# Box plot per plane
snr_by_plane = [all_snr[all_plane_ids == p] for p in planes]
snr_by_plane = [s[~np.isnan(s)] for s in snr_by_plane]
bp = ax3.boxplot(snr_by_plane, positions=planes, widths=0.6, patch_artist=True,
showfliers=False, medianprops=dict(color="#ffe66d", linewidth=2))
for patch in bp["boxes"]:
patch.set_facecolor("#2ecc71")
patch.set_alpha(0.7)
for whisker in bp["whiskers"]:
whisker.set_color("white")
for cap in bp["caps"]:
cap.set_color("white")
# Add mean line
ax3.plot(planes, mean_snrs, "o--", color="#e74c3c", markersize=4, label="Mean SNR")
else:
ax3.text(0.5, 0.5, "No SNR data", ha="center", va="center", fontsize=12, color="white")
ax3.set_xlabel("Z-Plane", fontsize=9, fontweight="bold", color="white")
ax3.set_ylabel("SNR", fontsize=9, fontweight="bold", color="white")
ax3.set_title("SNR Distribution per Plane", fontsize=10, fontweight="bold", color="white")
ax3.tick_params(colors="white", labelsize=8)
ax3.spines["top"].set_visible(False)
ax3.spines["right"].set_visible(False)
ax3.spines["bottom"].set_color("white")
ax3.spines["left"].set_color("white")
if len(all_snr) > 0:
ax3.legend(fontsize=7, facecolor="#1a1a1a", edgecolor="white", labelcolor="white")
# Panel 4: size distribution per plane
ax4 = fig.add_subplot(gs[1, 1])
ax4.set_facecolor("black")
if len(all_npix) > 0:
npix_by_plane = [all_npix[all_plane_ids == p] for p in planes]
npix_by_plane = [n[n > 0] for n in npix_by_plane]
bp4 = ax4.boxplot(npix_by_plane, positions=planes, widths=0.6, patch_artist=True,
showfliers=False, medianprops=dict(color="#ffe66d", linewidth=2))
for patch in bp4["boxes"]:
patch.set_facecolor("#3498db")
patch.set_alpha(0.7)
for whisker in bp4["whiskers"]:
whisker.set_color("white")
for cap in bp4["caps"]:
cap.set_color("white")
# Mean size per plane
mean_sizes = [np.mean(n) if len(n) > 0 else 0 for n in npix_by_plane]
ax4.plot(planes, mean_sizes, "o--", color="#e74c3c", markersize=4, label="Mean Size")
else:
ax4.text(0.5, 0.5, "No size data", ha="center", va="center", fontsize=12, color="white")
ax4.set_xlabel("Z-Plane", fontsize=9, fontweight="bold", color="white")
ax4.set_ylabel("Size (pixels)", fontsize=9, fontweight="bold", color="white")
ax4.set_title("ROI Size Distribution per Plane", fontsize=10, fontweight="bold", color="white")
ax4.tick_params(colors="white", labelsize=8)
ax4.spines["top"].set_visible(False)
ax4.spines["right"].set_visible(False)
ax4.spines["bottom"].set_color("white")
ax4.spines["left"].set_color("white")
if len(all_npix) > 0:
ax4.legend(fontsize=7, facecolor="#1a1a1a", edgecolor="white", labelcolor="white")
# Panel 5: compactness distribution per plane
ax5 = fig.add_subplot(gs[2, 0])
ax5.set_facecolor("black")
if len(all_compactness) > 0:
compact_by_plane = [all_compactness[all_plane_ids == p] for p in planes]
compact_by_plane = [c[~np.isnan(c)] for c in compact_by_plane]
# Only create boxplot if we have data
valid_compact = [c for c in compact_by_plane if len(c) > 0]
valid_planes_compact = [p for p, c in zip(planes, compact_by_plane) if len(c) > 0]
if valid_compact:
bp5 = ax5.boxplot(valid_compact, positions=valid_planes_compact, widths=0.6, patch_artist=True,
showfliers=False, medianprops=dict(color="#ffe66d", linewidth=2))
for patch in bp5["boxes"]:
patch.set_facecolor("#9b59b6") # Purple for compactness
patch.set_alpha(0.7)
for whisker in bp5["whiskers"]:
whisker.set_color("white")
for cap in bp5["caps"]:
cap.set_color("white")
# Mean compactness per plane
mean_compact = [np.mean(c) if len(c) > 0 else np.nan for c in compact_by_plane]
valid_mean_compact = [m for m, c in zip(mean_compact, compact_by_plane) if len(c) > 0]
ax5.plot(valid_planes_compact, valid_mean_compact, "o--", color="#e74c3c", markersize=4, label="Mean")
ax5.legend(fontsize=7, facecolor="#1a1a1a", edgecolor="white", labelcolor="white")
else:
ax5.text(0.5, 0.5, "No compactness data", ha="center", va="center", fontsize=12, color="white")
else:
ax5.text(0.5, 0.5, "No data", ha="center", va="center", fontsize=12, color="white")
ax5.set_xlabel("Z-Plane", fontsize=9, fontweight="bold", color="white")
ax5.set_ylabel("Compactness", fontsize=9, fontweight="bold", color="white")
ax5.set_title("Compactness Distribution per Plane", fontsize=10, fontweight="bold", color="white")
ax5.tick_params(colors="white", labelsize=8)
ax5.spines["top"].set_visible(False)
ax5.spines["right"].set_visible(False)
ax5.spines["bottom"].set_color("white")
ax5.spines["left"].set_color("white")
# Panel 6: skewness distribution per plane
ax6 = fig.add_subplot(gs[2, 1])
ax6.set_facecolor("black")
if len(all_skewness) > 0:
skew_by_plane = [all_skewness[all_plane_ids == p] for p in planes]
skew_by_plane = [s[~np.isnan(s)] for s in skew_by_plane]
# Only create boxplot if we have data
valid_skew = [s for s in skew_by_plane if len(s) > 0]
valid_planes_skew = [p for p, s in zip(planes, skew_by_plane) if len(s) > 0]
if valid_skew:
bp6 = ax6.boxplot(valid_skew, positions=valid_planes_skew, widths=0.6, patch_artist=True,
showfliers=False, medianprops=dict(color="#ffe66d", linewidth=2))
for patch in bp6["boxes"]:
patch.set_facecolor("#e67e22") # Orange for skewness
patch.set_alpha(0.7)
for whisker in bp6["whiskers"]:
whisker.set_color("white")
for cap in bp6["caps"]:
cap.set_color("white")
# Mean skewness per plane
mean_skew = [np.mean(s) if len(s) > 0 else np.nan for s in skew_by_plane]
valid_mean_skew = [m for m, s in zip(mean_skew, skew_by_plane) if len(s) > 0]
ax6.plot(valid_planes_skew, valid_mean_skew, "o--", color="#e74c3c", markersize=4, label="Mean")
ax6.legend(fontsize=7, facecolor="#1a1a1a", edgecolor="white", labelcolor="white")
else:
ax6.text(0.5, 0.5, "No skewness data", ha="center", va="center", fontsize=12, color="white")
else:
ax6.text(0.5, 0.5, "No data", ha="center", va="center", fontsize=12, color="white")
ax6.set_xlabel("Z-Plane", fontsize=9, fontweight="bold", color="white")
ax6.set_ylabel("Skewness", fontsize=9, fontweight="bold", color="white")
ax6.set_title("Skewness Distribution per Plane", fontsize=10, fontweight="bold", color="white")
ax6.tick_params(colors="white", labelsize=8)
ax6.spines["top"].set_visible(False)
ax6.spines["right"].set_visible(False)
ax6.spines["bottom"].set_color("white")
ax6.spines["left"].set_color("white")
# Title
total_accepted = n_accepted.sum()
total_rejected = n_rejected.sum()
fig.suptitle(f"Volume Quality Diagnostics: {n_planes} planes, {total_accepted} accepted, {total_rejected} rejected ROIs",
fontsize=12, fontweight="bold", color="white", y=0.98)
if save_path:
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="black")
plt.close(fig)
return fig
def plot_orthoslices(
ops_files: list[str | Path],
save_path: str | Path = None,
figsize: tuple = (16, 6),
use_mean: bool = True,
) -> plt.Figure:
"""
Generate orthogonal maximum intensity projections (XY, XZ, YZ) of the volume.
Creates a 3-panel figure showing the volume from three orthogonal views,
with proper interpolation to isotropic resolution. Axes are displayed
in micrometers using voxel size metadata.
Parameters
----------
ops_files : list of str or Path
List of paths to ops.npy files for each z-plane, ordered by z.
save_path : str or Path, optional
If provided, save figure to this path.
figsize : tuple, default (16, 6)
Figure size in inches.
use_mean : bool, default True
If True, use meanImg. If False, use refImg (registered reference).
Returns
-------
fig : matplotlib.figure.Figure
The generated figure object.
"""
from scipy.ndimage import zoom
from lbm_suite2p_python.postprocessing import load_ops
if not ops_files:
fig = plt.figure(figsize=figsize, facecolor="black")
fig.text(0.5, 0.5, "No ops files provided", ha="center", va="center",
fontsize=16, fontweight="bold", color="white")
return fig
# get voxel size from first ops file
first_ops = load_ops(ops_files[0])
dx_um, dy_um, dz_um = 1.0, 1.0, 15.0 # defaults (15um is typical z-step)
try:
from mbo_utilities.metadata import get_voxel_size
voxel = get_voxel_size(first_ops)
# only use voxel sizes if they're not the default 1.0 placeholder
if voxel.dx > 0 and voxel.dx != 1.0:
dx_um = voxel.dx
if voxel.dy > 0 and voxel.dy != 1.0:
dy_um = voxel.dy
if voxel.dz > 0 and voxel.dz != 1.0:
dz_um = voxel.dz
except (ImportError, Exception):
pass
# fallback to ops fields
if dx_um == 1.0:
pixel_res = first_ops.get("pixel_resolution", first_ops.get("um_per_pixel", None))
if pixel_res is not None:
if isinstance(pixel_res, (int, float)):
dx_um, dy_um = float(pixel_res), float(pixel_res)
else:
dx_um = float(pixel_res[0]) if len(pixel_res) > 0 else 1.0
dy_um = float(pixel_res[1]) if len(pixel_res) > 1 else dx_um
if dz_um == 15.0: # still default
dz_from_ops = first_ops.get("dz", first_ops.get("z_step", None))
if dz_from_ops is not None and dz_from_ops != 1.0:
dz_um = float(dz_from_ops)
# collect images from all planes
images = []
for ops_file in ops_files:
ops_file = Path(ops_file)
ops = load_ops(ops_file)
img_key = "meanImg" if use_mean else "refImg"
img = ops.get(img_key)
if img is None or not isinstance(img, np.ndarray):
img = ops.get("meanImg" if not use_mean else "refImg")
if img is None or not isinstance(img, np.ndarray):
continue
images.append(img)
if not images:
fig = plt.figure(figsize=figsize, facecolor="black")
fig.text(0.5, 0.5, "No valid images found", ha="center", va="center",
fontsize=16, fontweight="bold", color="white")
return fig
# stack into 3D volume (Z, Y, X)
volume = np.stack(images, axis=0)
nz, ny, nx = volume.shape
# compute volume dimensions in microns
vol_x_um = nx * dx_um
vol_y_um = ny * dy_um
vol_z_um = (nz - 1) * dz_um if nz > 1 else dz_um
# interpolate volume to isotropic resolution for proper orthoslices
xy_res = (dx_um + dy_um) / 2
z_zoom = dz_um / xy_res if xy_res > 0 else 1.0
z_zoom = min(z_zoom, 10.0) # cap to avoid memory issues
if z_zoom > 1.1:
volume_resampled = zoom(volume, (z_zoom, 1, 1), order=1)
else:
volume_resampled = volume
# compute projections
xy_proj = np.max(volume, axis=0)
xz_proj = np.max(volume_resampled, axis=1)
yz_proj = np.max(volume_resampled, axis=2)
# create figure
fig = plt.figure(figsize=figsize, facecolor="black")
gs = fig.add_gridspec(1, 3, wspace=0.15, left=0.05, right=0.95, top=0.88, bottom=0.1)
# extents for proper axis scaling (all in microns)
xy_extent = [0, vol_x_um, vol_y_um, 0]
xz_extent = [0, vol_x_um, vol_z_um, 0]
yz_extent = [0, vol_z_um, vol_y_um, 0]
# panel 1: XY projection
ax1 = fig.add_subplot(gs[0, 0])
ax1.set_facecolor("black")
im1 = ax1.imshow(xy_proj, cmap="magma", aspect="equal", extent=xy_extent,
vmin=np.percentile(xy_proj, 1), vmax=np.percentile(xy_proj, 99.5))
ax1.set_xlabel("X (μm)", fontsize=10, fontweight="bold", color="white")
ax1.set_ylabel("Y (μm)", fontsize=10, fontweight="bold", color="white")
ax1.set_title("XY Projection", fontsize=11, fontweight="bold", color="white")
ax1.tick_params(colors="white", labelsize=8)
for spine in ax1.spines.values():
spine.set_color("white")
# panel 2: XZ projection
ax2 = fig.add_subplot(gs[0, 1])
ax2.set_facecolor("black")
im2 = ax2.imshow(xz_proj, cmap="magma", aspect="auto", extent=xz_extent,
vmin=np.percentile(xz_proj, 1), vmax=np.percentile(xz_proj, 99.5),
interpolation="bilinear")
ax2.set_xlabel("X (μm)", fontsize=10, fontweight="bold", color="white")
ax2.set_ylabel("Z (μm)", fontsize=10, fontweight="bold", color="white")
ax2.set_title("XZ Projection", fontsize=11, fontweight="bold", color="white")
ax2.tick_params(colors="white", labelsize=8)
for spine in ax2.spines.values():
spine.set_color("white")
# panel 3: YZ projection
ax3 = fig.add_subplot(gs[0, 2])
ax3.set_facecolor("black")
im3 = ax3.imshow(yz_proj.T, cmap="magma", aspect="auto", extent=yz_extent,
vmin=np.percentile(yz_proj, 1), vmax=np.percentile(yz_proj, 99.5),
interpolation="bilinear")
ax3.set_xlabel("Z (μm)", fontsize=10, fontweight="bold", color="white")
ax3.set_ylabel("Y (μm)", fontsize=10, fontweight="bold", color="white")
ax3.set_title("YZ Projection", fontsize=11, fontweight="bold", color="white")
ax3.tick_params(colors="white", labelsize=8)
for spine in ax3.spines.values():
spine.set_color("white")
# colorbar
cbar = fig.colorbar(im1, ax=[ax1, ax2, ax3], shrink=0.6, pad=0.02, location="right")
cbar.set_label("Intensity (a.u.)", fontsize=10, color="white")
cbar.ax.tick_params(colors="white")
cbar.outline.set_edgecolor("white")
# title with volume dimensions in microns
title = f"Orthogonal Projections: {nz} planes, {vol_x_um:.0f}×{vol_y_um:.0f}×{vol_z_um:.0f} μm"
fig.suptitle(title, fontsize=12, fontweight="bold", color="white", y=0.96)
if save_path:
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="black")
plt.close(fig)
return fig
def plot_3d_roi_map(
ops_files: list[str | Path],
save_path: str | Path = None,
figsize: tuple = (14, 10),
color_by: str = "snr",
show_rejected: bool = False,
) -> plt.Figure:
"""
Generate a 3D scatter plot of ROI centroids across the volume.
Creates a 3D visualization showing the spatial distribution of detected
cells colored by SNR. Axes are displayed in micrometers when valid voxel
size metadata is available, otherwise in pixels/planes.
Parameters
----------
ops_files : list of str or Path
List of paths to ops.npy files for each z-plane.
save_path : str or Path, optional
If provided, save figure to this path.
figsize : tuple, default (14, 10)
Figure size in inches.
color_by : str, default "snr"
How to color the ROIs: "snr", "plane", "size", or "activity".
show_rejected : bool, default False
If True, also show rejected ROIs in gray.
Returns
-------
fig : matplotlib.figure.Figure
The generated figure object.
"""
from lbm_suite2p_python.postprocessing import load_ops
if not ops_files:
fig = plt.figure(figsize=figsize, facecolor="black")
fig.text(0.5, 0.5, "No ops files provided", ha="center", va="center",
fontsize=16, fontweight="bold", color="white")
return fig
# Get voxel size from first ops file
first_ops = load_ops(ops_files[0])
dx_um, dy_um, dz_um = 1.0, 1.0, 15.0 # defaults (15um is typical z-step)
try:
from mbo_utilities.metadata import get_voxel_size
voxel = get_voxel_size(first_ops)
# only use voxel sizes if they're not the default 1.0 placeholder
if voxel.dx > 0 and voxel.dx != 1.0:
dx_um = voxel.dx
if voxel.dy > 0 and voxel.dy != 1.0:
dy_um = voxel.dy
if voxel.dz > 0 and voxel.dz != 1.0:
dz_um = voxel.dz
except (ImportError, Exception):
pass
# fallback to ops fields
if dx_um == 1.0:
pixel_res = first_ops.get("pixel_resolution", first_ops.get("um_per_pixel", None))
if pixel_res is not None:
if isinstance(pixel_res, (int, float)):
dx_um, dy_um = float(pixel_res), float(pixel_res)
else:
dx_um = float(pixel_res[0]) if len(pixel_res) > 0 else 1.0
dy_um = float(pixel_res[1]) if len(pixel_res) > 1 else dx_um
if dz_um == 15.0: # still default
dz_from_ops = first_ops.get("dz", first_ops.get("z_step", None))
if dz_from_ops is not None and dz_from_ops != 1.0:
dz_um = float(dz_from_ops)
# Collect ROI data from all planes
all_x = []
all_y = []
all_z = []
all_colors = []
all_accepted = []
# For rejected ROIs
rej_x, rej_y, rej_z = [], [], []
for plane_idx, ops_file in enumerate(ops_files):
ops_file = Path(ops_file)
ops = load_ops(ops_file)
plane_dir = ops_file.parent
# Use enumeration index for z-depth (planes are ordered)
plane_num = plane_idx
# Load required files
stat_file = plane_dir / "stat.npy"
iscell_file = plane_dir / "iscell.npy"
if not stat_file.exists() or not iscell_file.exists():
continue
try:
stat = np.load(stat_file, allow_pickle=True)
iscell_raw = np.load(iscell_file, allow_pickle=True)
if not isinstance(iscell_raw, np.ndarray) or iscell_raw.size == 0:
continue
iscell = iscell_raw[:, 0].astype(bool)
except Exception:
continue
# Get color values based on color_by
z_um = plane_num * dz_um # z-depth in microns for this plane
if color_by == "snr":
F_file = plane_dir / "F.npy"
Fneu_file = plane_dir / "Fneu.npy"
if F_file.exists():
F = np.load(F_file, allow_pickle=True)
Fneu = np.load(Fneu_file, allow_pickle=True) if Fneu_file.exists() else np.zeros_like(F)
F_corr = F - 0.7 * Fneu
baseline = np.percentile(F_corr, 20, axis=1, keepdims=True)
baseline = np.maximum(baseline, 1e-6)
dff = (F_corr - baseline) / baseline
signal = np.std(dff, axis=1)
noise = np.median(np.abs(np.diff(dff, axis=1)), axis=1) / 0.6745
color_vals = signal / (noise + 1e-6)
else:
# no F data - use zeros
color_vals = np.zeros(len(stat))
elif color_by == "size":
color_vals = np.array([s.get("npix", 100) for s in stat])
elif color_by == "activity":
F_file = plane_dir / "F.npy"
Fneu_file = plane_dir / "Fneu.npy"
if F_file.exists():
F = np.load(F_file, allow_pickle=True)
Fneu = np.load(Fneu_file, allow_pickle=True) if Fneu_file.exists() else np.zeros_like(F)
F_corr = F - 0.7 * Fneu
# compute zscore: (x - mean) / std
mean_f = np.mean(F_corr, axis=1, keepdims=True)
std_f = np.std(F_corr, axis=1, keepdims=True)
std_f = np.maximum(std_f, 1e-6)
zscore = (F_corr - mean_f) / std_f
# use max zscore as activity metric
color_vals = np.max(zscore, axis=1)
else:
# no F data - use zeros
color_vals = np.zeros(len(stat))
else: # plane - use z-depth in microns
color_vals = np.ones(len(stat)) * z_um
# Extract centroids and convert to microns
for i, s in enumerate(stat):
med = s.get("med", [0, 0])
y_px, x_px = med[0], med[1]
# Convert pixels to microns
x_um = x_px * dx_um
y_um = y_px * dy_um
# z_um already calculated above as plane_num * dz_um
if iscell[i]:
all_x.append(x_um)
all_y.append(y_um)
all_z.append(z_um)
all_colors.append(color_vals[i])
all_accepted.append(True)
elif show_rejected:
rej_x.append(x_um)
rej_y.append(y_um)
rej_z.append(z_um)
if not all_x:
fig = plt.figure(figsize=figsize, facecolor="black")
fig.text(0.5, 0.5, "No ROI data found", ha="center", va="center",
fontsize=16, fontweight="bold", color="white")
return fig
# Convert to arrays
all_x = np.array(all_x)
all_y = np.array(all_y)
all_z = np.array(all_z)
all_colors = np.array(all_colors)
# Create figure with 3D axis
fig = plt.figure(figsize=figsize, facecolor="black")
ax = fig.add_subplot(111, projection="3d", facecolor="black")
# Set pane colors to dark
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
ax.xaxis.pane.set_edgecolor("white")
ax.yaxis.pane.set_edgecolor("white")
ax.zaxis.pane.set_edgecolor("white")
# Plot rejected ROIs first (if enabled)
if show_rejected and rej_x:
ax.scatter(rej_x, rej_y, rej_z, c="gray", s=10, alpha=0.3, label="Rejected")
# Choose colormap based on color_by
if color_by == "plane":
cmap = "viridis"
clabel = "Z-depth (μm)"
elif color_by == "snr":
cmap = "plasma"
clabel = "SNR"
# Clip extreme values
vmin, vmax = np.percentile(all_colors, [5, 95])
all_colors = np.clip(all_colors, vmin, vmax)
elif color_by == "size":
cmap = "cividis"
clabel = "Size (pixels)"
vmin, vmax = np.percentile(all_colors, [5, 95])
all_colors = np.clip(all_colors, vmin, vmax)
else: # activity
cmap = "magma"
clabel = "Z-score"
vmin, vmax = np.percentile(all_colors, [5, 95])
all_colors = np.clip(all_colors, vmin, vmax)
# Plot accepted ROIs
scatter = ax.scatter(all_x, all_y, all_z, c=all_colors, cmap=cmap,
s=15, alpha=0.7, edgecolors="none")
# Add colorbar
cbar = fig.colorbar(scatter, ax=ax, shrink=0.6, pad=0.1)
cbar.set_label(clabel, fontsize=10, color="white")
cbar.ax.tick_params(colors="white")
cbar.outline.set_edgecolor("white")
# Style axes - always show in microns
ax.set_xlabel("X (μm)", fontsize=10, fontweight="bold", color="white", labelpad=10)
ax.set_ylabel("Y (μm)", fontsize=10, fontweight="bold", color="white", labelpad=10)
ax.set_zlabel("Z (μm)", fontsize=10, fontweight="bold", color="white", labelpad=10)
ax.tick_params(colors="white", labelsize=8)
ax.xaxis.label.set_color("white")
ax.yaxis.label.set_color("white")
ax.zaxis.label.set_color("white")
# Set grid color
ax.xaxis._axinfo["grid"]["color"] = (1, 1, 1, 0.2)
ax.yaxis._axinfo["grid"]["color"] = (1, 1, 1, 0.2)
ax.zaxis._axinfo["grid"]["color"] = (1, 1, 1, 0.2)
# Title with volume dimensions in microns
n_cells = len(all_x)
n_planes = len(np.unique(all_z))
x_range = all_x.max() - all_x.min()
y_range = all_y.max() - all_y.min()
z_range = all_z.max() - all_z.min()
fig.suptitle(
f"3D ROI Distribution: {n_cells} cells across {n_planes} planes\n"
f"Volume: {x_range:.0f} × {y_range:.0f} × {z_range:.0f} μm",
fontsize=12, fontweight="bold", color="white", y=0.95
)
if show_rejected and rej_x:
ax.legend(fontsize=9, facecolor="#1a1a1a", edgecolor="white", labelcolor="white")
# Adjust view angle for better visualization
ax.view_init(elev=20, azim=45)
if save_path:
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="black")
plt.close(fig)
return fig
def plot_3d_rastermap_clusters(
suite2p_path: str | Path,
save_path: str | Path = None,
figsize: tuple = (14, 10),
n_clusters: int = 40,
show_rejected: bool = False,
rastermap_kwargs: dict = None,
) -> plt.Figure:
"""
Generate a 3D scatter plot of ROI centroids colored by rastermap cluster.
Loads volumetric suite2p output, runs rastermap clustering on fluorescence
data, and visualizes ROI positions in 3D colored by cluster assignment.
Checks for existing rastermap model file before running.
Parameters
----------
suite2p_path : str or Path
Path to suite2p directory containing plane*_stitched folders or merged folder.
save_path : str or Path, optional
If provided, save figure to this path.
figsize : tuple, default (14, 10)
Figure size in inches.
n_clusters : int, default 40
Number of rastermap clusters. Ignored if loading existing model.
show_rejected : bool, default False
If True, also show rejected ROIs in gray.
rastermap_kwargs : dict, optional
Additional kwargs passed to Rastermap(). Ignored if loading existing model.
Returns
-------
fig : matplotlib.figure.Figure
The generated figure object.
Notes
-----
Looks for existing rastermap model at:
- suite2p_path/rastermap_model.npy
- suite2p_path/merged/rastermap_model.npy
If not found, runs rastermap on the consolidated fluorescence data
and saves the model for future use.
Examples
--------
>>> import lbm_suite2p_python as lsp
>>> fig = lsp.plot_3d_rastermap_clusters("path/to/suite2p")
>>> fig = lsp.plot_3d_rastermap_clusters("path/to/suite2p", n_clusters=50)
"""
import logging
from lbm_suite2p_python.postprocessing import load_ops
logger = logging.getLogger(__name__)
suite2p_path = Path(suite2p_path)
# check if rastermap is available
try:
from rastermap import Rastermap
has_rastermap = True
except ImportError:
has_rastermap = False
logger.warning("rastermap not installed, cannot generate cluster visualization")
# find plane directories
plane_dirs = sorted(suite2p_path.glob("plane*_stitched"))
merged_dir = suite2p_path / "merged"
if merged_dir.exists() and (merged_dir / "stat.npy").exists():
# use merged directory
data_dirs = [merged_dir]
use_merged = True
elif plane_dirs:
data_dirs = plane_dirs
use_merged = False
else:
fig = plt.figure(figsize=figsize, facecolor="black")
fig.text(0.5, 0.5, "No plane directories or merged folder found",
ha="center", va="center", fontsize=14, fontweight="bold", color="white")
return fig
# check data is volumetric (4D = multiple planes)
if not use_merged and len(plane_dirs) < 2:
logger.warning("Only 1 plane found - this visualization is designed for volumetric (4D) data")
# look for existing rastermap model
model_paths = [
suite2p_path / "rastermap_model.npy",
merged_dir / "rastermap_model.npy",
]
rastermap_model = None
for mp in model_paths:
if mp.exists():
try:
rastermap_model = np.load(mp, allow_pickle=True).item()
logger.info(f"Loaded existing rastermap model from {mp}")
break
except Exception as e:
logger.warning(f"Failed to load rastermap model from {mp}: {e}")
# get voxel size from first ops file
first_ops_file = None
for d in data_dirs:
ops_file = d / "ops.npy"
if ops_file.exists():
first_ops_file = ops_file
break
dx_um, dy_um, dz_um = 1.0, 1.0, 15.0
if first_ops_file:
first_ops = load_ops(first_ops_file)
try:
from mbo_utilities.metadata import get_voxel_size
voxel = get_voxel_size(first_ops)
if voxel.dx > 0 and voxel.dx != 1.0:
dx_um = voxel.dx
if voxel.dy > 0 and voxel.dy != 1.0:
dy_um = voxel.dy
if voxel.dz > 0 and voxel.dz != 1.0:
dz_um = voxel.dz
except (ImportError, Exception):
pass
if dx_um == 1.0:
pixel_res = first_ops.get("pixel_resolution", first_ops.get("um_per_pixel", None))
if pixel_res is not None:
if isinstance(pixel_res, (int, float)):
dx_um, dy_um = float(pixel_res), float(pixel_res)
else:
dx_um = float(pixel_res[0]) if len(pixel_res) > 0 else 1.0
dy_um = float(pixel_res[1]) if len(pixel_res) > 1 else dx_um
if dz_um == 15.0:
dz_from_ops = first_ops.get("dz", first_ops.get("z_step", None))
if dz_from_ops is not None and dz_from_ops != 1.0:
dz_um = float(dz_from_ops)
# collect ROI data and fluorescence
all_x, all_y, all_z = [], [], []
all_F = []
all_iscell = [] # track iscell for filtering saved rastermap
rej_x, rej_y, rej_z = [], [], []
if use_merged:
# load from merged directory
stat_file = merged_dir / "stat.npy"
iscell_file = merged_dir / "iscell.npy"
F_file = merged_dir / "F.npy"
Fneu_file = merged_dir / "Fneu.npy"
if not all([stat_file.exists(), iscell_file.exists(), F_file.exists()]):
fig = plt.figure(figsize=figsize, facecolor="black")
fig.text(0.5, 0.5, "Missing required files in merged directory",
ha="center", va="center", fontsize=14, fontweight="bold", color="white")
return fig
stat = np.load(stat_file, allow_pickle=True)
iscell = np.load(iscell_file)[:, 0].astype(bool)
F = np.load(F_file)
Fneu = np.load(Fneu_file) if Fneu_file.exists() else np.zeros_like(F)
for i, s in enumerate(stat):
med = s.get("med", [0, 0])
y_px, x_px = med[0], med[1]
plane_num = s.get("iplane", 0)
x_um = x_px * dx_um
y_um = y_px * dy_um
z_um = plane_num * dz_um
if iscell[i]:
all_x.append(x_um)
all_y.append(y_um)
all_z.append(z_um)
elif show_rejected:
rej_x.append(x_um)
rej_y.append(y_um)
rej_z.append(z_um)
# neuropil correction
F_corr = F - 0.7 * Fneu
all_F = F_corr[iscell]
all_iscell = iscell
else:
# load from individual plane directories
F_list = []
iscell_list = []
for plane_idx, plane_dir in enumerate(plane_dirs):
stat_file = plane_dir / "stat.npy"
iscell_file = plane_dir / "iscell.npy"
F_file = plane_dir / "F.npy"
Fneu_file = plane_dir / "Fneu.npy"
if not all([stat_file.exists(), iscell_file.exists(), F_file.exists()]):
continue
try:
stat = np.load(stat_file, allow_pickle=True)
iscell_raw = np.load(iscell_file)
iscell = iscell_raw[:, 0].astype(bool)
F = np.load(F_file)
Fneu = np.load(Fneu_file) if Fneu_file.exists() else np.zeros_like(F)
except Exception:
continue
z_um = plane_idx * dz_um
F_corr = F - 0.7 * Fneu
iscell_list.append(iscell)
for i, s in enumerate(stat):
med = s.get("med", [0, 0])
y_px, x_px = med[0], med[1]
x_um = x_px * dx_um
y_um = y_px * dy_um
if iscell[i]:
all_x.append(x_um)
all_y.append(y_um)
all_z.append(z_um)
elif show_rejected:
rej_x.append(x_um)
rej_y.append(y_um)
rej_z.append(z_um)
F_list.append(F_corr[iscell])
if F_list:
all_F = np.vstack(F_list)
if iscell_list:
all_iscell = np.concatenate(iscell_list)
if len(all_x) == 0:
fig = plt.figure(figsize=figsize, facecolor="black")
fig.text(0.5, 0.5, "No accepted ROIs found",
ha="center", va="center", fontsize=14, fontweight="bold", color="white")
return fig
all_x = np.array(all_x)
all_y = np.array(all_y)
all_z = np.array(all_z)
# run or load rastermap
cluster_ids = None
n_actual_clusters = 0
if rastermap_model is not None:
# use existing model (handle both dict and Rastermap object)
if hasattr(rastermap_model, "embedding_clust"):
embedding_clust = rastermap_model.embedding_clust
elif isinstance(rastermap_model, dict):
embedding_clust = rastermap_model.get("embedding_clust", None)
else:
embedding_clust = None
if embedding_clust is not None:
cluster_ids = embedding_clust.flatten()
# filter by iscell if model was trained on all ROIs
if len(cluster_ids) != len(all_x) and len(all_iscell) > 0:
if len(cluster_ids) == len(all_iscell):
cluster_ids = cluster_ids[all_iscell]
logger.info(f"Filtered cluster_ids from {len(embedding_clust)} to {len(cluster_ids)} accepted ROIs")
n_actual_clusters = len(np.unique(cluster_ids[~np.isnan(cluster_ids)]))
logger.info(f"Using {n_actual_clusters} clusters from saved model")
elif has_rastermap and len(all_F) > 0:
# run rastermap
logger.info(f"Running rastermap with n_clusters={n_clusters}")
try:
# zscore the data
from scipy.stats import zscore
spks = zscore(all_F.astype("float32"), axis=1)
# set up rastermap params
kwargs = {"n_clusters": n_clusters, "n_PCs": min(200, spks.shape[0] - 1), "verbose": False}
if rastermap_kwargs:
kwargs.update(rastermap_kwargs)
model = Rastermap(**kwargs).fit(spks)
cluster_ids = model.embedding_clust.flatten()
n_actual_clusters = model.n_clusters
# save model for future use
model_save = {
"embedding": model.embedding,
"isort": model.isort,
"embedding_clust": model.embedding_clust,
"n_clusters": model.n_clusters,
}
save_model_path = suite2p_path / "rastermap_model.npy"
np.save(save_model_path, model_save)
logger.info(f"Saved rastermap model to {save_model_path}")
except Exception as e:
logger.warning(f"Rastermap failed: {e}")
cluster_ids = None
# fallback to z-plane coloring if no clusters
if cluster_ids is None:
logger.warning("No rastermap clusters available, coloring by z-plane instead")
cluster_ids = all_z
n_actual_clusters = len(np.unique(all_z))
color_label = "Z-depth (μm)"
cmap = "viridis"
else:
color_label = "Cluster"
cmap = "tab20" if n_actual_clusters <= 20 else "nipy_spectral"
# create figure
fig = plt.figure(figsize=figsize, facecolor="black")
ax = fig.add_subplot(111, projection="3d", facecolor="black")
# style panes
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
ax.xaxis.pane.set_edgecolor("white")
ax.yaxis.pane.set_edgecolor("white")
ax.zaxis.pane.set_edgecolor("white")
# plot rejected ROIs first
if show_rejected and rej_x:
ax.scatter(rej_x, rej_y, rej_z, c="gray", s=10, alpha=0.3, label="Rejected")
# plot accepted ROIs colored by cluster
scatter = ax.scatter(all_x, all_y, all_z, c=cluster_ids, cmap=cmap,
s=15, alpha=0.8, edgecolors="none")
# colorbar
cbar = fig.colorbar(scatter, ax=ax, shrink=0.6, pad=0.1)
cbar.set_label(color_label, fontsize=10, color="white")
cbar.ax.tick_params(colors="white")
cbar.outline.set_edgecolor("white")
# style axes
ax.set_xlabel("X (μm)", fontsize=10, fontweight="bold", color="white", labelpad=10)
ax.set_ylabel("Y (μm)", fontsize=10, fontweight="bold", color="white", labelpad=10)
ax.set_zlabel("Z (μm)", fontsize=10, fontweight="bold", color="white", labelpad=10)
ax.tick_params(colors="white", labelsize=8)
ax.xaxis._axinfo["grid"]["color"] = (1, 1, 1, 0.2)
ax.yaxis._axinfo["grid"]["color"] = (1, 1, 1, 0.2)
ax.zaxis._axinfo["grid"]["color"] = (1, 1, 1, 0.2)
# title
n_cells = len(all_x)
n_planes = len(np.unique(all_z))
x_range = all_x.max() - all_x.min()
y_range = all_y.max() - all_y.min()
z_range = all_z.max() - all_z.min()
title = f"Rastermap Clusters: {n_cells} cells, {n_actual_clusters} clusters, {n_planes} planes"
subtitle = f"Volume: {x_range:.0f} × {y_range:.0f} × {z_range:.0f} μm"
fig.suptitle(f"{title}\n{subtitle}", fontsize=12, fontweight="bold", color="white", y=0.95)
if show_rejected and rej_x:
ax.legend(fontsize=9, facecolor="#1a1a1a", edgecolor="white", labelcolor="white")
ax.view_init(elev=20, azim=45)
if save_path:
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="black")
plt.close(fig)
else:
plt.show()
return fig