Source code for experanto.dataloaders

import logging
import warnings
from typing import Optional

from omegaconf import DictConfig

from .datasets import ChunkDataset
from .utils import (
    FastSessionDataLoader,
    LongCycler,
    MultiEpochsDataLoader,
    SessionConcatDataset,
)

logger = logging.getLogger(__name__)


[docs] def get_multisession_dataloader( paths: list[str], configs: DictConfig | dict | list[DictConfig | dict] | None = None, shuffle_keys: bool = False, **kwargs, ) -> LongCycler: """Create a multi-session dataloader from multiple experiment paths. By default, creates a :class:`ChunkDataset` for each path and wraps them in a :class:`LongCycler` that yields ``(session_key, batch)`` pairs. The cycler continues until the longest session is exhausted. Parameters ---------- paths : list of str Paths to experiment directories. configs : dict, DictConfig, list, optional Configuration for each dataset. If a single config is provided, it will be applied to all datasets. If a list is provided, it should match the length of ``paths``. Each config should have ``dataset`` and ``dataloader`` keys. shuffle_keys : bool, default=False Whether to shuffle the order of session keys. **kwargs Additional keyword arguments. Supports ``config`` as an alias for ``configs``. Returns ------- LongCycler A dataloader-like object that yields ``(session_key, batch)`` tuples. Iterates until the longest session is exhausted. See Also -------- get_multisession_concat_dataloader : Alternative that concatenates sessions. LongCycler : The underlying multi-session iterator. Examples -------- >>> from experanto.dataloaders import get_multisession_dataloader >>> from experanto.configs import DEFAULT_CONFIG >>> paths = ['/path/to/exp1', '/path/to/exp2'] >>> loader = get_multisession_dataloader(paths, configs=DEFAULT_CONFIG) >>> for session_key, batch in loader: ... print(f"Session: {session_key}, batch shape: {batch['responses'].shape}") """ if configs is None and "config" in kwargs: configs = kwargs.pop("config") # Convert single config to list for uniform handling if isinstance(configs, (DictConfig, dict)): configs = [configs] * len(paths) assert configs is not None dataloaders = {} for i, (path, cfg) in enumerate(zip(paths, configs, strict=True)): # TODO use saved meta dict to find data key if "dynamic" in path: dataset_name = path.split("dynamic")[1].split("-Video")[0] elif "_gaze" in path: dataset_name = path.split("_gaze")[0].split("datasets/")[1] else: dataset_name = f"session_{i}" dataset = ChunkDataset(path, **cfg["dataset"]) dataloaders[dataset_name] = MultiEpochsDataLoader( dataset, **cfg["dataloader"], ) return LongCycler(dataloaders)
[docs] def get_multisession_concat_dataloader( paths: list[str], configs: dict | list[dict] | None = None, seed: int | None = 0, dataloader_config: dict | None = None, **kwargs, ) -> Optional["FastSessionDataLoader"]: """Create a concatenated multi-session dataloader. Unlike :func:`get_multisession_dataloader`, this function concatenates all sessions into a single dataset and uses batch sampling to ensure each batch contains samples from only one session. This is more memory-efficient and provides better shuffling across sessions. Parameters ---------- paths : list of str Paths to experiment directories. configs : dict or list of dict, optional Configuration for each dataset. If a single config is provided, it will be applied to all datasets. Each config should have ``dataset`` and ``dataloader`` keys. seed : int, default=0 Random seed for reproducibility. Each dataset gets a deterministic seed derived from this value and its path hash. dataloader_config : dict, optional Configuration for the dataloader (batch_size, num_workers, etc.). If None, uses the dataloader config from the first config. **kwargs Additional keyword arguments. Supports ``config`` as an alias for ``configs``. Returns ------- FastSessionDataLoader or None A dataloader that yields ``(session_key, batch)`` tuples. Returns None if no valid datasets could be created. See Also -------- get_multisession_dataloader : Alternative using separate dataloaders. FastSessionDataLoader : The underlying dataloader implementation. SessionConcatDataset : Dataset that concatenates multiple sessions. Examples -------- >>> from experanto.dataloaders import get_multisession_concat_dataloader >>> from experanto.configs import DEFAULT_CONFIG >>> paths = ['/path/to/exp1', '/path/to/exp2', '/path/to/exp3'] >>> loader = get_multisession_concat_dataloader(paths, configs=DEFAULT_CONFIG) >>> for session_key, batch in loader: ... print(f"Session: {session_key}") """ if configs is None and "config" in kwargs: configs = kwargs.pop("config") assert configs is not None # Convert single config to list for uniform handling if not isinstance(configs, list): configs = [configs] * len(paths) # Create datasets datasets = [] session_names = [] for _i, (path, cfg) in enumerate(zip(paths, configs, strict=True)): # Create dataset with deterministic seed path_hash = hash(path) % 10000 dataset_seed = seed + path_hash if seed is not None else None # Set specific seed for this dataset if needed if hasattr(cfg.get("dataset", {}), "seed") and dataset_seed is not None: cfg["dataset"]["seed"] = dataset_seed if "dataset" in cfg: cfg = cfg["dataset"] try: # Assuming ChunkDataset is defined elsewhere dataset = ChunkDataset(path, **cfg) session_name = dataset.data_key # Only add datasets with non-zero length if len(dataset) > 0: datasets.append(dataset) session_names.append(session_name) except Exception as e: warnings.warn( f"Error creating dataset for {path}: {str(e)}", UserWarning, stacklevel=2, ) if not datasets: return None # Create the concatenated dataset concat_dataset = SessionConcatDataset(datasets, session_names) # Get dataloader config from the first config if dataloader_config is None: dataloader_config = dict(configs[0].get("dataloader", {})) # Create the dataloader with our simplified implementation return FastSessionDataLoader(dataset=concat_dataset, seed=seed, **dataloader_config)