import bisect
import logging
# inbuilt libraries
from collections import defaultdict
# third-party libraries
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, Sampler
# local libraries
logger = logging.getLogger(__name__)
def replace_nan_with_batch_mean(data: np.ndarray) -> np.ndarray:
row, col = np.where(np.isnan(data))
for i, j in zip(row, col, strict=True):
new_value = np.nanmean(data[:, j])
data[i, j] = new_value if not np.isnan(new_value) else 0
return data
[docs]
def add_behavior_as_channels(data: dict[str, torch.Tensor]) -> dict:
"""Add behavioral data as additional channels to screen data.
Parameters
----------
data : dict
Dictionary with keys 'screen', 'eye_tracker', 'treadmill'.
Screen shape: ``(C, T, H, W)``.
Behavior shapes: ``(T, C_behavior)`` or ``(T, H, W)``.
Returns
-------
dict
Modified dictionary with behavior concatenated to screen channels.
Screen shape becomes ``(C + behavior_channels, T, H, W)``.
"""
screen = data["screen"] # Already contiguous, shape (c, t, h, w)
c, t, h, w = screen.shape
eye_tracker = data["eye_tracker"]
treadmill = data["treadmill"]
# Process eye_tracker
if len(eye_tracker.shape) == 2: # (t, c_eye)
# Reshape to (c_eye, t, h, w)
eye_tracker = eye_tracker.transpose(0, 1) # (c_eye, t)
eye_tracker = eye_tracker.unsqueeze(-1).unsqueeze(-1) # (c_eye, t, 1, 1)
eye_tracker = eye_tracker.expand(-1, -1, h, w).contiguous() # (c_eye, t, h, w)
else: # (t, h, w)
# Reshape to (1, t, h, w)
eye_tracker = eye_tracker.unsqueeze(0).contiguous() # (1, t, h, w)
# Process treadmill
if len(treadmill.shape) == 2: # (t, c_tread)
# Reshape to (c_tread, t, h, w)
treadmill = treadmill.transpose(0, 1) # (c_tread, t)
treadmill = treadmill.unsqueeze(-1).unsqueeze(-1) # (c_tread, t, 1, 1)
treadmill = treadmill.expand(-1, -1, h, w).contiguous() # (c_tread, t, h, w)
else: # (t, h, w)
# Reshape to (1, t, h, w)
treadmill = treadmill.unsqueeze(0).contiguous() # (1, t, h, w)
# Concatenate along the channel dimension (dim=0) and ensure the result is contiguous
result = torch.cat([screen, eye_tracker, treadmill], dim=0)
# Ensure the result is contiguous
if not result.is_contiguous():
result = result.contiguous()
data["screen"] = result
return data
[docs]
class MultiEpochsDataLoader(DataLoader):
"""DataLoader that keeps workers alive across epochs.
Solves a bug where worker processes are re-spawned at the start of each
epoch, causing significant overhead. Workers are initialized once and
reused throughout training.
Parameters
----------
*args
Positional arguments forwarded to :class:`torch.utils.data.DataLoader`.
shuffle_each_epoch : bool, default=False
If True and the underlying dataset has a ``shuffle_valid_screen_times``
method, that method is called at the start of every epoch.
**kwargs
Keyword arguments forwarded to :class:`torch.utils.data.DataLoader`.
References
----------
https://discuss.pytorch.org/t/enumerate-dataloader-slow/87778
https://github.com/huggingface/pytorch-image-models/blob/d72ac0db259275233877be8c1d4872163954dfbb/timm/data/loader.py#L209-L238
"""
[docs]
def __init__(
self,
*args,
shuffle_each_epoch=False,
**kwargs,
):
super().__init__(*args, **kwargs)
self._DataLoader__initialized = False
self.batch_sampler = _RepeatSampler(self.batch_sampler)
self._DataLoader__initialized = True
self.iterator = super().__iter__()
self.shuffle_each_epoch = shuffle_each_epoch
def __len__(self):
return len(self.batch_sampler)
def __iter__(self): # type: ignore[override]
if self.shuffle_each_epoch and hasattr(
self.dataset, "shuffle_valid_screen_times"
):
self.dataset.shuffle_valid_screen_times() # type: ignore[union-attr]
for _i in range(len(self)):
yield next(self.iterator)
# borrowed with <3 from
# https://github.com/sinzlab/neuralpredictors/blob/main/neuralpredictors/training/cyclers.py
def cycle(iterable):
# see https://github.com/pytorch/pytorch/issues/23900
iterator = iter(iterable)
while True:
try:
yield next(iterator)
except StopIteration:
iterator = iter(iterable)
class Exhauster:
"""
Given a dictionary of data loaders, mapping data_key into a data loader, steps through each data loader, moving onto the next data loader
only upon exhausing the content of the current data loader.
"""
def __init__(self, loaders):
self.loaders = loaders
def __iter__(self):
for data_key, loader in self.loaders.items():
for batch in loader:
yield data_key, batch
def __len__(self):
return sum([len(loader) for loader in self.loaders])
[docs]
class LongCycler:
"""Cycle through multiple dataloaders until the longest is exhausted.
Useful for training with multiple sessions of unequal size. Cycles through
all loaders, yielding ``(session_key, batch)`` pairs. Shorter loaders are
recycled until the longest loader completes one full epoch.
Parameters
----------
loaders : dict
Dictionary mapping session keys to DataLoader instances.
Attributes
----------
max_batches : int
Number of batches in the longest loader.
Examples
--------
>>> loaders = {'session_1': loader1, 'session_2': loader2}
>>> cycler = LongCycler(loaders)
>>> for session_key, batch in cycler:
... print(f"Processing {session_key}")
"""
[docs]
def __init__(self, loaders):
self.loaders = loaders
self.max_batches = max([len(loader) for loader in self.loaders.values()])
def __iter__(self):
cycles = [cycle(loader) for loader in self.loaders.values()]
for k, loader, _ in zip(
cycle(self.loaders.keys()),
(cycle(cycles)),
range(len(self.loaders) * self.max_batches),
strict=True,
):
yield k, next(loader)
def __len__(self):
return len(self.loaders) * self.max_batches
[docs]
class ShortCycler:
"""Cycle through multiple dataloaders until the shortest is exhausted.
Similar to :class:`LongCycler`, but stops when the smallest loader
completes one epoch. No recycling occurs.
Parameters
----------
loaders : dict
Dictionary mapping session keys to DataLoader instances.
Attributes
----------
min_batches : int
Number of batches in the shortest loader.
"""
[docs]
def __init__(self, loaders):
self.loaders = loaders
self.min_batches = min([len(loader) for loader in self.loaders.values()])
def __iter__(self):
cycles = [cycle(loader) for loader in self.loaders.values()]
for k, loader, _ in zip(
cycle(self.loaders.keys()),
(cycle(cycles)),
range(len(self.loaders) * self.min_batches),
strict=True,
):
yield k, next(loader)
def __len__(self):
return len(self.loaders) * self.min_batches
class _RepeatSampler:
"""Simple sampler that repeats indefinitely."""
def __init__(self, sampler):
self.sampler = sampler
def __len__(self):
"""Return the length of the original sampler."""
return len(self.sampler)
def __iter__(self):
while True:
yield from iter(self.sampler)
[docs]
class SessionConcatDataset(Dataset):
"""Memory-efficient concatenated dataset that reliably tracks sessions."""
[docs]
def __init__(self, datasets, session_names=None):
"""Initialize the concatenated dataset with session tracking."""
if not datasets:
raise ValueError("datasets is empty")
# Store datasets
self.datasets = list(datasets)
# Track session names
if session_names is None:
session_names = [f"session_{i}" for i in range(len(datasets))]
self.session_names = session_names
# Log dataset sizes for debugging
for i, (name, dataset) in enumerate(zip(session_names, datasets, strict=True)):
logger.debug("Dataset %s: %s, length = %s", i, name, len(dataset))
# Compute cumulative sizes for efficient indexing
self.cumulative_sizes = []
current_size = 0
for dataset in self.datasets:
current_size += len(dataset)
self.cumulative_sizes.append(current_size)
# Create session indices dictionary for fast lookup
self.session_indices = {}
start_idx = 0
for i, dataset in enumerate(datasets):
session_name = session_names[i]
session_size = len(dataset)
self.session_indices[session_name] = (
start_idx,
start_idx + session_size,
)
start_idx += session_size
[docs]
def __len__(self):
"""Return total length of the concatenated dataset."""
return self.cumulative_sizes[-1] if self.cumulative_sizes else 0
[docs]
def __getitem__(self, idx):
"""Get item from the appropriate dataset."""
if idx < 0 or idx >= len(self):
raise IndexError(
f"Index {idx} out of range for dataset of size {len(self)}"
)
# Find which dataset the index belongs to
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
# Get the data from the dataset
data = self.datasets[dataset_idx][sample_idx]
# Return the data along with session information
return data
[docs]
def get_session_for_idx(self, idx):
"""Get the session name for a given index."""
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
return self.session_names[dataset_idx]
[docs]
def get_indices_for_session(self, session_name):
"""Get all indices belonging to a given session."""
if session_name in self.session_indices:
start, end = self.session_indices[session_name]
return list(range(start, end))
return []
[docs]
def get_sessions_count(self):
"""Get number of sessions and count of samples per session."""
return {
name: end - start for name, (start, end) in self.session_indices.items()
}
[docs]
class SessionBatchSampler(Sampler):
"""
A batch sampler that cycles through sessions, ensuring each session
appears exactly once before repeating any session.
"""
[docs]
def __init__(self, dataset, batch_size, drop_last=False, shuffle=False, seed=None):
"""Initialize session batch sampler.
Parameters
----------
dataset : SessionConcatDataset
The dataset to sample from.
batch_size : int
Number of samples per batch.
drop_last : bool, optional
Whether to drop the last batch if smaller than batch_size.
Default is False.
shuffle : bool, optional
Whether to shuffle samples within each session. Default is False.
seed : int, optional
Random seed for reproducibility.
"""
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle = shuffle
self.seed = seed
# Use its own RNG instance based on the provided seed
self.rng = (
np.random.RandomState(seed) if seed is not None else np.random.RandomState()
)
self.prv_rng_state = None
# Get sessions
self.session_names = list(dataset.session_indices.keys())
logger.debug("Sessions: %s", self.session_names)
self.consumed_sessions = []
# Pre-process session indices
self.session_indices = {}
for session_name in self.session_names:
indices = dataset.get_indices_for_session(session_name)
if indices:
self.session_indices[session_name] = indices
# Calculate batches per session
self.batches_per_session = {}
total_batches = 0
for session_name, indices in self.session_indices.items():
session_size = len(indices)
if drop_last:
num_batches = session_size // batch_size
else:
num_batches = (session_size + batch_size - 1) // batch_size
self.batches_per_session[session_name] = num_batches
total_batches += num_batches
logger.debug("Batches per session: %s", self.batches_per_session)
logger.debug("Total batches: %s", total_batches)
[docs]
def __len__(self):
"""Return the total number of batches across all sessions."""
return sum(self.batches_per_session.values())
[docs]
def get_session_cycle(self):
"""
Generate one cycle of sessions, with each session appearing exactly once.
Sessions are shuffled unless their appearance order needs to be controlled.
"""
order = list(self.session_names)
self.prv_rng_state = self.rng.get_state()
if self.shuffle:
self.rng.shuffle(order)
# Remove consumed sessions from order
for session_name in self.consumed_sessions:
order.remove(session_name)
return order
[docs]
def get_state(self):
"""Return the state of the sampler (including RNG state)."""
return {
"prv_rng_state": self.prv_rng_state,
"consumed_sessions": self.consumed_sessions,
}
[docs]
def set_state(self, state):
"""Restore the state of the sampler (including RNG state)."""
rng_state = state.get("prv_rng_state")
if rng_state is not None and self.rng is not None:
self.rng.set_state(rng_state)
self.prv_rng_state = None
self.consumed_sessions = state.get("consumed_sessions", [])
[docs]
class FastSessionDataLoader:
"""Optimized multi-session dataloader with state tracking.
Provides efficient data loading across multiple sessions with guarantees:
- Each session appears exactly once before repeating
- Epoch ends when the longest session is exhausted
- Perfect alignment between sessions and batches is maintained
- State is properly tracked and can be restored
Parameters
----------
dataset : SessionConcatDataset
Concatenated dataset with session tracking.
batch_size : int, default=1
Number of samples per batch.
shuffle : bool, default=False
Whether to shuffle samples within each session.
num_workers : int, default=0
Number of worker processes for data loading.
pin_memory : bool, default=False
Whether to pin memory for GPU transfer.
drop_last : bool, default=False
Whether to drop incomplete batches.
seed : int, optional
Random seed for reproducibility.
**kwargs
Additional arguments passed to underlying DataLoaders.
Attributes
----------
session_names : list
Names of all sessions in the dataset.
batches_per_session : dict
Number of batches in each session.
See Also
--------
SessionConcatDataset : Dataset that tracks session membership.
LongCycler : Simpler alternative without state tracking.
"""
[docs]
def __init__(
self,
dataset,
batch_size=1,
shuffle=False,
num_workers=0,
pin_memory=False,
drop_last=False,
seed=None,
**kwargs,
):
# Store dataset and parameters
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.seed = seed
self.num_workers = num_workers
self.pin_memory = pin_memory
self.kwargs = kwargs
# Create batch sampler
self.batch_sampler = SessionBatchSampler(
dataset=dataset,
batch_size=batch_size,
drop_last=drop_last,
shuffle=shuffle,
seed=seed,
)
# Store session info for faster access
self.session_names = self.batch_sampler.session_names
self.session_indices = self.batch_sampler.session_indices
self.batches_per_session = self.batch_sampler.batches_per_session
# Compute maximum batches per session (for epoch tracking)
self.max_batches_per_session = (
max(self.batches_per_session.values()) if self.batches_per_session else 0
)
# Prepare session data loaders to avoid recreating them for each batch
self.session_dataloaders = {}
for i, session_name in enumerate(self.session_names):
indices = self.session_indices[session_name]
# Derive a unique seed for each session sampler
session_seed = None if seed is None else seed + i + 1
# Create a specific sampler for this session
session_sampler = SessionSpecificSampler(
indices=indices,
batch_size=batch_size,
drop_last=drop_last,
shuffle=shuffle,
seed=session_seed,
)
# Create a DataLoader for this session
self.session_dataloaders[session_name] = DataLoader(
dataset=dataset,
batch_sampler=session_sampler,
num_workers=num_workers,
pin_memory=pin_memory,
**kwargs,
)
# State tracking variables
self.current_batch = 0
self.epoch = 0
self.session_positions = dict.fromkeys(self.session_names, 0)
self.batches_from_session = defaultdict(
int
) # Tracks batches yielded per session in the current epoch iteration
# Track active sessions
self.active_sessions = set(self.session_names)
logger.debug(
"Created FastSessionDataLoader with %s sessions and %s total batches",
len(self.session_names),
len(self),
)
[docs]
def __len__(self):
"""Return the total number of batches in an epoch."""
return sum(self.batches_per_session.values())
[docs]
def get_state(self):
"""Return the current state of the dataloader."""
return {
"current_batch": self.current_batch,
"epoch": self.epoch,
"session_positions": self.session_positions.copy(),
"batches_from_session": self.batches_from_session.copy(),
"active_sessions": list(
self.active_sessions
), # Store as list for serialization
"batch_sampler_state": self.batch_sampler.get_state(),
"session_sampler_states": {
name: dl.batch_sampler.get_state()
for name, dl in self.session_dataloaders.items()
},
}
[docs]
def set_state(self, state):
"""Restore the dataloader state."""
if not state:
return
# Restore batch counter
self.current_batch = state.get("current_batch", 0)
# Restore epoch counter
self.epoch = state.get("epoch", 0)
# Restore session positions
session_positions = state.get("session_positions")
if session_positions:
self.session_positions = session_positions
# Restore RNG state for the main dataloader
dataloader_rng_state = state.get("dataloader_rng_state")
if (
dataloader_rng_state is not None
and hasattr(self, "rng")
and self.rng is not None # type: ignore[attr-defined]
):
self.rng.set_state(dataloader_rng_state) # type: ignore[attr-defined]
# Restore RNG state for the batch sampler
batch_sampler_state = state.get("batch_sampler_state")
if batch_sampler_state is not None and hasattr(self.batch_sampler, "set_state"):
self.batch_sampler.set_state(batch_sampler_state)
# Restore batches_from_session state
batches_from_session_state = state.get("batches_from_session")
if batches_from_session_state is not None:
self.batches_from_session = defaultdict(int)
self.batches_from_session.update(batches_from_session_state)
else:
# For backward compatibility or if state doesn't have it
self.batches_from_session = defaultdict(int)
# Restore active sessions
active_sessions_list = state.get("active_sessions")
if active_sessions_list is not None:
self.active_sessions = set(active_sessions_list)
else:
# Default to all sessions if not in state (for backward compatibility)
self.active_sessions = set(self.session_names)
# Reset session iterators with new positions
for session_name, dataloader in self.session_dataloaders.items():
# Get sampler and reset its position
sampler = dataloader.batch_sampler
if hasattr(sampler, "set_position"):
position = self.session_positions.get(session_name, 0)
sampler.set_position(position)
# Restore RNG state for each session sampler
session_sampler_states = state.get("session_sampler_states", {})
sampler_state = session_sampler_states.get(session_name)
if sampler_state is not None and hasattr(sampler, "set_state"):
sampler.set_state(sampler_state)
logger.info(
"Restored dataloader state to batch %s, epoch %s",
self.current_batch,
self.epoch,
)
[docs]
def __iter__(self):
"""
Iterate through sessions, cycling through them until all are exhausted.
The iteration scheme ensures:
1. Each session appears exactly once in each cycle
2. Samples within a session are properly batched and optionally shuffled
3. The epoch ends when the longest session is exhausted
"""
# Track active sessions
active_sessions = self.active_sessions
# Track position within each session
position_in_epoch = 0
batches_from_session = self.batches_from_session
# Reset session positions if needed
for session_name in self.session_names:
if self.session_positions.get(
session_name, 0
) >= self.batches_per_session.get(session_name, 0):
self.session_positions[session_name] = 0
# Reset iterators with current positions
session_iterators = {}
for session_name, dataloader in self.session_dataloaders.items():
# Reset sampler position
sampler = dataloader.batch_sampler
if hasattr(sampler, "set_position"):
sampler.set_position(self.session_positions.get(session_name, 0))
# Create iterator
session_iterators[session_name] = iter(dataloader)
# Continue until we've gone through one full epoch
# (i.e., until the longest session is exhausted)
while active_sessions and position_in_epoch < self.max_batches_per_session:
# Create a cycle order of sessions
cycle_order = self.batch_sampler.get_session_cycle()
# Process one batch from each active session in this cycle
for session_name in cycle_order:
# Skip if session is already exhausted
if session_name not in active_sessions:
continue
# Skip if we've already processed all batches for this session in the current epoch
if batches_from_session[session_name] >= self.batches_per_session.get(
session_name, 0
):
active_sessions.remove(session_name)
continue
# Get iterator for this session
iterator = session_iterators.get(session_name)
if iterator is None:
continue
try:
# Get the next batch from this session
batch = next(iterator)
# Update state tracking
self.current_batch += 1
self.session_positions[session_name] += 1
batches_from_session[session_name] += 1 # Update local dictionary
# Yield session name and batch
yield session_name, batch
except StopIteration:
# This session is exhausted for the current epoch
active_sessions.remove(session_name)
self.batch_sampler.consumed_sessions.append(session_name)
self.batch_sampler.consumed_sessions = []
# If we've completed a full cycle, increment the position counter
position_in_epoch += 1
# End of epoch - increment epoch counter
self.epoch += 1
# Reset session positions for next epoch
for session_name in self.session_names:
self.session_positions[session_name] = 0
self.batches_from_session = defaultdict(int)
self.active_sessions = set(self.session_names)
[docs]
class SessionSpecificSampler(Sampler):
"""
A batch sampler specific to a single session that efficiently
generates batches from the session's indices.
"""
[docs]
def __init__(self, indices, batch_size, drop_last=False, shuffle=False, seed=None):
"""Initialize session-specific sampler.
Parameters
----------
indices : list
Dataset indices belonging to this session.
batch_size : int
Number of samples per batch.
drop_last : bool, optional
Whether to drop the last batch if smaller than batch_size.
Default is False.
shuffle : bool, optional
Whether to shuffle indices. Default is False.
seed : int, optional
Random seed for reproducibility.
"""
self.indices = list(indices) # Make a copy to avoid modification issues
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle = shuffle
self.prv_rng_state = None
self.rng = (
np.random.RandomState(seed) if seed is not None else np.random.RandomState()
)
# Calculate number of batches
if drop_last:
self.num_batches = len(indices) // batch_size
else:
self.num_batches = (len(indices) + batch_size - 1) // batch_size
# Track current position
self.position = 0
[docs]
def __len__(self):
"""Return the number of batches."""
return self.num_batches
[docs]
def set_position(self, position):
"""Set the current batch position."""
self.position = position % self.num_batches if self.num_batches > 0 else 0
[docs]
def get_state(self):
"""Return the state of the sampler (including RNG state)."""
return {"prv_rng_state": self.prv_rng_state}
[docs]
def set_state(self, state):
"""Restore the state of the sampler (including RNG state)."""
rng_state = state.get("prv_rng_state")
if rng_state is not None and self.rng is not None:
self.rng.set_state(rng_state)
self.prv_rng_state = None
[docs]
def __iter__(self):
"""
Yield batches of indices starting from the current position.
"""
# Create shuffled indices if needed
if self.shuffle and self.rng is not None:
indices = self.indices.copy()
self.rng.shuffle(indices)
else:
indices = self.indices
# Start from current position
start_idx = self.position * self.batch_size
# If start_idx is out of bounds, stop iteration for this pass
if start_idx >= len(indices):
return # Effectively yield from []
# Generate batches from start_idx to end
for i in range(start_idx, len(indices), self.batch_size):
batch_indices = indices[i : i + self.batch_size]
# Skip last batch if needed
if self.drop_last and len(batch_indices) < self.batch_size:
continue
yield batch_indices