Source code for experanto.experiment

from __future__ import annotations

import logging
import warnings
from pathlib import Path

import numpy as np
from hydra.utils import instantiate
from omegaconf import DictConfig

from .configs import DEFAULT_MODALITY_CONFIG
from .interpolators import Interpolator

logger = logging.getLogger(__name__)


[docs] class Experiment: """High-level interface for loading and querying neuroscience experiments. An Experiment represents a single recording session containing multiple modalities (e.g., visual stimuli, neural responses, behavioral data). Each modality is loaded as an Interpolator, allowing data to be queried at arbitrary time points. Parameters ---------- root_folder : str Path to the experiment directory. Should contain subdirectories for each modality (e.g., ``screen/``, ``responses/``, ``eye_tracker/``). modality_config : dict, optional Configuration dictionary specifying interpolation and processing settings for each modality. See :mod:`experanto.configs` for the default configuration structure. cache_data : bool, default=False If True, loads all trial data into memory for faster access. Useful for smaller datasets or when memory is not a constraint. Attributes ---------- devices : dict Dictionary mapping device names to their :class:`Interpolator` instances. start_time : float Earliest valid timestamp across all devices. end_time : float Latest valid timestamp across all devices. See Also -------- ChunkDataset : Higher-level interface for ML training. Interpolator : Base class for modality-specific interpolators. Examples -------- >>> from experanto.experiment import Experiment >>> exp = Experiment('/path/to/experiment') >>> exp.device_names ('screen', 'responses', 'eye_tracker') >>> times = np.linspace(0, 10, 100) >>> data = exp.interpolate(times, device='responses') """
[docs] def __init__( self, root_folder: str, modality_config: dict = DEFAULT_MODALITY_CONFIG, cache_data: bool = False, ) -> None: self.root_folder = Path(root_folder) self.devices = {} self.start_time = np.inf self.end_time = -np.inf self.modality_config = modality_config self.cache_data = cache_data self._load_devices()
def _load_devices(self) -> None: # Populate devices by going through subfolders # Assumption: blocks are sorted by start time device_folders = [d for d in self.root_folder.iterdir() if d.is_dir()] for d in device_folders: if d.name not in self.modality_config: logger.info("Skipping %s data", d.name) continue logger.info("Parsing %s data", d.name) # Get interpolation config for this device interp_conf = self.modality_config[d.name]["interpolation"] if ( isinstance(interp_conf, (dict, DictConfig)) and "_target_" in interp_conf ): # Custom interpolator (Hydra instantiates it) dev = instantiate( interp_conf, root_folder=d, cache_data=self.cache_data ) # Check if instantiated object is proper Interpolator if not isinstance(dev, Interpolator): raise ValueError( "Instantiated object must inherit from Interpolator class." ) elif isinstance(interp_conf, Interpolator): dev = interp_conf else: # Default back to original logic warnings.warn( "Falling back to original Interpolator creation logic.", UserWarning, stacklevel=2, ) dev = Interpolator.create( d, cache_data=self.cache_data, **interp_conf, # type: ignore[arg-type] ) if ( dev.start_time is None or dev.end_time is None or not np.isfinite(dev.start_time) or not np.isfinite(dev.end_time) ): logger.warning( "Device %s has undefined start_time or end_time and will be " "excluded from the experiment-wide time range.", d.name, ) else: self.start_time = min(self.start_time, dev.start_time) self.end_time = max(self.end_time, dev.end_time) self.devices[d.name] = dev logger.info("Parsing finished") if not self.devices: raise ValueError( "Experiment time range could not be determined: no devices with valid start_time and end_time were found." ) elif self.start_time > self.end_time: raise ValueError( "Experiment time range could not be determined: at least one device " "must define finite start_time and end_time." ) @property def device_names(self): return tuple(self.devices.keys())
[docs] def interpolate( self, times: np.ndarray, device: str | Interpolator | None = None, return_valid: bool = False, ) -> tuple[dict, dict] | dict | tuple[np.ndarray, np.ndarray] | np.ndarray: """Interpolate data from one or all devices at specified time points. Parameters ---------- times : array_like 1D array of time points (in seconds) at which to interpolate. device : str, optional Name of a specific device to interpolate. If None, interpolates all devices and returns dictionaries. Returns ------- values : numpy.ndarray or dict If ``device`` is specified, returns the interpolated data array for the valid time points only (shape is modality-dependent, see below). Otherwise, returns a dict mapping device names to their data arrays. valid : numpy.ndarray or dict, optional Only present when ``return_valid=True``. Integer index array(s) into ``times`` indicating which entries were used to produce ``values``. ``values[i]`` corresponds to ``times[valid[i]]``, and ``len(valid) == values.shape[0]``. When a dict is returned, ``valid`` is a dict with the same keys and ``len(valid[d])`` may differ across devices because each modality has its own valid range. Notes ----- Output shapes per modality: * Sequence modalities (``responses``, ``eye_tracker``, ``treadmill``): ``(n_valid, n_signals)`` * Screen modality: ``(n_valid, H, W)`` for grayscale, ``(n_valid, H, W, C)`` for colour. Examples -------- Interpolate a single device: >>> data, valid = exp.interpolate(times, device='responses', return_valid=True) >>> data.shape (n_valid, 500) # n_valid <= len(times), 500 neurons >>> times[valid].shape == (data.shape[0],) True Interpolate all devices: >>> data = exp.interpolate(times) >>> data.keys() dict_keys(['screen', 'responses', 'eye_tracker']) """ if device is None: values, valid = {}, {} for d, interp in self.devices.items(): res = interp.interpolate(times, return_valid=return_valid) if return_valid: vals, vlds = res values[d], valid[d] = vals, vlds else: values[d] = res return (values, valid) if return_valid else values elif isinstance(device, str): if device not in self.devices: raise KeyError(f"Unknown device '{device}'") return self.devices[device].interpolate(times, return_valid=return_valid) raise ValueError(f"Unsupported device type: {type(device)}")
[docs] def get_valid_range(self, device_name: str) -> tuple[float, float]: """Get the valid time range for a specific device. Parameters ---------- device_name : str Name of the device (e.g., 'screen', 'responses'). Returns ------- tuple A tuple `(start_time, end_time)` representing the valid time interval in seconds. """ return tuple(self.devices[device_name].valid_interval)