Source code for experanto.utils

import bisect
import logging

# inbuilt libraries
from collections import defaultdict

# third-party libraries
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, Sampler

# local libraries

logger = logging.getLogger(__name__)


def replace_nan_with_batch_mean(data: np.ndarray) -> np.ndarray:
    row, col = np.where(np.isnan(data))
    for i, j in zip(row, col, strict=True):
        new_value = np.nanmean(data[:, j])
        data[i, j] = new_value if not np.isnan(new_value) else 0
    return data


[docs] def add_behavior_as_channels(data: dict[str, torch.Tensor]) -> dict: """Add behavioral data as additional channels to screen data. Parameters ---------- data : dict Dictionary with keys 'screen', 'eye_tracker', 'treadmill'. Screen shape: ``(C, T, H, W)``. Behavior shapes: ``(T, C_behavior)`` or ``(T, H, W)``. Returns ------- dict Modified dictionary with behavior concatenated to screen channels. Screen shape becomes ``(C + behavior_channels, T, H, W)``. """ screen = data["screen"] # Already contiguous, shape (c, t, h, w) c, t, h, w = screen.shape eye_tracker = data["eye_tracker"] treadmill = data["treadmill"] # Process eye_tracker if len(eye_tracker.shape) == 2: # (t, c_eye) # Reshape to (c_eye, t, h, w) eye_tracker = eye_tracker.transpose(0, 1) # (c_eye, t) eye_tracker = eye_tracker.unsqueeze(-1).unsqueeze(-1) # (c_eye, t, 1, 1) eye_tracker = eye_tracker.expand(-1, -1, h, w).contiguous() # (c_eye, t, h, w) else: # (t, h, w) # Reshape to (1, t, h, w) eye_tracker = eye_tracker.unsqueeze(0).contiguous() # (1, t, h, w) # Process treadmill if len(treadmill.shape) == 2: # (t, c_tread) # Reshape to (c_tread, t, h, w) treadmill = treadmill.transpose(0, 1) # (c_tread, t) treadmill = treadmill.unsqueeze(-1).unsqueeze(-1) # (c_tread, t, 1, 1) treadmill = treadmill.expand(-1, -1, h, w).contiguous() # (c_tread, t, h, w) else: # (t, h, w) # Reshape to (1, t, h, w) treadmill = treadmill.unsqueeze(0).contiguous() # (1, t, h, w) # Concatenate along the channel dimension (dim=0) and ensure the result is contiguous result = torch.cat([screen, eye_tracker, treadmill], dim=0) # Ensure the result is contiguous if not result.is_contiguous(): result = result.contiguous() data["screen"] = result return data
[docs] class MultiEpochsDataLoader(DataLoader): """DataLoader that keeps workers alive across epochs. Solves a bug where worker processes are re-spawned at the start of each epoch, causing significant overhead. Workers are initialized once and reused throughout training. Parameters ---------- *args Positional arguments forwarded to :class:`torch.utils.data.DataLoader`. shuffle_each_epoch : bool, default=False If True and the underlying dataset has a ``shuffle_valid_screen_times`` method, that method is called at the start of every epoch. **kwargs Keyword arguments forwarded to :class:`torch.utils.data.DataLoader`. References ---------- https://discuss.pytorch.org/t/enumerate-dataloader-slow/87778 https://github.com/huggingface/pytorch-image-models/blob/d72ac0db259275233877be8c1d4872163954dfbb/timm/data/loader.py#L209-L238 """
[docs] def __init__( self, *args, shuffle_each_epoch=False, **kwargs, ): super().__init__(*args, **kwargs) self._DataLoader__initialized = False self.batch_sampler = _RepeatSampler(self.batch_sampler) self._DataLoader__initialized = True self.iterator = super().__iter__() self.shuffle_each_epoch = shuffle_each_epoch
def __len__(self): return len(self.batch_sampler) def __iter__(self): # type: ignore[override] if self.shuffle_each_epoch and hasattr( self.dataset, "shuffle_valid_screen_times" ): self.dataset.shuffle_valid_screen_times() # type: ignore[union-attr] for _i in range(len(self)): yield next(self.iterator)
# borrowed with <3 from # https://github.com/sinzlab/neuralpredictors/blob/main/neuralpredictors/training/cyclers.py def cycle(iterable): # see https://github.com/pytorch/pytorch/issues/23900 iterator = iter(iterable) while True: try: yield next(iterator) except StopIteration: iterator = iter(iterable) class Exhauster: """ Given a dictionary of data loaders, mapping data_key into a data loader, steps through each data loader, moving onto the next data loader only upon exhausing the content of the current data loader. """ def __init__(self, loaders): self.loaders = loaders def __iter__(self): for data_key, loader in self.loaders.items(): for batch in loader: yield data_key, batch def __len__(self): return sum([len(loader) for loader in self.loaders])
[docs] class LongCycler: """Cycle through multiple dataloaders until the longest is exhausted. Useful for training with multiple sessions of unequal size. Cycles through all loaders, yielding ``(session_key, batch)`` pairs. Shorter loaders are recycled until the longest loader completes one full epoch. Parameters ---------- loaders : dict Dictionary mapping session keys to DataLoader instances. Attributes ---------- max_batches : int Number of batches in the longest loader. Examples -------- >>> loaders = {'session_1': loader1, 'session_2': loader2} >>> cycler = LongCycler(loaders) >>> for session_key, batch in cycler: ... print(f"Processing {session_key}") """
[docs] def __init__(self, loaders): self.loaders = loaders self.max_batches = max([len(loader) for loader in self.loaders.values()])
def __iter__(self): cycles = [cycle(loader) for loader in self.loaders.values()] for k, loader, _ in zip( cycle(self.loaders.keys()), (cycle(cycles)), range(len(self.loaders) * self.max_batches), strict=True, ): yield k, next(loader) def __len__(self): return len(self.loaders) * self.max_batches
[docs] class ShortCycler: """Cycle through multiple dataloaders until the shortest is exhausted. Similar to :class:`LongCycler`, but stops when the smallest loader completes one epoch. No recycling occurs. Parameters ---------- loaders : dict Dictionary mapping session keys to DataLoader instances. Attributes ---------- min_batches : int Number of batches in the shortest loader. """
[docs] def __init__(self, loaders): self.loaders = loaders self.min_batches = min([len(loader) for loader in self.loaders.values()])
def __iter__(self): cycles = [cycle(loader) for loader in self.loaders.values()] for k, loader, _ in zip( cycle(self.loaders.keys()), (cycle(cycles)), range(len(self.loaders) * self.min_batches), strict=True, ): yield k, next(loader) def __len__(self): return len(self.loaders) * self.min_batches
class _RepeatSampler: """Simple sampler that repeats indefinitely.""" def __init__(self, sampler): self.sampler = sampler def __len__(self): """Return the length of the original sampler.""" return len(self.sampler) def __iter__(self): while True: yield from iter(self.sampler)
[docs] class SessionConcatDataset(Dataset): """Memory-efficient concatenated dataset that reliably tracks sessions."""
[docs] def __init__(self, datasets, session_names=None): """Initialize the concatenated dataset with session tracking.""" if not datasets: raise ValueError("datasets is empty") # Store datasets self.datasets = list(datasets) # Track session names if session_names is None: session_names = [f"session_{i}" for i in range(len(datasets))] self.session_names = session_names # Log dataset sizes for debugging for i, (name, dataset) in enumerate(zip(session_names, datasets, strict=True)): logger.debug("Dataset %s: %s, length = %s", i, name, len(dataset)) # Compute cumulative sizes for efficient indexing self.cumulative_sizes = [] current_size = 0 for dataset in self.datasets: current_size += len(dataset) self.cumulative_sizes.append(current_size) # Create session indices dictionary for fast lookup self.session_indices = {} start_idx = 0 for i, dataset in enumerate(datasets): session_name = session_names[i] session_size = len(dataset) self.session_indices[session_name] = ( start_idx, start_idx + session_size, ) start_idx += session_size
[docs] def __len__(self): """Return total length of the concatenated dataset.""" return self.cumulative_sizes[-1] if self.cumulative_sizes else 0
[docs] def __getitem__(self, idx): """Get item from the appropriate dataset.""" if idx < 0 or idx >= len(self): raise IndexError( f"Index {idx} out of range for dataset of size {len(self)}" ) # Find which dataset the index belongs to dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] # Get the data from the dataset data = self.datasets[dataset_idx][sample_idx] # Return the data along with session information return data
[docs] def get_session_for_idx(self, idx): """Get the session name for a given index.""" dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) return self.session_names[dataset_idx]
[docs] def get_indices_for_session(self, session_name): """Get all indices belonging to a given session.""" if session_name in self.session_indices: start, end = self.session_indices[session_name] return list(range(start, end)) return []
[docs] def get_sessions_count(self): """Get number of sessions and count of samples per session.""" return { name: end - start for name, (start, end) in self.session_indices.items() }
[docs] class SessionBatchSampler(Sampler): """ A batch sampler that cycles through sessions, ensuring each session appears exactly once before repeating any session. """
[docs] def __init__(self, dataset, batch_size, drop_last=False, shuffle=False, seed=None): """Initialize session batch sampler. Parameters ---------- dataset : SessionConcatDataset The dataset to sample from. batch_size : int Number of samples per batch. drop_last : bool, optional Whether to drop the last batch if smaller than batch_size. Default is False. shuffle : bool, optional Whether to shuffle samples within each session. Default is False. seed : int, optional Random seed for reproducibility. """ self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last self.shuffle = shuffle self.seed = seed # Use its own RNG instance based on the provided seed self.rng = ( np.random.RandomState(seed) if seed is not None else np.random.RandomState() ) self.prv_rng_state = None # Get sessions self.session_names = list(dataset.session_indices.keys()) logger.debug("Sessions: %s", self.session_names) self.consumed_sessions = [] # Pre-process session indices self.session_indices = {} for session_name in self.session_names: indices = dataset.get_indices_for_session(session_name) if indices: self.session_indices[session_name] = indices # Calculate batches per session self.batches_per_session = {} total_batches = 0 for session_name, indices in self.session_indices.items(): session_size = len(indices) if drop_last: num_batches = session_size // batch_size else: num_batches = (session_size + batch_size - 1) // batch_size self.batches_per_session[session_name] = num_batches total_batches += num_batches logger.debug("Batches per session: %s", self.batches_per_session) logger.debug("Total batches: %s", total_batches)
[docs] def __len__(self): """Return the total number of batches across all sessions.""" return sum(self.batches_per_session.values())
[docs] def get_session_cycle(self): """ Generate one cycle of sessions, with each session appearing exactly once. Sessions are shuffled unless their appearance order needs to be controlled. """ order = list(self.session_names) self.prv_rng_state = self.rng.get_state() if self.shuffle: self.rng.shuffle(order) # Remove consumed sessions from order for session_name in self.consumed_sessions: order.remove(session_name) return order
[docs] def get_state(self): """Return the state of the sampler (including RNG state).""" return { "prv_rng_state": self.prv_rng_state, "consumed_sessions": self.consumed_sessions, }
[docs] def set_state(self, state): """Restore the state of the sampler (including RNG state).""" rng_state = state.get("prv_rng_state") if rng_state is not None and self.rng is not None: self.rng.set_state(rng_state) self.prv_rng_state = None self.consumed_sessions = state.get("consumed_sessions", [])
[docs] class FastSessionDataLoader: """Optimized multi-session dataloader with state tracking. Provides efficient data loading across multiple sessions with guarantees: - Each session appears exactly once before repeating - Epoch ends when the longest session is exhausted - Perfect alignment between sessions and batches is maintained - State is properly tracked and can be restored Parameters ---------- dataset : SessionConcatDataset Concatenated dataset with session tracking. batch_size : int, default=1 Number of samples per batch. shuffle : bool, default=False Whether to shuffle samples within each session. num_workers : int, default=0 Number of worker processes for data loading. pin_memory : bool, default=False Whether to pin memory for GPU transfer. drop_last : bool, default=False Whether to drop incomplete batches. seed : int, optional Random seed for reproducibility. **kwargs Additional arguments passed to underlying DataLoaders. Attributes ---------- session_names : list Names of all sessions in the dataset. batches_per_session : dict Number of batches in each session. See Also -------- SessionConcatDataset : Dataset that tracks session membership. LongCycler : Simpler alternative without state tracking. """
[docs] def __init__( self, dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False, drop_last=False, seed=None, **kwargs, ): # Store dataset and parameters self.dataset = dataset self.batch_size = batch_size self.shuffle = shuffle self.drop_last = drop_last self.seed = seed self.num_workers = num_workers self.pin_memory = pin_memory self.kwargs = kwargs # Create batch sampler self.batch_sampler = SessionBatchSampler( dataset=dataset, batch_size=batch_size, drop_last=drop_last, shuffle=shuffle, seed=seed, ) # Store session info for faster access self.session_names = self.batch_sampler.session_names self.session_indices = self.batch_sampler.session_indices self.batches_per_session = self.batch_sampler.batches_per_session # Compute maximum batches per session (for epoch tracking) self.max_batches_per_session = ( max(self.batches_per_session.values()) if self.batches_per_session else 0 ) # Prepare session data loaders to avoid recreating them for each batch self.session_dataloaders = {} for i, session_name in enumerate(self.session_names): indices = self.session_indices[session_name] # Derive a unique seed for each session sampler session_seed = None if seed is None else seed + i + 1 # Create a specific sampler for this session session_sampler = SessionSpecificSampler( indices=indices, batch_size=batch_size, drop_last=drop_last, shuffle=shuffle, seed=session_seed, ) # Create a DataLoader for this session self.session_dataloaders[session_name] = DataLoader( dataset=dataset, batch_sampler=session_sampler, num_workers=num_workers, pin_memory=pin_memory, **kwargs, ) # State tracking variables self.current_batch = 0 self.epoch = 0 self.session_positions = dict.fromkeys(self.session_names, 0) self.batches_from_session = defaultdict( int ) # Tracks batches yielded per session in the current epoch iteration # Track active sessions self.active_sessions = set(self.session_names) logger.debug( "Created FastSessionDataLoader with %s sessions and %s total batches", len(self.session_names), len(self), )
[docs] def __len__(self): """Return the total number of batches in an epoch.""" return sum(self.batches_per_session.values())
[docs] def get_state(self): """Return the current state of the dataloader.""" return { "current_batch": self.current_batch, "epoch": self.epoch, "session_positions": self.session_positions.copy(), "batches_from_session": self.batches_from_session.copy(), "active_sessions": list( self.active_sessions ), # Store as list for serialization "batch_sampler_state": self.batch_sampler.get_state(), "session_sampler_states": { name: dl.batch_sampler.get_state() for name, dl in self.session_dataloaders.items() }, }
[docs] def set_state(self, state): """Restore the dataloader state.""" if not state: return # Restore batch counter self.current_batch = state.get("current_batch", 0) # Restore epoch counter self.epoch = state.get("epoch", 0) # Restore session positions session_positions = state.get("session_positions") if session_positions: self.session_positions = session_positions # Restore RNG state for the main dataloader dataloader_rng_state = state.get("dataloader_rng_state") if ( dataloader_rng_state is not None and hasattr(self, "rng") and self.rng is not None # type: ignore[attr-defined] ): self.rng.set_state(dataloader_rng_state) # type: ignore[attr-defined] # Restore RNG state for the batch sampler batch_sampler_state = state.get("batch_sampler_state") if batch_sampler_state is not None and hasattr(self.batch_sampler, "set_state"): self.batch_sampler.set_state(batch_sampler_state) # Restore batches_from_session state batches_from_session_state = state.get("batches_from_session") if batches_from_session_state is not None: self.batches_from_session = defaultdict(int) self.batches_from_session.update(batches_from_session_state) else: # For backward compatibility or if state doesn't have it self.batches_from_session = defaultdict(int) # Restore active sessions active_sessions_list = state.get("active_sessions") if active_sessions_list is not None: self.active_sessions = set(active_sessions_list) else: # Default to all sessions if not in state (for backward compatibility) self.active_sessions = set(self.session_names) # Reset session iterators with new positions for session_name, dataloader in self.session_dataloaders.items(): # Get sampler and reset its position sampler = dataloader.batch_sampler if hasattr(sampler, "set_position"): position = self.session_positions.get(session_name, 0) sampler.set_position(position) # Restore RNG state for each session sampler session_sampler_states = state.get("session_sampler_states", {}) sampler_state = session_sampler_states.get(session_name) if sampler_state is not None and hasattr(sampler, "set_state"): sampler.set_state(sampler_state) logger.info( "Restored dataloader state to batch %s, epoch %s", self.current_batch, self.epoch, )
[docs] def __iter__(self): """ Iterate through sessions, cycling through them until all are exhausted. The iteration scheme ensures: 1. Each session appears exactly once in each cycle 2. Samples within a session are properly batched and optionally shuffled 3. The epoch ends when the longest session is exhausted """ # Track active sessions active_sessions = self.active_sessions # Track position within each session position_in_epoch = 0 batches_from_session = self.batches_from_session # Reset session positions if needed for session_name in self.session_names: if self.session_positions.get( session_name, 0 ) >= self.batches_per_session.get(session_name, 0): self.session_positions[session_name] = 0 # Reset iterators with current positions session_iterators = {} for session_name, dataloader in self.session_dataloaders.items(): # Reset sampler position sampler = dataloader.batch_sampler if hasattr(sampler, "set_position"): sampler.set_position(self.session_positions.get(session_name, 0)) # Create iterator session_iterators[session_name] = iter(dataloader) # Continue until we've gone through one full epoch # (i.e., until the longest session is exhausted) while active_sessions and position_in_epoch < self.max_batches_per_session: # Create a cycle order of sessions cycle_order = self.batch_sampler.get_session_cycle() # Process one batch from each active session in this cycle for session_name in cycle_order: # Skip if session is already exhausted if session_name not in active_sessions: continue # Skip if we've already processed all batches for this session in the current epoch if batches_from_session[session_name] >= self.batches_per_session.get( session_name, 0 ): active_sessions.remove(session_name) continue # Get iterator for this session iterator = session_iterators.get(session_name) if iterator is None: continue try: # Get the next batch from this session batch = next(iterator) # Update state tracking self.current_batch += 1 self.session_positions[session_name] += 1 batches_from_session[session_name] += 1 # Update local dictionary # Yield session name and batch yield session_name, batch except StopIteration: # This session is exhausted for the current epoch active_sessions.remove(session_name) self.batch_sampler.consumed_sessions.append(session_name) self.batch_sampler.consumed_sessions = [] # If we've completed a full cycle, increment the position counter position_in_epoch += 1 # End of epoch - increment epoch counter self.epoch += 1 # Reset session positions for next epoch for session_name in self.session_names: self.session_positions[session_name] = 0 self.batches_from_session = defaultdict(int) self.active_sessions = set(self.session_names)
[docs] class SessionSpecificSampler(Sampler): """ A batch sampler specific to a single session that efficiently generates batches from the session's indices. """
[docs] def __init__(self, indices, batch_size, drop_last=False, shuffle=False, seed=None): """Initialize session-specific sampler. Parameters ---------- indices : list Dataset indices belonging to this session. batch_size : int Number of samples per batch. drop_last : bool, optional Whether to drop the last batch if smaller than batch_size. Default is False. shuffle : bool, optional Whether to shuffle indices. Default is False. seed : int, optional Random seed for reproducibility. """ self.indices = list(indices) # Make a copy to avoid modification issues self.batch_size = batch_size self.drop_last = drop_last self.shuffle = shuffle self.prv_rng_state = None self.rng = ( np.random.RandomState(seed) if seed is not None else np.random.RandomState() ) # Calculate number of batches if drop_last: self.num_batches = len(indices) // batch_size else: self.num_batches = (len(indices) + batch_size - 1) // batch_size # Track current position self.position = 0
[docs] def __len__(self): """Return the number of batches.""" return self.num_batches
[docs] def set_position(self, position): """Set the current batch position.""" self.position = position % self.num_batches if self.num_batches > 0 else 0
[docs] def get_state(self): """Return the state of the sampler (including RNG state).""" return {"prv_rng_state": self.prv_rng_state}
[docs] def set_state(self, state): """Restore the state of the sampler (including RNG state).""" rng_state = state.get("prv_rng_state") if rng_state is not None and self.rng is not None: self.rng.set_state(rng_state) self.prv_rng_state = None
[docs] def __iter__(self): """ Yield batches of indices starting from the current position. """ # Create shuffled indices if needed if self.shuffle and self.rng is not None: indices = self.indices.copy() self.rng.shuffle(indices) else: indices = self.indices # Start from current position start_idx = self.position * self.batch_size # If start_idx is out of bounds, stop iteration for this pass if start_idx >= len(indices): return # Effectively yield from [] # Generate batches from start_idx to end for i in range(start_idx, len(indices), self.batch_size): batch_indices = indices[i : i + self.batch_size] # Skip last batch if needed if self.drop_last and len(batch_indices) < self.batch_size: continue yield batch_indices