experanto.datasets.ChunkDataset

class ChunkDataset(*args, **kwargs)[source]

Bases: Dataset

PyTorch Dataset for chunked experiment data.

This dataset loads an experiment and provides temporally-chunked samples suitable for training neural networks. Each sample contains synchronized data from all modalities (e.g., screen, responses, eye_tracker, treadmill) at a given time window.

Parameters:
  • root_folder (str) – Path to the experiment directory containing modality subfolders.

  • global_sampling_rate (float, optional) – Sampling rate (Hz) applied to all modalities. If None, uses per-modality rates from modality_config.

  • global_chunk_size (int, optional) – Number of samples per chunk for all modalities. If None, uses per-modality sizes from modality_config.

  • add_behavior_as_channels (bool, default=False) – If True, concatenates behavioral data as additional image channels. Deprecated: use separate modality outputs instead.

  • replace_nans_with_means (bool, default=False) – If True, replaces NaN values with column means.

  • cache_data (bool, default=False) – If True, keeps loaded data in memory for faster access.

  • out_keys (iterable, optional) – Which modalities to include in output. Defaults to all modalities plus ‘timestamps’.

  • normalize_timestamps (bool, default=True) – If True, normalizes timestamps relative to recording start.

  • modality_config (dict) – Configuration for each modality including sampling rates, transforms, filters, and interpolation settings. See Notes for structure.

  • seed (int, optional) – Random seed for reproducible shuffling of valid time points.

  • safe_interval_threshold (float, default=0.5) – Safety margin (in seconds) to exclude from edges of valid intervals.

  • interpolate_precision (int, default=5) – Number of decimal places for time precision. Prevents floating-point accumulation errors during interpolation.

data_key

Unique identifier for this dataset, extracted from metadata.

Type:

str

device_names

Names of loaded modalities.

Type:

tuple

start_time

Start of valid time range (after applying safety threshold).

Type:

float

end_time

End of valid time range (after applying safety threshold).

Type:

float

See also

Experiment

Lower-level interface for data access.

get_multisession_dataloader

Load multiple datasets.

experanto.configs

Default configuration values.

Notes

The dataset handles:

  • Per-modality sampling rates and chunk sizes

  • Time offset alignment between modalities

  • Data normalization and transforms

  • Filtering based on trial conditions and data quality

  • Reproducible random sampling with seeds

The modality_config is a nested dictionary with per-modality settings. The following is an example of a modality config for a screen, responses, eye_tracker, and treadmill:

screen:
  sampling_rate: null
  chunk_size: null
  valid_condition:
    tier: test
    stim_type: stimulus.Frame
  offset: 0
  sample_stride: 4
  include_blanks: false
  transforms:
    ToTensor:
      _target_: torchvision.transforms.ToTensor
    Normalize:
      _target_: torchvision.transforms.Normalize
      mean: 80.0
      std: 60.0
    Resize:
      _target_: torchvision.transforms.Resize
      size:
      - 144
      - 256
    CenterCrop:
      _target_: torchvision.transforms.CenterCrop
      size: 144
  interpolation: {}
responses:
  sampling_rate: null
  chunk_size: null
  offset: 0.1
  transforms:
    normalize_variance_only: true # old standardize: true
  interpolation:
    interpolation_mode: nearest_neighbor
eye_tracker:
  sampling_rate: null
  chunk_size: null
  offset: 0
  transforms:
    normalize: true
  interpolation:
    interpolation_mode: nearest_neighbor
treadmill:
  sampling_rate: null
  chunk_size: null
  offset: 0
  transforms:
    normalize: true
  interpolation:
    interpolation_mode: nearest_neighbor

Examples

>>> from experanto.datasets import ChunkDataset
>>> from experanto.configs import DEFAULT_MODALITY_CONFIG
>>> dataset = ChunkDataset(
...     '/path/to/experiment',
...     global_sampling_rate=30,
...     global_chunk_size=60,
...     modality_config=DEFAULT_MODALITY_CONFIG,
... )
>>> len(dataset)
1000
>>> sample = dataset[0]
>>> sample['screen'].shape
torch.Size([1, 60, 144, 256])
>>> sample['responses'].shape
torch.Size([16, 500])

Methods

__init__(root_folder[, ...])

add_channel_function(x)

get_condition_mask_from_meta_conditions(...)

Create a boolean mask for trials satisfying given conditions.

get_data_key_from_root_folder(root_folder)

Extract a data key from the root folder path.

get_full_valid_sample_times([...])

Get all valid chunk starting times based on meta conditions.

get_screen_sample_mask_from_meta_conditions(...)

Create a boolean mask for screen samples satisfying given conditions.

get_state()

Return the current state of the dataset's RNG.

get_valid_intervals_from_filters([visualize])

initialize_statistics()

Initialize normalization statistics for each device.

initialize_transforms()

Initialize data transforms for each device based on modality config.

set_state(state)

Restore the dataset's RNG state.

shuffle_valid_screen_times()

Shuffle valid screen times using the dataset's random number generator for reproducibility.

