experanto.datasets.ChunkDataset
- class ChunkDataset(*args, **kwargs)[source]
Bases:
DatasetPyTorch Dataset for chunked experiment data.
This dataset loads an experiment and provides temporally-chunked samples suitable for training neural networks. Each sample contains synchronized data from all modalities (e.g., screen, responses, eye_tracker, treadmill) at a given time window.
- Parameters:
root_folder (str) – Path to the experiment directory containing modality subfolders.
global_sampling_rate (float, optional) – Sampling rate (Hz) applied to all modalities. If None, uses per-modality rates from
modality_config.global_chunk_size (int, optional) – Number of samples per chunk for all modalities. If None, uses per-modality sizes from
modality_config.add_behavior_as_channels (bool, default=False) – If True, concatenates behavioral data as additional image channels. Deprecated: use separate modality outputs instead.
replace_nans_with_means (bool, default=False) – If True, replaces NaN values with column means.
cache_data (bool, default=False) – If True, keeps loaded data in memory for faster access.
out_keys (iterable, optional) – Which modalities to include in output. Defaults to all modalities plus ‘timestamps’.
normalize_timestamps (bool, default=True) – If True, normalizes timestamps relative to recording start.
modality_config (dict) – Configuration for each modality including sampling rates, transforms, filters, and interpolation settings. See Notes for structure.
seed (int, optional) – Random seed for reproducible shuffling of valid time points.
safe_interval_threshold (float, default=0.5) – Safety margin (in seconds) to exclude from edges of valid intervals.
interpolate_precision (int, default=5) – Number of decimal places for time precision. Prevents floating-point accumulation errors during interpolation.
See also
ExperimentLower-level interface for data access.
get_multisession_dataloaderLoad multiple datasets.
experanto.configsDefault configuration values.
Notes
The dataset handles:
Per-modality sampling rates and chunk sizes
Time offset alignment between modalities
Data normalization and transforms
Filtering based on trial conditions and data quality
Reproducible random sampling with seeds
The
modality_configis a nested dictionary with per-modality settings. The following is an example of a modality config for a screen, responses, eye_tracker, and treadmill:screen: sampling_rate: null chunk_size: null valid_condition: tier: test stim_type: stimulus.Frame offset: 0 sample_stride: 4 include_blanks: false transforms: ToTensor: _target_: torchvision.transforms.ToTensor Normalize: _target_: torchvision.transforms.Normalize mean: 80.0 std: 60.0 Resize: _target_: torchvision.transforms.Resize size: - 144 - 256 CenterCrop: _target_: torchvision.transforms.CenterCrop size: 144 interpolation: {} responses: sampling_rate: null chunk_size: null offset: 0.1 transforms: normalize_variance_only: true # old standardize: true interpolation: interpolation_mode: nearest_neighbor eye_tracker: sampling_rate: null chunk_size: null offset: 0 transforms: normalize: true interpolation: interpolation_mode: nearest_neighbor treadmill: sampling_rate: null chunk_size: null offset: 0 transforms: normalize: true interpolation: interpolation_mode: nearest_neighbor
Examples
>>> from experanto.datasets import ChunkDataset >>> from experanto.configs import DEFAULT_MODALITY_CONFIG >>> dataset = ChunkDataset( ... '/path/to/experiment', ... global_sampling_rate=30, ... global_chunk_size=60, ... modality_config=DEFAULT_MODALITY_CONFIG, ... ) >>> len(dataset) 1000 >>> sample = dataset[0] >>> sample['screen'].shape torch.Size([1, 60, 144, 256]) >>> sample['responses'].shape torch.Size([16, 500])
Methods
__init__(root_folder[, ...])Create a boolean mask for trials satisfying given conditions.
get_data_key_from_root_folder(root_folder)Extract a data key from the root folder path.
get_full_valid_sample_times([...])Get all valid chunk starting times based on meta conditions.
Create a boolean mask for screen samples satisfying given conditions.
Return the current state of the dataset's RNG.
get_valid_intervals_from_filters([visualize])Initialize normalization statistics for each device.
Initialize data transforms for each device based on modality config.
set_state(state)Restore the dataset's RNG state.
Shuffle valid screen times using the dataset's random number generator for reproducibility.
- __init__(root_folder, global_sampling_rate=None, global_chunk_size=None, add_behavior_as_channels=False, replace_nans_with_means=False, cache_data=False, out_keys=None, normalize_timestamps=True, modality_config={'screen': {'keep_nans': False, 'sampling_rate': 30, 'chunk_size': 60, 'valid_condition': {'tier': 'train'}, 'offset': 0, 'sample_stride': 1, 'include_blanks': True, 'transforms': {'normalization': 'normalize', 'Resize': {'_target_': 'torchvision.transforms.v2.Resize', 'size': [144, 256]}}, 'interpolation': {'rescale': True, 'rescale_size': [144, 256]}}, 'responses': {'keep_nans': False, 'sampling_rate': 8, 'chunk_size': 16, 'offset': 0.0, 'transforms': {'normalization': 'normalize_variance_only'}, 'interpolation': {'interpolation_mode': 'nearest_neighbor'}, 'filters': {'nan_filter': {'__target__': 'experanto.filters.common_filters.nan_filter', '__partial__': True, 'vicinity': 0.05}}}, 'eye_tracker': {'keep_nans': False, 'sampling_rate': 30, 'chunk_size': 60, 'offset': 0, 'transforms': {'normalization': 'normalize'}, 'interpolation': {'interpolation_mode': 'nearest_neighbor'}, 'filters': {'nan_filter': {'__target__': 'experanto.filters.common_filters.nan_filter', '__partial__': True, 'vicinity': 0.05}}}, 'treadmill': {'keep_nans': False, 'sampling_rate': 30, 'chunk_size': 60, 'offset': 0, 'transforms': {'normalization': 'normalize'}, 'interpolation': {'interpolation_mode': 'nearest_neighbor'}, 'filters': {'nan_filter': {'__target__': 'experanto.filters.common_filters.nan_filter', '__partial__': True, 'vicinity': 0.05}}}}, seed=None, safe_interval_threshold=0.5, interpolate_precision=5)[source]
- initialize_statistics()[source]
Initialize normalization statistics for each device.
Loads mean and standard deviation values from each device’s meta folder and stores them in
self._statisticsfor use during data transforms.
- initialize_transforms()[source]
Initialize data transforms for each device based on modality config.
- get_condition_mask_from_meta_conditions(valid_conditions_sum_of_product)[source]
Create a boolean mask for trials satisfying given conditions.
- Parameters:
valid_conditions_sum_of_product (list of dict) – Condition dictionaries combined with OR logic, where conditions within each dictionary use AND logic.
- Returns:
Boolean mask indicating which trials satisfy at least one set of conditions.
- Return type:
np.ndarray
Notes
For example,
[{'tier': 'train', 'stim_type': 'natural'}, {'tier': 'blank'}]matches trials that are either (train AND natural) OR blank.
- get_screen_sample_mask_from_meta_conditions(satisfy_for_next, valid_conditions_sum_of_product, filter_for_valid_intervals=True)[source]
Create a boolean mask for screen samples satisfying given conditions.
- Parameters:
satisfy_for_next (int) – Number of consecutive samples that must satisfy conditions.
valid_conditions_sum_of_product (list of dict) – Condition dictionaries combined with OR logic, where conditions within each dictionary use AND logic.
filter_for_valid_intervals (bool, default=True) – Whether to apply interval-based filtering.
- Returns:
Boolean array matching screen sample times, True where conditions are met.
- Return type:
- get_full_valid_sample_times(filter_for_valid_intervals=True)[source]
Get all valid chunk starting times based on meta conditions.
Iterates through sample times and checks if they can be used as chunk start times (i.e., the next
chunk_sizepoints are all valid based on the previous meta condition filtering).- Parameters:
filter_for_valid_intervals (bool, default=True) – Whether to apply interval-based filtering.
- Returns:
Array of valid starting time points.
- Return type:
- shuffle_valid_screen_times()[source]
Shuffle valid screen times using the dataset’s random number generator for reproducibility.
- get_data_key_from_root_folder(root_folder)[source]
Extract a data key from the root folder path.
Checks for a meta.json file and extracts the data_key or scan_key.
- __getitem__(idx)[source]
Return a single data sample at the given index.
- Parameters:
idx (int) – Index of the sample to retrieve.
- Returns:
Dictionary containing data for each modality in
out_keys, e.g.:'screen': torch.Tensor of shape(C, T, H, W)'responses': torch.Tensor of shape(T, N_neurons)'eye_tracker': torch.Tensor of shape(T, N_features)'treadmill': torch.Tensor of shape(T, N_features)'timestamps': dict mapping modality names to time arrays
Where
Tis the chunk size (may differ per modality),Cis channels,His height,Wis width.- Return type: