import numpy as np
from experanto.interpolators import SequenceInterpolator
from experanto.intervals import (
TimeInterval,
find_complement_of_interval_array,
uniquefy_interval_array,
)
[docs]
def nan_filter(vicinity=0.05):
"""Create a filter that excludes time regions around NaN values.
Returns a closure that, given a :class:`~experanto.interpolators.SequenceInterpolator`,
identifies all time points containing NaN in any channel and marks a
symmetric window of ``vicinity`` seconds around each as invalid.
Parameters
----------
vicinity : float, optional
Half-width of the exclusion window in seconds around each NaN
time point. Default is 0.05.
Returns
-------
callable
A function that takes a
:class:`~experanto.interpolators.SequenceInterpolator` and returns
a list of :class:`~experanto.intervals.TimeInterval` representing
the valid (NaN-free) portions of the recording.
"""
def implementation(device_: SequenceInterpolator):
# Requires a SequenceInterpolator since it relies on time_delta,
# which other interpolator types do not expose.
time_delta = device_.time_delta
start_time = device_.start_time
end_time = device_.end_time
data = device_._data # (T, n_features)
nan_mask = np.isnan(data) # (T, n_features)
nan_mask = np.any(nan_mask, axis=1) # (T,)
# Find indices where nan_mask is True
nan_indices = np.where(nan_mask)[0]
# Create invalid TimeIntervals around each nan point
invalid_intervals = []
for idx in nan_indices:
time_point = start_time + idx * time_delta
interval_start = max(start_time, time_point - vicinity)
interval_end = min(end_time, time_point + vicinity)
invalid_intervals.append(TimeInterval(interval_start, interval_end))
# Merge overlapping invalid intervals
invalid_intervals = uniquefy_interval_array(invalid_intervals)
# Find the complement of invalid intervals to get valid intervals
valid_intervals = find_complement_of_interval_array(
start_time, end_time, invalid_intervals
)
return valid_intervals
return implementation