__init__(root_folder, global_sampling_rate=None, global_chunk_size=None, add_behavior_as_channels=False, replace_nans_with_means=False, cache_data=False, out_keys=None, normalize_timestamps=True, modality_config={'screen': {'keep_nans': False, 'sampling_rate': 30, 'chunk_size': 60, 'valid_condition': {'tier': 'train'}, 'offset': 0, 'sample_stride': 1, 'include_blanks': True, 'transforms': {'normalization': 'normalize', 'Resize': {'_target_': 'torchvision.transforms.v2.Resize', 'size': [144, 256]}}, 'interpolation': {'rescale': True, 'rescale_size': [144, 256]}}, 'responses': {'keep_nans': False, 'sampling_rate': 8, 'chunk_size': 16, 'offset': 0.0, 'transforms': {'normalization': 'normalize_variance_only'}, 'interpolation': {'interpolation_mode': 'nearest_neighbor'}, 'filters': {'nan_filter': {'__target__': 'experanto.filters.common_filters.nan_filter', '__partial__': True, 'vicinity': 0.05}}}, 'eye_tracker': {'keep_nans': False, 'sampling_rate': 30, 'chunk_size': 60, 'offset': 0, 'transforms': {'normalization': 'normalize'}, 'interpolation': {'interpolation_mode': 'nearest_neighbor'}, 'filters': {'nan_filter': {'__target__': 'experanto.filters.common_filters.nan_filter', '__partial__': True, 'vicinity': 0.05}}}, 'treadmill': {'keep_nans': False, 'sampling_rate': 30, 'chunk_size': 60, 'offset': 0, 'transforms': {'normalization': 'normalize'}, 'interpolation': {'interpolation_mode': 'nearest_neighbor'}, 'filters': {'nan_filter': {'__target__': 'experanto.filters.common_filters.nan_filter', '__partial__': True, 'vicinity': 0.05}}}}, seed=None, safe_interval_threshold=0.5, interpolate_precision=5)[source]
initialize_statistics()[source]

Initialize normalization statistics for each device.

Loads mean and standard deviation values from each device’s meta folder and stores them in self._statistics for use during data transforms.

static add_channel_function(x)[source]
initialize_transforms()[source]

Initialize data transforms for each device based on modality config.

get_valid_intervals_from_filters(visualize=False)[source]
get_condition_mask_from_meta_conditions(valid_conditions_sum_of_product)[source]

Create a boolean mask for trials satisfying given conditions.

Parameters:

valid_conditions_sum_of_product (list of dict) – Condition dictionaries combined with OR logic, where conditions within each dictionary use AND logic.

Returns:

Boolean mask indicating which trials satisfy at least one set of conditions.

Return type:

np.ndarray

Notes

For example, [{'tier': 'train', 'stim_type': 'natural'}, {'tier': 'blank'}] matches trials that are either (train AND natural) OR blank.

get_screen_sample_mask_from_meta_conditions(satisfy_for_next, valid_conditions_sum_of_product, filter_for_valid_intervals=True)[source]

Create a boolean mask for screen samples satisfying given conditions.

Parameters:
  • satisfy_for_next (int) – Number of consecutive samples that must satisfy conditions.

  • valid_conditions_sum_of_product (list of dict) – Condition dictionaries combined with OR logic, where conditions within each dictionary use AND logic.

  • filter_for_valid_intervals (bool, default=True) – Whether to apply interval-based filtering.

Returns:

Boolean array matching screen sample times, True where conditions are met.

Return type:

numpy.ndarray

get_full_valid_sample_times(filter_for_valid_intervals=True)[source]

Get all valid chunk starting times based on meta conditions.

Iterates through sample times and checks if they can be used as chunk start times (i.e., the next chunk_size points are all valid based on the previous meta condition filtering).

Parameters:

filter_for_valid_intervals (bool, default=True) – Whether to apply interval-based filtering.

Returns:

Array of valid starting time points.

Return type:

numpy.ndarray

shuffle_valid_screen_times()[source]

Shuffle valid screen times using the dataset’s random number generator for reproducibility.

get_data_key_from_root_folder(root_folder)[source]

Extract a data key from the root folder path.

Checks for a meta.json file and extracts the data_key or scan_key.

Parameters:

root_folder (str or Path) – Path to the root folder containing the dataset.

Returns:

The extracted data key, or folder name if meta.json doesn’t exist or lacks data_key.

Return type:

str

__getitem__(idx)[source]

Return a single data sample at the given index.

Parameters:

idx (int) – Index of the sample to retrieve.

Returns:

Dictionary containing data for each modality in out_keys, e.g.:

  • 'screen': torch.Tensor of shape (C, T, H, W)

  • 'responses': torch.Tensor of shape (T, N_neurons)

  • 'eye_tracker': torch.Tensor of shape (T, N_features)

  • 'treadmill': torch.Tensor of shape (T, N_features)

  • 'timestamps': dict mapping modality names to time arrays

Where T is the chunk size (may differ per modality), C is channels, H is height, W is width.

Return type:

dict

get_state()[source]

Return the current state of the dataset’s RNG.

set_state(state)[source]

Restore the dataset’s RNG state.