import glob
import os
import subprocess
from pathlib import Path
import cv2
import numpy as np
from matplotlib import pyplot as plt
from lbm_suite2p_python.zplane import load_ops
from lbm_suite2p_python.utils import get_common_path
[docs]
def update_ops_paths(ops_files: str | list):
"""
Update save_path, save_path0, and save_folder in an ops dictionary based on its current location. Use after moving an ops_file or batch of ops_files."""
if isinstance(ops_files, (str, Path)):
ops_files = [ops_files]
for ops_file in ops_files:
ops = np.load(ops_file, allow_pickle=True).item()
ops_path = Path(ops_file)
plane0_folder = ops_path.parent
plane_folder = plane0_folder.parent
ops["save_path"] = str(plane0_folder)
ops["save_path0"] = str(plane_folder)
ops["save_folder"] = plane_folder.name
ops["ops_path"] = ops_path
np.save(ops_file, ops)
[docs]
def plot_execution_time(filepath, savepath):
"""
Plots the execution time for each processing step per z-plane.
This function loads execution timing data from a `.npy` file and visualizes the
runtime of different processing steps as a stacked bar plot with a black background.
Parameters
----------
filepath : str or Path
Path to the `.npy` file containing the volume timing stats.
savepath : str or Path
Path to save the generated figure.
Notes
-----
- The `.npy` file should contain structured data with `plane`, `registration`,
`detection`, `extraction`, `classification`, `deconvolution`, and `total_plane_runtime` fields.
"""
plane_stats = np.load(filepath)
planes = plane_stats["plane"]
reg_time = plane_stats["registration"]
detect_time = plane_stats["detection"]
extract_time = plane_stats["extraction"]
total_time = plane_stats["total_plane_runtime"]
plt.figure(figsize=(10, 6), facecolor="black")
ax = plt.gca()
ax.set_facecolor("black")
plt.xlabel("Z-Plane", fontsize=14, fontweight="bold", color="white")
plt.ylabel("Execution Time (s)", fontsize=14, fontweight="bold", color="white")
plt.title(
"Execution Time per Processing Step",
fontsize=16,
fontweight="bold",
color="white",
)
plt.bar(planes, reg_time, label="Registration", alpha=0.8, color="#FF5733")
plt.bar(
planes,
detect_time,
label="Detection",
alpha=0.8,
bottom=reg_time,
color="#33FF57",
)
bars3 = plt.bar(
planes,
extract_time,
label="Extraction",
alpha=0.8,
bottom=reg_time + detect_time,
color="#3357FF",
)
for bar, total in zip(bars3, total_time):
height = bar.get_y() + bar.get_height()
if total > 1: # Only label if execution time is large enough to be visible
plt.text(
bar.get_x() + bar.get_width() / 2,
height + 2,
f"{int(total)}",
ha="center",
va="bottom",
fontsize=12,
color="white",
fontweight="bold",
)
plt.xticks(planes, fontsize=12, fontweight="bold", color="white")
plt.yticks(fontsize=12, fontweight="bold", color="white")
plt.grid(axis="y", linestyle="--", alpha=0.4, color="white")
ax.spines["bottom"].set_color("white")
ax.spines["left"].set_color("white")
ax.spines["top"].set_color("white")
ax.spines["right"].set_color("white")
plt.legend(
fontsize=12,
facecolor="black",
edgecolor="white",
labelcolor="white",
loc="upper left",
bbox_to_anchor=(1, 1),
)
plt.savefig(savepath, bbox_inches="tight", facecolor="black")
plt.show()
[docs]
def plot_volume_signal(zstats, savepath):
"""
Plots the mean fluorescence signal per z-plane with standard deviation error bars.
This function loads signal statistics from a `.npy` file and visualizes the mean
fluorescence signal per z-plane, with error bars representing the standard deviation.
Parameters
----------
zstats : str or Path
Path to the `.npy` file containing the volume stats. The output of `get_zstats()`.
savepath : str or Path
Path to save the generated figure.
Notes
-----
- The `.npy` file should contain structured data with `plane`, `mean_trace`, and `std_trace` fields.
- Error bars represent the standard deviation of the fluorescence signal.
"""
plane_stats = np.load(zstats)
planes = plane_stats["plane"]
mean_signal = plane_stats["mean_trace"]
std_signal = plane_stats["std_trace"]
fig, ax = plt.figure(figsize=(10, 6), facecolor="black")
ax.set_facecolor("black")
plt.xlabel("Z-Plane", fontsize=14, fontweight="bold", color="white")
plt.ylabel("Mean Raw Signal", fontsize=14, fontweight="bold", color="white")
plt.title(
"Mean Fluorescence Signal per Z-Plane",
fontsize=16,
fontweight="bold",
color="white",
)
plt.errorbar(
planes,
mean_signal,
yerr=std_signal,
fmt="o-",
color="cyan",
ecolor="lightblue",
elinewidth=2,
capsize=4,
markersize=6,
alpha=0.8,
label="Mean ± STD",
)
plt.xticks(planes, fontsize=12, fontweight="bold", color="white")
plt.yticks(fontsize=12, fontweight="bold", color="white")
plt.grid(axis="y", linestyle="--", alpha=0.4, color="white")
ax.spines["bottom"].set_color("white")
ax.spines["left"].set_color("white")
ax.spines["top"].set_color("white")
ax.spines["right"].set_color("white")
plt.legend(fontsize=12, facecolor="black", edgecolor="white", labelcolor="white")
plt.savefig(savepath, bbox_inches="tight", facecolor="black")
plt.close(fig)
def plot_volume_neuron_counts(zstats, savepath):
"""
Plots the number of accepted and rejected neurons per z-plane.
This function loads neuron count data from a `.npy` file and visualizes the
accepted vs. rejected neurons as a stacked bar plot with a black background.
Parameters
----------
zstats : str, Path
Full path to the zstats.npy file.
savepath : str or Path
Path to directory where generated figure will be saved.
Notes
-----
- The `.npy` file should contain structured data with `plane`, `accepted`, and `rejected` fields.
"""
zstats = Path(zstats)
if not zstats.is_file():
raise FileNotFoundError(f"{zstats} is not a valid zstats.npy file.")
plane_stats = np.load(zstats)
savepath = Path(savepath)
planes = plane_stats["plane"]
accepted = plane_stats["accepted"]
rejected = plane_stats["rejected"]
savename = savepath.joinpath(
f"all_neurons_{accepted.sum()}acc_{rejected.sum()}rej.png"
)
plt.figure(figsize=(10, 6), facecolor="black")
ax = plt.gca()
ax.set_facecolor("black")
plt.xlabel("Z-Plane", fontsize=14, fontweight="bold", color="white")
plt.ylabel("Number of Neurons", fontsize=14, fontweight="bold", color="white")
plt.title(
"Accepted vs. Rejected Neurons per Z-Plane",
fontsize=16,
fontweight="bold",
color="white",
)
bars1 = plt.bar(
planes, accepted, label="Accepted Neurons", alpha=0.8, color="#4CAF50"
) # Light green
bars2 = plt.bar(
planes,
rejected,
label="Rejected Neurons",
alpha=0.8,
bottom=accepted,
color="#F57C00",
) # Light orange
for bar in bars1:
height = bar.get_height()
if height > 0:
plt.text(
bar.get_x() + bar.get_width() / 2,
height / 2,
f"{int(height)}",
ha="center",
va="center",
fontsize=12,
color="white",
fontweight="bold",
)
for bar1, bar2 in zip(bars1, bars2):
height1 = bar1.get_height()
height2 = bar2.get_height()
if height2 > 0:
plt.text(
bar2.get_x() + bar2.get_width() / 2,
height1 + height2 / 2,
f"{int(height2)}",
ha="center",
va="center",
fontsize=12,
color="white",
fontweight="bold",
)
plt.xticks(planes, fontsize=12, fontweight="bold", color="white")
plt.yticks(fontsize=12, fontweight="bold", color="white")
plt.grid(axis="y", linestyle="--", alpha=0.4, color="white")
ax.spines["bottom"].set_color("white")
ax.spines["left"].set_color("white")
ax.spines["top"].set_color("white")
ax.spines["right"].set_color("white")
plt.legend(fontsize=12, facecolor="black", edgecolor="white", labelcolor="white")
plt.savefig(savename, bbox_inches="tight", facecolor="black")
def get_volume_stats(ops_files: list[str | Path], overwrite: bool = True):
"""
Given a list of ops.npy files, accumulate common statistics for assessing zplane quality.
Parameters
----------
ops_files : list of str or Path
Each item in the list should be a path pointing to a z-lanes `ops.npy` file.
The number of items in this list should match the number of z-planes in your session.
overwrite : bool
If a file already exists, it will be overwritten. Defaults to True.
Notes
-----
- The `.npy` file should contain structured data with `plane`, `accepted`, and `rejected` fields.
"""
if not ops_files:
print("No ops files found.")
return None
plane_stats = {}
for i, file in enumerate(ops_files):
output_ops = load_ops(file)
raw_z = output_ops.get("zplane")
if raw_z is None:
zplane_num = i + 1
else:
zplane_num = int(str(raw_z).removeprefix("plane"))
save_path = Path(output_ops["save_path"])
iscell = np.load(save_path / "iscell.npy", allow_pickle=True)[:, 0].astype(bool)
traces = np.load(save_path / "F.npy", allow_pickle=True)
timing = output_ops.get("timing", {})
plane_stats[zplane_num] = {
"accepted": iscell.sum(),
"rejected": (~iscell).sum(),
"mean": traces.mean(),
"std": traces.std(),
"registration": timing.get("registration", np.nan),
"detection": timing.get("detection", timing.get("detect", np.nan)),
"extraction": timing.get("extraction", np.nan),
"classification": timing.get("classification", np.nan),
"deconvolution": timing.get("deconvolution", np.nan),
"total_runtime": timing.get("total_plane_runtime", np.nan),
"filepath": str(file),
"zplane": zplane_num,
}
common = get_common_path(ops_files)
out = []
for p, stats in sorted(plane_stats.items()):
out.append(
(
p,
stats["accepted"],
stats["rejected"],
stats["mean"],
stats["std"],
stats["registration"],
stats["detection"],
stats["extraction"],
stats["classification"],
stats["deconvolution"],
stats["total_runtime"],
stats["filepath"],
stats["zplane"],
)
)
dtype = [
("plane", "i4"),
("accepted", "i4"),
("rejected", "i4"),
("mean_trace", "f8"),
("std_trace", "f8"),
("registration", "f8"),
("detection", "f8"),
("extraction", "f8"),
("classification", "f8"),
("deconvolution", "f8"),
("total_plane_runtime", "f8"),
("filepath", "U255"),
("zplane", "i4"),
]
arr = np.array(out, dtype=dtype)
save_path = Path(common) / "zstats.npy"
if overwrite or not save_path.exists():
np.save(save_path, arr)
return str(save_path)
[docs]
def save_images_to_movie(image_input, savepath, duration=None, format=".mp4"):
"""
Convert a sequence of saved images into a movie.
TODO: move to mbo_utilities.
Parameters
----------
image_input : str, Path, or list
Directory containing saved segmentation images or a list of image file paths.
savepath : str or Path
Path to save the video file.
duration : int, optional
Desired total video duration in seconds. If None, defaults to 1 FPS (1 image per second).
format : str, optional
Video format: ".mp4" (PowerPoint-compatible), ".avi" (lossless), ".mov" (ProRes). Default is ".mp4".
Examples
--------
>>> import mbo_utilities as mbo
>>> import lbm_suite2p_python as lsp
Get all png files autosaved during LBM-Suite2p-Python `run_volume()`
>>> segmentation_pngs = mbo.get_files("path/suite3d/results/", "segmentation.png", max_depth=3)
>>> lsp.save_images_to_movie(segmentation_pngs, "path/to/save/segmentation.png", format=".mp4")
"""
savepath = Path(savepath).with_suffix(format) # Ensure correct file extension
temp_video = savepath.with_suffix(".avi") # Temporary AVI file for MOV conversion
savepath.parent.mkdir(parents=True, exist_ok=True)
if isinstance(image_input, (str, Path)):
image_dir = Path(image_input)
image_files = sorted(
glob.glob(str(image_dir / "*.png"))
+ glob.glob(str(image_dir / "*.jpg"))
+ glob.glob(str(image_dir / "*.tif"))
)
elif isinstance(image_input, list):
image_files = sorted(map(str, image_input))
else:
raise ValueError(
"image_input must be a directory path or a list of file paths."
)
if not image_files:
return
first_image = cv2.imread(image_files[0])
height, width, _ = first_image.shape
fps = len(image_files) / duration if duration else 1
if format == ".mp4":
fourcc = cv2.VideoWriter_fourcc(*"XVID")
video_path = savepath
elif format == ".avi":
fourcc = cv2.VideoWriter_fourcc(*"HFYU")
video_path = savepath
elif format == ".mov":
fourcc = cv2.VideoWriter_fourcc(*"HFYU")
video_path = temp_video
else:
raise ValueError("Invalid format. Use '.mp4', '.avi', or '.mov'.")
video_writer = cv2.VideoWriter(
str(video_path), fourcc, max(fps, 1), (width, height)
)
for image_file in image_files:
frame = cv2.imread(image_file)
video_writer.write(frame)
video_writer.release()
if format == ".mp4":
ffmpeg_cmd = [
"ffmpeg",
"-y",
"-i",
str(video_path),
"-vcodec",
"libx264",
"-acodec",
"aac",
"-preset",
"slow",
"-crf",
"18",
str(savepath), # Save directly to `savepath`
]
subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
print(f"MP4 saved at {savepath}")
elif format == ".mov":
ffmpeg_cmd = [
"ffmpeg",
"-y",
"-i",
str(temp_video),
"-c:v",
"prores_ks", # Use Apple ProRes codec
"-profile:v",
"3", # ProRes 422 LT
"-pix_fmt",
"yuv422p10le",
str(savepath),
]
subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
temp_video.unlink()
def get_fcells_list(ops_list: list):
if not isinstance(ops_list, list):
raise ValueError("`ops_list` must be a list")
f_cells_list = []
for ops in ops_list:
ops = load_ops(ops)
f_cells = np.load(Path(ops["save_path"]).joinpath("F.npy"))
f_cells_list.append(f_cells)
return f_cells_list
def collect_result_png(ops_list):
if not isinstance(ops_list, list):
raise ValueError("`ops_list` must be a list")
png_list = []
for ops in ops_list:
ops = load_ops(ops)
f_cells = np.load(Path(ops["save_path"]).joinpath("segmentation.png"))
png_list.append(f_cells)
return png_list