experanto.utils.SessionBatchSampler

class SessionBatchSampler(*args, **kwargs)[source]

Bases: Sampler

A batch sampler that cycles through sessions, ensuring each session appears exactly once before repeating any session.

Methods

__init__(dataset, batch_size[, drop_last, ...])

Initialize session batch sampler.

get_session_cycle()

Generate one cycle of sessions, with each session appearing exactly once.

get_state()

Return the state of the sampler (including RNG state).

set_state(state)

Restore the state of the sampler (including RNG state).

__init__(dataset, batch_size, drop_last=False, shuffle=False, seed=None)[source]

Initialize session batch sampler.

Parameters:
  • dataset (SessionConcatDataset) – The dataset to sample from.

  • batch_size (int) – Number of samples per batch.

  • drop_last (bool, optional) – Whether to drop the last batch if smaller than batch_size. Default is False.

  • shuffle (bool, optional) – Whether to shuffle samples within each session. Default is False.

  • seed (int, optional) – Random seed for reproducibility.

__len__()[source]

Return the total number of batches across all sessions.

get_session_cycle()[source]

Generate one cycle of sessions, with each session appearing exactly once. Sessions are shuffled unless their appearance order needs to be controlled.

get_state()[source]

Return the state of the sampler (including RNG state).

set_state(state)[source]

Restore the state of the sampler (including RNG state).