Source code for lbm_caiman_python.helpers

import os
import shutil
import sys
import tempfile

import cv2
import scipy
import tifffile
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from typing import Any as ArrayLike

import caiman as cm
from tqdm import tqdm

from .lcp_io import get_metrics_path


[docs] def get_single_patch_coords(dims, stride, overlap, patch_index): """ Get coordinates of a single patch based on stride, overlap parameters of motion-correction. Parameters ---------- dims : tuple Dimensions of the image as (rows, cols). stride : int Number of pixels to include in each patch. overlap : int Number of pixels to overlap between patches. patch_index : tuple Index of the patch to return. """ patch_height = stride + overlap patch_width = stride + overlap rows = np.arange(0, dims[0] - patch_height + 1, stride) cols = np.arange(0, dims[1] - patch_width + 1, stride) row_idx, col_idx = patch_index y_start = rows[row_idx] x_start = cols[col_idx] return y_start, y_start + patch_height, x_start, x_start + patch_width
def _pad_image_for_even_patches(image, stride, overlap): patch_width = stride + overlap padded_x = int(np.ceil(image.shape[0] / patch_width) * patch_width) - image.shape[0] padded_y = int(np.ceil(image.shape[1] / patch_width) * patch_width) - image.shape[1] return np.pad(image, ((0, padded_x), (0, padded_y)), mode='constant'), padded_x, padded_y
[docs] def calculate_num_patches(image, stride, overlap): """ Calculate the total number of patches in an image given stride and rf. Parameters ---------- image_size : tuple Size of the image as (df, cols). stride : int Half-size of the patches in pixels (patch width is rf*2 + 1). overlap : int Amount of overlap between patches in pixels. Returns ------- int Total number of patches. """ from caiman.utils.visualization import get_rectangle_coords # pad the image like caiman does padded_image, pad_x, pad_y = _pad_image_for_even_patches(image, stride, overlap) # Get patch coordinates patch_rows, patch_cols = get_rectangle_coords(padded_image.shape, stride, overlap) return len(patch_rows) * len(patch_cols)
def calculate_neurons_per_patch(rf, pixel_resolution, neuron_density): """ Calculate the expected number of neurons in a 2D patch. Parameters ---------- rf : int The receptive field size in pixels. pixel_resolution : tuple The resolution of the image in microns per pixel. neuron_density : float The density of neurons in the image in neurons per square micron. """ row_size, col_size = pixel_resolution surface_density = neuron_density # Already in neurons per square micron for a 2D slice # Patch size in pixels patch_rows = 2 * rf + 1 patch_cols = 2 * rf + 1 # Patch area in microns^2 patch_area = (patch_rows * row_size) * (patch_cols * col_size) # Expected neurons per patch expected_neurons = patch_area * surface_density return expected_neurons
[docs] def generate_patch_view(image: ArrayLike, pixel_resolution: float, target_patch_size: int = 40, overlap_fraction: float = 0.5): """ Generate a patch visualization for a 2D image with approximately square patches of a specified size in microns. Patches are evenly distributed across the image, using calculated strides and overlaps. Parameters ---------- image : ndarray A 2D NumPy array representing the input image to be divided into patches. pixel_resolution : float The pixel resolution of the image in microns per pixel. target_patch_size : float, optional The desired size of the patches in microns. Default is 40 microns. overlap_fraction : float, optional The fraction of the patch size to use as overlap between patches. Default is 0.5 (50%). Returns ------- fig : matplotlib.figure.Figure A matplotlib figure containing the patch visualization. ax : matplotlib.axes.Axes A matplotlib axes object showing the patch layout on the image. Examples -------- >>> import numpy as np >>> from matplotlib import pyplot as plt >>> data = np.random.random((144, 600)) # Example 2D image >>> pixel_resolution = 0.5 # Microns per pixel >>> fig, ax = generate_patch_view(data, pixel_resolution) >>> plt.show() """ from caiman.utils.visualization import get_rectangle_coords, rect_draw # Calculate stride and overlap in pixels stride = int(target_patch_size / pixel_resolution) overlap = int(overlap_fraction * stride) # pad the image like caiman does def pad_image_for_even_patches(image, stride, overlap): patch_width = stride + overlap padded_x = int(np.ceil(image.shape[0] / patch_width) * patch_width) - image.shape[0] padded_y = int(np.ceil(image.shape[1] / patch_width) * patch_width) - image.shape[1] return np.pad(image, ((0, padded_x), (0, padded_y)), mode='constant'), padded_x, padded_y padded_image, pad_x, pad_y = pad_image_for_even_patches(image, stride, overlap) # Get patch coordinates patch_rows, patch_cols = get_rectangle_coords(padded_image.shape, stride, overlap) fig, ax = plt.subplots(figsize=(8, 8)) ax.imshow(padded_image, cmap='gray') # Draw patches using rect_draw for patch_row in patch_rows: for patch_col in patch_cols: rect_draw(patch_row, patch_col, color='white', alpha=0.2, ax=ax) ax.set_title(f"Stride: {stride} pixels (~{stride * pixel_resolution:.1f} μm)\n" f"Overlap: {overlap} pixels (~{overlap * pixel_resolution:.1f} μm)\n") plt.tight_layout() return fig, ax, stride, overlap
def _compute_metrics(fname, uuid, batch_id, final_size_x, final_size_y, swap_dim=False, pyr_scale=.5, levels=3, winsize=100, iterations=15, poly_n=5, poly_sigma=1.2 / 5, flags=0, resize_fact_flow=.2, template=None, gSig_filt=None): """ Compute metrics for a given movie file. """ if not uuid: raise ValueError("UUID must be provided.") m = cm.load(fname) if gSig_filt is not None: m = cm.motion_correction.high_pass_filter_space(m, gSig_filt) max_shft_x = int(np.ceil((np.shape(m)[1] - final_size_x) / 2)) max_shft_y = int(np.ceil((np.shape(m)[2] - final_size_y) / 2)) max_shft_x_1 = - ((np.shape(m)[1] - max_shft_x) - (final_size_x)) max_shft_y_1 = - ((np.shape(m)[2] - max_shft_y) - (final_size_y)) if max_shft_x_1 == 0: max_shft_x_1 = None if max_shft_y_1 == 0: max_shft_y_1 = None m = m[:, max_shft_x:max_shft_x_1, max_shft_y:max_shft_y_1] if np.sum(np.isnan(m)) > 0: raise Exception('Movie contains NaN') img_corr = m.local_correlations(eight_neighbours=True, swap_dim=swap_dim) if template is None: tmpl = cm.motion_correction.bin_median(m) else: tmpl = template smoothness = np.sqrt( np.sum(np.sum(np.array(np.gradient(np.mean(m, 0))) ** 2, 0))) smoothness_corr = np.sqrt( np.sum(np.sum(np.array(np.gradient(img_corr)) ** 2, 0))) correlations = [] count = 0 sys.stdout.flush() for fr in tqdm(m, desc="Correlations"): count += 1 correlations.append(scipy.stats.pearsonr( fr.flatten(), tmpl.flatten())[0]) m = m.resize(1, 1, resize_fact_flow) norms = [] flows = [] count = 0 sys.stdout.flush() for fr in tqdm(m, desc="Optical flow"): count += 1 flow = cv2.calcOpticalFlowFarneback( tmpl, fr, None, pyr_scale, levels, winsize, iterations, poly_n, poly_sigma, flags) n = np.linalg.norm(flow) flows.append(flow) norms.append(n) # cast to numpy-loadable primatives, handle variable cases of None uuid = str(uuid) if uuid not in [None, 'None', 'nan'] else 'None' batch_id = int(batch_id) if batch_id not in [None, 'None', 'nan'] else -1 np.savez( os.path.splitext(fname)[0] + '_metrics', uuid=uuid, batch_id=batch_id, flows=flows, norms=norms, correlations=correlations, smoothness=smoothness, tmpl=tmpl, smoothness_corr=smoothness_corr, img_corr=img_corr ) def _compute_raw_mcorr_metrics(raw_fname: Path, overwrite=False) -> Path: """ Wrapper for caiman.motion_correction.compute_metrics_motion_correction. Writes raw_file to a temporary memmapped file to run compute_metrics_motion_correction, and move the metrics file back to the fname directory. Needed due to compute_metrics_motion_correction not accepting memmapped files, just filenames. Parameters ---------- raw_fname : Path The path to the raw data file. Must be a TIFF file. overwrite : bool, optional If True, recompute the metrics even if the file already exists. Default is False. Returns ------- final_metrics_path : Path The path to the computed metrics file. Notes ----- The final metrics files contains the following keys: - 'correlations': The correlation coefficients between frames. - 'flows': The flow vectors between frames. - 'norms': A list of magnitudes of optical flow for each frame. Represents the amount of motion in each frame. - 'smoothness': A measure of the sharpness of the image. """ # make a new uuid with raw_{uuid} import uuid raw_uuid = f'raw_{uuid.uuid4()}' final_metrics_path = get_metrics_path(raw_fname) if final_metrics_path.exists() and not overwrite: return final_metrics_path data = tifffile.memmap(raw_fname) if final_metrics_path.exists() and overwrite: final_metrics_path.unlink() with tempfile.NamedTemporaryFile(suffix='.tiff', delete=False) as temp_file: temp_path = Path(temp_file.name) try: tifffile.imwrite(temp_path, data) _ = _compute_metrics(temp_path, raw_uuid, None, data.shape[1], data.shape[2], swap_dim=False) temp_metrics_path = get_metrics_path(temp_path) if temp_metrics_path.exists(): shutil.move(temp_metrics_path, final_metrics_path) else: raise FileNotFoundError(f"Expected metrics file {temp_metrics_path} not found.") finally: temp_path.unlink(missing_ok=True) return final_metrics_path