Source code for experanto.filters.common_filters

import numpy as np

from experanto.interpolators import SequenceInterpolator
from experanto.intervals import (
    TimeInterval,
    find_complement_of_interval_array,
    uniquefy_interval_array,
)


[docs] def nan_filter(vicinity=0.05): """Create a filter that excludes time regions around NaN values. Returns a closure that, given a :class:`~experanto.interpolators.SequenceInterpolator`, identifies all time points containing NaN in any channel and marks a symmetric window of ``vicinity`` seconds around each as invalid. Parameters ---------- vicinity : float, optional Half-width of the exclusion window in seconds around each NaN time point. Default is 0.05. Returns ------- callable A function that takes a :class:`~experanto.interpolators.SequenceInterpolator` and returns a list of :class:`~experanto.intervals.TimeInterval` representing the valid (NaN-free) portions of the recording. """ def implementation(device_: SequenceInterpolator): # Requires a SequenceInterpolator since it relies on time_delta, # which other interpolator types do not expose. time_delta = device_.time_delta start_time = device_.start_time end_time = device_.end_time data = device_._data # (T, n_features) nan_mask = np.isnan(data) # (T, n_features) nan_mask = np.any(nan_mask, axis=1) # (T,) # Find indices where nan_mask is True nan_indices = np.where(nan_mask)[0] # Create invalid TimeIntervals around each nan point invalid_intervals = [] for idx in nan_indices: time_point = start_time + idx * time_delta interval_start = max(start_time, time_point - vicinity) interval_end = min(end_time, time_point + vicinity) invalid_intervals.append(TimeInterval(interval_start, interval_end)) # Merge overlapping invalid intervals invalid_intervals = uniquefy_interval_array(invalid_intervals) # Find the complement of invalid intervals to get valid intervals valid_intervals = find_complement_of_interval_array( start_time, end_time, invalid_intervals ) return valid_intervals return implementation