bead.active_learning¶
Stage 6 of the bead pipeline: active learning with GLMM support and convergence detection.
Active Learning Loop¶
loop
¶
Active learning loop orchestration.
This module orchestrates the iterative active learning loop (stages 3-6): construct items → deploy experiment → collect data → train model → select next items. It manages convergence detection and coordinates all components.
IterationResult
¶
Bases: TypedDict
Results from a single active learning iteration.
Attributes:
| Name | Type | Description |
|---|---|---|
iteration |
int
|
Iteration number. |
selected_items |
list[Item]
|
Items selected for annotation in this iteration. |
model |
TwoAFCModel
|
Updated model after this iteration. |
metadata |
ModelMetadata | None
|
Training metadata if model was retrained, None otherwise. |
ActiveLearningLoop
¶
Orchestrates the active learning loop (stages 3-6).
Manages the iterative process of selecting informative items, training models on collected data, and determining when to stop.
Note: Data collection integration is not yet implemented, so this loop uses placeholder interfaces for deployment and data collection. The focus is on the selection logic and loop orchestration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
item_selector
|
ItemSelector
|
Algorithm for selecting informative items. |
required |
config
|
ActiveLearningLoopConfig | None
|
Configuration object. If None, uses default configuration. |
None
|
Attributes:
| Name | Type | Description |
|---|---|---|
item_selector |
ItemSelector
|
Item selection algorithm. |
config |
ActiveLearningLoopConfig
|
Loop configuration. |
iteration_history |
list[IterationResult]
|
History of all iterations with structured results. |
Examples:
>>> from bead.active_learning.selection import UncertaintySampler
>>> from bead.config.active_learning import ActiveLearningLoopConfig
>>> import numpy as np
>>> selector = UncertaintySampler()
>>> config = ActiveLearningLoopConfig(
... max_iterations=5,
... budget_per_iteration=100
... )
>>> loop = ActiveLearningLoop(
... item_selector=selector,
... config=config
... )
__init__(item_selector: ItemSelector, config: ActiveLearningLoopConfig | None = None) -> None
¶
Initialize active learning loop.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
item_selector
|
ItemSelector
|
Algorithm for selecting items. |
required |
config
|
ActiveLearningLoopConfig | None
|
Configuration object. If None, uses default configuration. |
None
|
run(initial_items: list[Item], initial_model: ActiveLearningModel, item_template: ItemTemplate, unlabeled_pool: list[Item], human_ratings: dict[str, str] | None = None, convergence_detector: ConvergenceDetector | None = None) -> list[ModelMetadata]
¶
Run the complete active learning loop.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
initial_items
|
list[Item]
|
Initial labeled items for training. |
required |
initial_model
|
ActiveLearningModel
|
Model instance to use for active learning. |
required |
item_template
|
ItemTemplate
|
Template used to construct all items. Required for validating model compatibility with task type. |
required |
unlabeled_pool
|
list[Item]
|
Pool of unlabeled items to select from. |
required |
human_ratings
|
dict[str, str] | None
|
Human ratings mapping item_id to option names. |
None
|
convergence_detector
|
ConvergenceDetector | None
|
Detector for checking convergence to human-level performance. If provided, will check convergence after each iteration. |
None
|
Returns:
| Type | Description |
|---|---|
list[ModelMetadata]
|
Metadata for all trained models across iterations. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If stopping_criterion is invalid or threshold not provided when needed. |
Notes
Stopping criteria and performance thresholds are configured via
the config parameter passed to init.
Examples:
>>> from uuid import uuid4
>>> from bead.items.item import Item
>>> from bead.config.active_learning import ActiveLearningLoopConfig
>>> selector = UncertaintySampler()
>>> config = ActiveLearningLoopConfig(max_iterations=3)
>>> loop = ActiveLearningLoop(
... item_selector=selector,
... config=config
... )
>>> # Run would typically be called here with real data
run_iteration(iteration: int, unlabeled_items: list[Item], current_model: ActiveLearningModel) -> IterationResult
¶
Run one iteration of the active learning loop.
Steps: 1. Select informative items using uncertainty sampling 2. (Placeholder) Deploy experiment for data collection 3. (Placeholder) Wait for and collect data 4. (Placeholder) Train new model on augmented dataset 5. Return results
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
iteration
|
int
|
Current iteration number. |
required |
unlabeled_items
|
list[Item]
|
Unlabeled items available for selection. |
required |
current_model
|
ActiveLearningModel
|
Current trained model for making predictions. |
required |
Returns:
| Type | Description |
|---|---|
IterationResult
|
Structured iteration results containing: - iteration: Iteration number - selected_items: List of selected items - model: Updated model - metadata: Training metadata if available |
Examples:
>>> from uuid import uuid4
>>> from bead.items.item import Item
>>> import numpy as np
>>> selector = UncertaintySampler()
>>> loop = ActiveLearningLoop(
... item_selector=selector,
... trainer=None,
... predict_fn=lambda m, i: np.array([0.5, 0.5]),
... max_iterations=5,
... budget_per_iteration=2
... )
>>> items = [
... Item(item_template_id=uuid4(), rendered_elements={})
... for _ in range(5)
... ]
>>> result = loop.run_iteration(0, items, None)
>>> len(result["selected_items"])
2
>>> result["iteration"]
0
check_convergence(metrics_history: list[dict[str, float]], metric_name: str = 'accuracy', patience: int = 3, min_delta: float = 0.01) -> bool
¶
Check if model performance has converged.
Uses early stopping logic: if performance hasn't improved by at least min_delta for patience iterations, consider converged.
For metrics where lower is better (like "loss"), the logic checks if the best (lowest) value is from more than patience iterations ago.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_history
|
list[dict[str, float]]
|
History of metrics from each iteration. |
required |
metric_name
|
str
|
Name of metric to track. |
'accuracy'
|
patience
|
int
|
Number of iterations without improvement before stopping. |
3
|
min_delta
|
float
|
Minimum change to count as improvement. |
0.01
|
Returns:
| Type | Description |
|---|---|
bool
|
True if converged, False otherwise. |
Examples:
>>> loop = ActiveLearningLoop(
... item_selector=UncertaintySampler(),
... trainer=None,
... predict_fn=lambda m, i: np.array([0.5, 0.5])
... )
>>> # Improving performance - not converged
>>> history = [
... {"accuracy": 0.7},
... {"accuracy": 0.75},
... {"accuracy": 0.8}
... ]
>>> loop.check_convergence(history, metric_name="accuracy", patience=2)
False
>>> # No improvement for 3 iterations - converged
>>> history = [
... {"accuracy": 0.8},
... {"accuracy": 0.81},
... {"accuracy": 0.805},
... {"accuracy": 0.81}
... ]
>>> loop.check_convergence(
... history, metric_name="accuracy", patience=3, min_delta=0.02
... )
True
get_summary() -> dict[str, int | dict[str, int]]
¶
Get summary statistics of the active learning loop.
Returns:
| Type | Description |
|---|---|
dict[str, int | dict[str, int]]
|
Summary dictionary with the following keys: total_iterations : int Total number of iterations run. total_items_selected : int Total items selected across all iterations. convergence_info : dict[str, int] Configuration parameters (max_iterations, budget_per_iteration). |
Examples:
selection
¶
Item selectors for active learning.
This module implements sample selection algorithms that use uncertainty strategies to intelligently select the most informative items for labeling in the active learning loop.
ItemSelector
¶
Base class for item selection algorithms.
Item selectors determine which unlabeled items should be selected for annotation in each active learning iteration.
Examples:
select(items: list[Item], model: ActiveLearningModel, predict_fn: Callable[[ActiveLearningModel, Item], np.ndarray], budget: int) -> list[Item]
¶
Select items for annotation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
items
|
list[Item]
|
Unlabeled items to select from. |
required |
model
|
ActiveLearningModel
|
Trained model for making predictions. |
required |
predict_fn
|
Callable[[ActiveLearningModel, Item], ndarray]
|
Function to get prediction probabilities from model. Should return array of shape (n_classes,) with probabilities. |
required |
budget
|
int
|
Number of items to select. |
required |
Returns:
| Type | Description |
|---|---|
list[Item]
|
Selected items for annotation. |
Examples:
UncertaintySampler
¶
Bases: ItemSelector
Uncertainty-based item selector.
Selects items using uncertainty sampling strategies (entropy, margin, or least confidence). This is the main item selection algorithm for active learning in bead.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
UncertaintySamplerConfig | None
|
Configuration for the uncertainty sampler. |
None
|
Attributes:
| Name | Type | Description |
|---|---|---|
config |
UncertaintySamplerConfig
|
Configuration for the sampler. |
strategy |
SamplingStrategy
|
The underlying sampling strategy. |
Examples:
>>> import numpy as np
>>> from uuid import uuid4
>>> from bead.items.item import Item
>>> from bead.config.active_learning import UncertaintySamplerConfig
>>> # Create sampler
>>> config = UncertaintySamplerConfig(method="entropy")
>>> sampler = UncertaintySampler(config=config)
>>> # Mock items
>>> items = [Item(item_template_id=uuid4(), rendered_elements={}) for _ in range(5)]
>>> # Mock model and predict function
>>> def predict_fn(model, item):
... return np.array([0.5, 0.5]) # Mock probabilities
>>> # Select items
>>> selected = sampler.select(items, None, predict_fn, budget=2)
>>> len(selected)
2
__init__(config: UncertaintySamplerConfig | None = None) -> None
¶
Initialize uncertainty sampler.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
UncertaintySamplerConfig | None
|
Configuration for the sampler. If None, uses defaults. |
None
|
select(items: list[Item], model: Any, predict_fn: Callable[[Any, Item], np.ndarray], budget: int) -> list[Item]
¶
Select items using uncertainty sampling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
items
|
list[Item]
|
Unlabeled items to select from. |
required |
model
|
Any
|
Trained model for making predictions. |
required |
predict_fn
|
Callable[[Any, Item], ndarray]
|
Function to get prediction probabilities from model. Should return array of shape (n_classes,) for each item. |
required |
budget
|
int
|
Number of items to select. |
required |
Returns:
| Type | Description |
|---|---|
list[Item]
|
Selected items for annotation, ordered by uncertainty (most to least). |
Raises:
| Type | Description |
|---|---|
ValueError
|
If items list is empty or budget is invalid. |
Examples:
>>> import numpy as np
>>> from uuid import uuid4
>>> from bead.items.item import Item
>>> from bead.config.active_learning import UncertaintySamplerConfig
>>> config = UncertaintySamplerConfig(method="entropy")
>>> sampler = UncertaintySampler(config=config)
>>> items = [
... Item(item_template_id=uuid4(), rendered_elements={"text": "item1"}),
... Item(item_template_id=uuid4(), rendered_elements={"text": "item2"}),
... ]
>>> def predict_fn(model, item):
... # First item is uncertain, second is confident
... if "item1" in item.rendered_elements.get("text", ""):
... return np.array([0.5, 0.5])
... return np.array([0.9, 0.1])
>>> selected = sampler.select(items, None, predict_fn, budget=1)
>>> "item1" in selected[0].rendered_elements["text"]
True
RandomSelector
¶
Bases: ItemSelector
Random item selector (baseline).
Selects items randomly without considering model predictions. Useful as a baseline for comparison with uncertainty-based methods.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
seed
|
int | None
|
Random seed for reproducibility. |
None
|
Attributes:
| Name | Type | Description |
|---|---|---|
rng |
Generator
|
Random number generator. |
Examples:
>>> from uuid import uuid4
>>> from bead.items.item import Item
>>> selector = RandomSelector(seed=42)
>>> items = [
... Item(item_template_id=uuid4(), rendered_elements={})
... for _ in range(10)
... ]
>>> selected = selector.select(items, None, None, budget=3)
>>> len(selected)
3
__init__(seed: int | None = None) -> None
¶
Initialize random selector.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
seed
|
int | None
|
Random seed for reproducibility. |
None
|
select(items: list[Item], model: Any, predict_fn: Callable[[Any, Item], np.ndarray], budget: int) -> list[Item]
¶
Select items randomly.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
items
|
list[Item]
|
Items to select from. |
required |
model
|
Any
|
Model (unused, kept for interface compatibility). |
required |
predict_fn
|
Callable[[Any, Item], ndarray]
|
Prediction function (unused, kept for interface compatibility). |
required |
budget
|
int
|
Number of items to select. |
required |
Returns:
| Type | Description |
|---|---|
list[Item]
|
Randomly selected items. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If items list is empty or budget is invalid. |
Examples:
strategies
¶
Sampling strategies for active learning.
This module implements various uncertainty quantification methods for active learning item selection, including entropy, margin, and least confidence sampling, plus a random baseline.
SamplingStrategy
¶
Bases: ABC
Base class for active learning sampling strategies.
All sampling strategies must implement compute_scores to quantify uncertainty or informativeness of predictions, and select_top_k to select the most informative items.
Examples:
>>> import numpy as np
>>> class MyStrategy(SamplingStrategy):
... def compute_scores(self, probabilities):
... return np.max(probabilities, axis=1)
>>> strategy = MyStrategy()
>>> probs = np.array([[0.7, 0.2, 0.1], [0.4, 0.4, 0.2]])
>>> scores = strategy.compute_scores(probs)
>>> indices = strategy.select_top_k(scores, k=1)
>>> len(indices)
1
compute_scores(probabilities: np.ndarray) -> np.ndarray
abstractmethod
¶
Compute uncertainty scores from prediction probabilities.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
probabilities
|
ndarray
|
Prediction probabilities with shape (n_samples, n_classes). Each row should sum to 1.0. |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Uncertainty scores with shape (n_samples,). Higher scores indicate more informative/uncertain items. |
Examples:
select_top_k(scores: np.ndarray, k: int) -> np.ndarray
¶
Select top k items by score.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
scores
|
ndarray
|
Uncertainty scores with shape (n_samples,). |
required |
k
|
int
|
Number of items to select. |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Indices of top k items with shape (k,). If k > len(scores), returns all indices. If k <= 0, returns empty array. |
Examples:
UncertaintySampling
¶
Bases: SamplingStrategy
Entropy-based uncertainty sampling.
Selects items where the model's prediction entropy is highest, indicating maximum uncertainty across all classes.
Mathematical definition: H(p) = -∑(p_i * log(p_i))
where p is the probability distribution over classes.
Examples:
>>> import numpy as np
>>> strategy = UncertaintySampling()
>>> # Uniform distribution (high entropy)
>>> probs = np.array([[0.33, 0.33, 0.34]])
>>> score = strategy.compute_scores(probs)
>>> score[0] > 1.0 # High uncertainty
True
>>> # Confident prediction (low entropy)
>>> probs = np.array([[0.9, 0.05, 0.05]])
>>> score = strategy.compute_scores(probs)
>>> score[0] < 0.5 # Low uncertainty
True
compute_scores(probabilities: np.ndarray) -> np.ndarray
¶
Compute entropy for each prediction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
probabilities
|
ndarray
|
Prediction probabilities with shape (n_samples, n_classes). |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Entropy scores with shape (n_samples,). Higher entropy indicates more uncertainty. |
Examples:
MarginSampling
¶
Bases: SamplingStrategy
Margin-based uncertainty sampling.
Selects items where the margin between the top two predicted classes is smallest, indicating uncertainty between the two most likely options.
Mathematical definition: margin(p) = 1 - (p₁ - p₂)
where p₁ and p₂ are the highest and second-highest probabilities.
Examples:
>>> import numpy as np
>>> strategy = MarginSampling()
>>> # Small margin (uncertain)
>>> probs = np.array([[0.51, 0.49, 0.0]])
>>> score = strategy.compute_scores(probs)
>>> score[0] > 0.95 # High uncertainty
True
>>> # Large margin (confident)
>>> probs = np.array([[0.9, 0.05, 0.05]])
>>> score = strategy.compute_scores(probs)
>>> score[0] < 0.2 # Low uncertainty
True
compute_scores(probabilities: np.ndarray) -> np.ndarray
¶
Compute margin scores for each prediction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
probabilities
|
ndarray
|
Prediction probabilities with shape (n_samples, n_classes). |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Margin scores with shape (n_samples,). Higher scores indicate smaller margin (more uncertainty). |
Examples:
LeastConfidenceSampling
¶
Bases: SamplingStrategy
Least confidence sampling.
Selects items where the model is least confident, measured as 1 minus the maximum predicted probability.
Mathematical definition: lc(p) = 1 - max(p)
where p is the probability distribution over classes.
Examples:
>>> import numpy as np
>>> strategy = LeastConfidenceSampling()
>>> # Low confidence
>>> probs = np.array([[0.4, 0.3, 0.3]])
>>> score = strategy.compute_scores(probs)
>>> score[0] == 0.6 # 1 - 0.4
True
>>> # High confidence
>>> probs = np.array([[0.95, 0.03, 0.02]])
>>> score = strategy.compute_scores(probs)
>>> score[0] == 0.05 # 1 - 0.95
True
compute_scores(probabilities: np.ndarray) -> np.ndarray
¶
Compute least confidence scores for each prediction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
probabilities
|
ndarray
|
Prediction probabilities with shape (n_samples, n_classes). |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Least confidence scores with shape (n_samples,). Higher scores indicate lower confidence (more uncertainty). |
Examples:
RandomSampling
¶
Bases: SamplingStrategy
Random sampling baseline.
Selects items randomly, serving as a baseline for comparison with uncertainty-based methods. Uses seeded random number generation for reproducibility.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
seed
|
int | None
|
Random seed for reproducibility. If None, uses non-deterministic seed. |
None
|
Attributes:
| Name | Type | Description |
|---|---|---|
rng |
Generator
|
Random number generator. |
Examples:
>>> import numpy as np
>>> strategy = RandomSampling(seed=42)
>>> probs = np.array([[0.9, 0.1], [0.5, 0.5]])
>>> scores = strategy.compute_scores(probs)
>>> len(scores) == 2
True
>>> # Scores are random, not based on probabilities
>>> strategy2 = RandomSampling(seed=42)
>>> scores2 = strategy2.compute_scores(probs)
>>> np.allclose(scores, scores2) # Same seed gives same results
True
compute_scores(probabilities: np.ndarray) -> np.ndarray
¶
Generate random scores for each item.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
probabilities
|
ndarray
|
Prediction probabilities with shape (n_samples, n_classes). Not used in random sampling, but required by interface. |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Random scores with shape (n_samples,). |
Examples:
create_strategy(method: StrategyMethod, seed: int | None = None) -> SamplingStrategy
¶
Create a sampling strategy instance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
method
|
StrategyMethod
|
Strategy method name ("entropy", "margin", "least_confidence", "random"). |
required |
seed
|
int | None
|
Random seed for random strategy. Ignored for other strategies. |
None
|
Returns:
| Type | Description |
|---|---|
SamplingStrategy
|
Instantiated sampling strategy. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If method is not recognized. |
Examples:
>>> strategy = create_strategy("entropy")
>>> isinstance(strategy, UncertaintySampling)
True
>>> strategy = create_strategy("margin")
>>> isinstance(strategy, MarginSampling)
True
>>> strategy = create_strategy("least_confidence")
>>> isinstance(strategy, LeastConfidenceSampling)
True
>>> strategy = create_strategy("random", seed=42)
>>> isinstance(strategy, RandomSampling)
True
Configuration¶
config
¶
Configuration models for mixed effects active learning.
Separated from base.py to avoid circular imports.
VarianceComponents
¶
Bases: BaseModel
Variance-covariance structure for random effects (G matrix in GLMM theory).
Tracks estimated variances for random effects, enabling: - Shrinkage estimation (groups with few samples → prior mean) - Model diagnostics (proportion of variance explained by random effects) - Uncertainty quantification
In GLMM notation: u ~ N(0, G), where G is the variance-covariance matrix. For random intercepts, G is diagonal with entries σ²_u. For random slopes, G can be full (correlated) or diagonal (independent).
Attributes:
| Name | Type | Description |
|---|---|---|
grouping_factor |
str
|
Name of grouping factor (e.g., "participant", "item", "lab"). |
effect_type |
Literal['intercept', 'slope']
|
Type of random effect. |
variance |
float
|
Estimated variance σ² for this random effect. Higher values indicate more heterogeneity across groups. |
n_groups |
int
|
Number of groups (e.g., 50 participants). |
n_observations_per_group |
dict[str, int]
|
Number of observations per group. Used for adaptive regularization and shrinkage. |
Examples:
>>> vc = VarianceComponents(
... grouping_factor="participant",
... effect_type="intercept",
... variance=0.25,
... n_groups=50,
... n_observations_per_group={"p1": 10, "p2": 15}
... )
>>> vc.variance
0.25
RandomEffectsSpec
¶
Bases: BaseModel
Specification of random effects structure.
Inspired by lme4 formula notation: (expr | factor).
Phase 5 (current): Supports single grouping factor (participant). Future phases: Multiple factors, crossed/nested structure.
Attributes:
| Name | Type | Description |
|---|---|---|
grouping_factors |
dict[str, Literal['intercept', 'slope', 'both']]
|
Mapping from grouping factor name to effect type. Phase 5: {"participant": "intercept"} or {"participant": "slope"} Future: {"participant": "intercept", "item": "intercept"} |
correlation_structure |
Literal['independent', 'correlated']
|
If "both" specified: whether intercept and slope are correlated. Independent: G is diagonal. Correlated: G has off-diagonal covariances. |
Examples:
>>> # Random intercepts for participants
>>> spec = RandomEffectsSpec(
... grouping_factors={"participant": "intercept"}
... )
>>> # Random slopes for participants
>>> spec = RandomEffectsSpec(
... grouping_factors={"participant": "slope"}
... )
>>> # Future: Multiple grouping factors
>>> spec = RandomEffectsSpec(
... grouping_factors={"participant": "intercept", "item": "intercept"}
... )
MixedEffectsConfig
¶
Bases: BaseModel
Configuration for mixed effects modeling in active learning.
Based on GLMM theory: y = Xβ + Zu + ε
Where: - Xβ: Fixed effects (population-level parameters, shared across all groups) - Zu: Random effects (group-specific parameters, e.g., per-participant) - u ~ N(0, G): Random effects with variance-covariance matrix G - ε ~ N(0, σ²): Residuals
Attributes:
| Name | Type | Description |
|---|---|---|
mode |
Literal['fixed', 'random_intercepts', 'random_slopes']
|
Modeling mode: - 'fixed': Standard model, no group-specific parameters (Z = 0) - 'random_intercepts': Per-group biases (Z = I, u = bias vectors) - 'random_slopes': Per-group model parameters (Z = I, u = full model heads) |
prior_mean |
float
|
Mean μ₀ of Gaussian prior for random effects initialization. Random effects initialized from N(μ₀, σ²₀). |
prior_variance |
float
|
Variance σ²₀ of Gaussian prior for random effects initialization. Controls initial spread of random effects. |
estimate_variance_components |
bool
|
Whether to estimate variance components (G matrix) during training. If True, returns variance estimates in training metrics. |
variance_estimation_method |
Literal['mle', 'reml']
|
Method for variance component estimation: - 'mle': Maximum Likelihood Estimation - 'reml': Restricted Maximum Likelihood (adjusts for fixed effects) |
regularization_strength |
float
|
Strength λ of regularization pulling random effects toward prior. Loss: L_total = L_data + λ * ||u - μ₀||² |
adaptive_regularization |
bool
|
If True, use stronger regularization for groups with fewer samples. Weight: w_g = 1 / max(n_g, min_samples_for_random_effects) |
min_samples_for_random_effects |
int
|
Minimum samples before estimating group-specific random effects. Below threshold: use prior mean for predictions (shrinkage). |
random_effects_spec |
RandomEffectsSpec | None
|
Advanced: Specification for multiple grouping factors. If None: infer from mode (backward compatibility). Future: Enable item random effects, crossed effects, etc. |
Examples:
>>> # Random intercepts (participant biases)
>>> config = MixedEffectsConfig(
... mode='random_intercepts',
... prior_mean=0.0,
... prior_variance=1.0,
... regularization_strength=0.01
... )
>>> # Random slopes (participant-specific models)
>>> config = MixedEffectsConfig(
... mode='random_slopes',
... prior_variance=0.1,
... adaptive_regularization=True
... )
Base Model Interface¶
base
¶
Base interfaces for active learning models with mixed effects support.
This module implements Generalized Linear Mixed Effects Models (GLMMs) following the standard formulation:
y = Xβ + Zu + ε
Where: - Xβ: Fixed effects (population-level parameters, shared across all groups) - Zu: Random effects (group-specific parameters, e.g., per-participant) - u ~ N(0, G): Random effects with variance-covariance matrix G - ε: Residuals
The implementation supports three modeling modes: 1. Fixed effects: Standard model, ignores grouping structure 2. Random intercepts: Per-group biases (Zu = bias vector per group) 3. Random slopes: Per-group model parameters (Zu = separate model head per group)
References
- Bates et al. (2015). "Fitting Linear Mixed-Effects Models using lme4"
- Simchoni & Rosset (2022). "Integrating Random Effects in Deep Neural Networks"
MixedEffectsConfig
¶
Bases: BaseModel
Configuration for mixed effects modeling in active learning.
Based on GLMM theory: y = Xβ + Zu + ε
Where: - Xβ: Fixed effects (population-level parameters, shared across all groups) - Zu: Random effects (group-specific parameters, e.g., per-participant) - u ~ N(0, G): Random effects with variance-covariance matrix G - ε ~ N(0, σ²): Residuals
Attributes:
| Name | Type | Description |
|---|---|---|
mode |
Literal['fixed', 'random_intercepts', 'random_slopes']
|
Modeling mode: - 'fixed': Standard model, no group-specific parameters (Z = 0) - 'random_intercepts': Per-group biases (Z = I, u = bias vectors) - 'random_slopes': Per-group model parameters (Z = I, u = full model heads) |
prior_mean |
float
|
Mean μ₀ of Gaussian prior for random effects initialization. Random effects initialized from N(μ₀, σ²₀). |
prior_variance |
float
|
Variance σ²₀ of Gaussian prior for random effects initialization. Controls initial spread of random effects. |
estimate_variance_components |
bool
|
Whether to estimate variance components (G matrix) during training. If True, returns variance estimates in training metrics. |
variance_estimation_method |
Literal['mle', 'reml']
|
Method for variance component estimation: - 'mle': Maximum Likelihood Estimation - 'reml': Restricted Maximum Likelihood (adjusts for fixed effects) |
regularization_strength |
float
|
Strength λ of regularization pulling random effects toward prior. Loss: L_total = L_data + λ * ||u - μ₀||² |
adaptive_regularization |
bool
|
If True, use stronger regularization for groups with fewer samples. Weight: w_g = 1 / max(n_g, min_samples_for_random_effects) |
min_samples_for_random_effects |
int
|
Minimum samples before estimating group-specific random effects. Below threshold: use prior mean for predictions (shrinkage). |
random_effects_spec |
RandomEffectsSpec | None
|
Advanced: Specification for multiple grouping factors. If None: infer from mode (backward compatibility). Future: Enable item random effects, crossed effects, etc. |
Examples:
>>> # Random intercepts (participant biases)
>>> config = MixedEffectsConfig(
... mode='random_intercepts',
... prior_mean=0.0,
... prior_variance=1.0,
... regularization_strength=0.01
... )
>>> # Random slopes (participant-specific models)
>>> config = MixedEffectsConfig(
... mode='random_slopes',
... prior_variance=0.1,
... adaptive_regularization=True
... )
RandomEffectsSpec
¶
Bases: BaseModel
Specification of random effects structure.
Inspired by lme4 formula notation: (expr | factor).
Phase 5 (current): Supports single grouping factor (participant). Future phases: Multiple factors, crossed/nested structure.
Attributes:
| Name | Type | Description |
|---|---|---|
grouping_factors |
dict[str, Literal['intercept', 'slope', 'both']]
|
Mapping from grouping factor name to effect type. Phase 5: {"participant": "intercept"} or {"participant": "slope"} Future: {"participant": "intercept", "item": "intercept"} |
correlation_structure |
Literal['independent', 'correlated']
|
If "both" specified: whether intercept and slope are correlated. Independent: G is diagonal. Correlated: G has off-diagonal covariances. |
Examples:
>>> # Random intercepts for participants
>>> spec = RandomEffectsSpec(
... grouping_factors={"participant": "intercept"}
... )
>>> # Random slopes for participants
>>> spec = RandomEffectsSpec(
... grouping_factors={"participant": "slope"}
... )
>>> # Future: Multiple grouping factors
>>> spec = RandomEffectsSpec(
... grouping_factors={"participant": "intercept", "item": "intercept"}
... )
VarianceComponents
¶
Bases: BaseModel
Variance-covariance structure for random effects (G matrix in GLMM theory).
Tracks estimated variances for random effects, enabling: - Shrinkage estimation (groups with few samples → prior mean) - Model diagnostics (proportion of variance explained by random effects) - Uncertainty quantification
In GLMM notation: u ~ N(0, G), where G is the variance-covariance matrix. For random intercepts, G is diagonal with entries σ²_u. For random slopes, G can be full (correlated) or diagonal (independent).
Attributes:
| Name | Type | Description |
|---|---|---|
grouping_factor |
str
|
Name of grouping factor (e.g., "participant", "item", "lab"). |
effect_type |
Literal['intercept', 'slope']
|
Type of random effect. |
variance |
float
|
Estimated variance σ² for this random effect. Higher values indicate more heterogeneity across groups. |
n_groups |
int
|
Number of groups (e.g., 50 participants). |
n_observations_per_group |
dict[str, int]
|
Number of observations per group. Used for adaptive regularization and shrinkage. |
Examples:
>>> vc = VarianceComponents(
... grouping_factor="participant",
... effect_type="intercept",
... variance=0.25,
... n_groups=50,
... n_observations_per_group={"p1": 10, "p2": 15}
... )
>>> vc.variance
0.25
ModelPrediction
¶
Bases: BeadBaseModel
Prediction output for a single item.
Attributes:
| Name | Type | Description |
|---|---|---|
item_id |
str
|
Unique identifier for the item. |
probabilities |
dict[str, float]
|
Predicted probabilities for each class/option. Keys are option names (e.g., "option_a", "option_b") or class labels. |
predicted_class |
str
|
The predicted class/option with highest probability. |
confidence |
float
|
Confidence score (max probability). |
Examples:
>>> prediction = ModelPrediction(
... item_id="abc123",
... probabilities={"option_a": 0.7, "option_b": 0.3},
... predicted_class="option_a",
... confidence=0.7
... )
>>> prediction.predicted_class
'option_a'
ActiveLearningModel
¶
Bases: ABC
Base class for all active learning models with mixed effects support.
Implements GLMM-based active learning: y = Xβ + Zu + ε
All models must: 1. Support mixed effects (fixed, random_intercepts, random_slopes modes) 2. Accept participant_ids in train/predict/predict_proba (None for fixed effects) 3. Validate items match supported task types 4. Track variance components (if estimate_variance_components=True)
Attributes:
| Name | Type | Description |
|---|---|---|
config |
dict[str, str | int | float | bool | None] | BeadBaseModel
|
Model configuration (task-type-specific).
Must include a |
supported_task_types |
list[TaskType]
|
List of task types this model can handle. |
Examples:
>>> class MyModel(ActiveLearningModel):
... def __init__(self, config):
... super().__init__(config) # Validates mixed_effects field
... @property
... def supported_task_types(self):
... return ["forced_choice"]
... def validate_item_compatibility(self, item, item_template):
... pass
... def train(self, items, labels, participant_ids):
... return {}
... def predict(self, items, participant_ids):
... return []
... def predict_proba(self, items, participant_ids):
... return np.array([])
... def save(self, path):
... pass
... def load(self, path):
... pass
supported_task_types: list[TaskType]
abstractmethod
property
¶
__init__(config: dict[str, str | int | float | bool | None] | BeadBaseModel) -> None
¶
Initialize model with configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
Any
|
Model configuration. Must have a |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If config is invalid or missing required fields. |
Examples:
validate_item_compatibility(item: Item, item_template: ItemTemplate) -> None
abstractmethod
¶
Validate that an item is compatible with this model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
item
|
Item
|
Item to validate. |
required |
item_template
|
ItemTemplate
|
Template the item was constructed from. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If item's task_type is not in supported_task_types. |
ValueError
|
If item is missing required elements. |
ValueError
|
If item structure is incompatible with model. |
Examples:
train(items: list[Item], labels: list[str] | list[list[str]], participant_ids: list[str] | None = None, validation_items: list[Item] | None = None, validation_labels: list[str] | list[list[str]] | None = None) -> dict[str, float]
¶
Train model on labeled items with participant identifiers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
items
|
list[Item]
|
Training items. |
required |
labels
|
list[str]
|
Training labels (format depends on task type). |
required |
participant_ids
|
list[str] | None
|
Participant identifier for each item. - For fixed effects (mode='fixed'): Pass None (automatically handled). - For mixed effects (mode='random_intercepts' or 'random_slopes'): Must provide list[str] with same length as items. Must not contain empty strings. |
None
|
validation_items
|
list[Item] | None
|
Optional validation items. |
None
|
validation_labels
|
list[str] | None
|
Optional validation labels. |
None
|
Returns:
| Type | Description |
|---|---|
dict[str, float]
|
Training metrics including: - "train_accuracy", "train_loss": Standard metrics - "participant_variance": σ²_u (if estimate_variance_components=True) - "n_participants": Number of unique participants - "residual_variance": σ²_ε (if estimated) |
Raises:
| Type | Description |
|---|---|
ValueError
|
If participant_ids is None when mode is 'random_intercepts' or 'random_slopes'. |
ValueError
|
If items, labels, and participant_ids have different lengths. |
ValueError
|
If participant_ids contains empty strings. |
ValueError
|
If validation data is incomplete. |
ValueError
|
If labels are invalid for this task type. |
predict(items: list[Item], participant_ids: list[str] | None = None) -> list[ModelPrediction]
¶
Predict class labels for items with participant identifiers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
items
|
list[Item]
|
Items to predict. |
required |
participant_ids
|
list[str] | None
|
Participant identifier for each item. - For fixed effects (mode='fixed'): Pass None. - For mixed effects: Must provide list[str] with same length as items. - For unknown participants: Use population mean (prior) for random effects. |
None
|
Returns:
| Type | Description |
|---|---|
list[ModelPrediction]
|
Predictions with probabilities and predicted class for each item. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If model has not been trained. |
ValueError
|
If participant_ids is None when mode requires mixed effects. |
ValueError
|
If items and participant_ids have different lengths. |
ValueError
|
If participant_ids contains empty strings. |
ValueError
|
If items are incompatible with model. |
predict_proba(items: list[Item], participant_ids: list[str] | None = None) -> np.ndarray
¶
Predict class probabilities for items with participant identifiers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
items
|
list[Item]
|
Items to predict. |
required |
participant_ids
|
list[str] | None
|
Participant identifier for each item. - For fixed effects (mode='fixed'): Pass None. - For mixed effects: Must provide list[str] with same length as items. |
None
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Array of shape (n_items, n_classes) with probabilities. Each row sums to 1.0 for classification tasks. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If model has not been trained. |
ValueError
|
If participant_ids is None when mode requires mixed effects. |
ValueError
|
If items and participant_ids have different lengths. |
ValueError
|
If participant_ids contains empty strings. |
ValueError
|
If items are incompatible with model. |
save(path: str) -> None
¶
Save model to disk.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str
|
File or directory path to save the model. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If model has not been trained. |
load(path: str) -> None
¶
Load model from disk.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str
|
File or directory path to load the model from. |
required |
Raises:
| Type | Description |
|---|---|
FileNotFoundError
|
If model file/directory does not exist. |
Task-Specific Models¶
forced_choice
¶
Model for forced choice tasks (2AFC, 3AFC, 4AFC, nAFC).
ForcedChoiceModel
¶
Bases: ActiveLearningModel
Model for forced_choice tasks with n alternatives.
Supports 2AFC, 3AFC, 4AFC, and general nAFC tasks using any HuggingFace transformer model. Provides two encoding strategies: single encoder (concatenate options) or dual encoder (separate embeddings).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ForcedChoiceModelConfig
|
Configuration object containing all model parameters. |
None
|
Attributes:
| Name | Type | Description |
|---|---|---|
config |
ForcedChoiceModelConfig
|
Model configuration. |
tokenizer |
AutoTokenizer
|
Transformer tokenizer. |
encoder |
AutoModel
|
Transformer encoder model. |
classifier_head |
Sequential
|
Classification head (fixed effects head). |
num_classes |
int | None
|
Number of classes (inferred from training data). |
option_names |
list[str] | None
|
Option names (e.g., ["option_a", "option_b"]). |
random_effects |
RandomEffectsManager
|
Manager for participant-level random effects. |
variance_history |
list[VarianceComponents]
|
Variance component estimates over training (for diagnostics). |
_is_fitted |
bool
|
Whether model has been trained. |
Examples:
>>> from uuid import uuid4
>>> from bead.items.item import Item
>>> from bead.config.active_learning import ForcedChoiceModelConfig
>>> items = [
... Item(
... item_template_id=uuid4(),
... rendered_elements={"option_a": "sentence A", "option_b": "sentence B"}
... )
... for _ in range(10)
... ]
>>> labels = ["option_a"] * 5 + ["option_b"] * 5
>>> config = ForcedChoiceModelConfig(
... num_epochs=1, batch_size=2, device="cpu"
... )
>>> model = ForcedChoiceModel(config=config)
>>> metrics = model.train(items, labels)
>>> predictions = model.predict(items[:3])
supported_task_types: list[TaskType]
property
¶
Get supported task types.
Returns:
| Type | Description |
|---|---|
list[TaskType]
|
List containing "forced_choice". |
__init__(config: ForcedChoiceModelConfig | None = None) -> None
¶
Initialize forced choice model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ForcedChoiceModelConfig | None
|
Configuration object. If None, uses default configuration. |
None
|
validate_item_compatibility(item: Item, item_template: ItemTemplate) -> None
¶
Validate item is compatible with forced choice model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
item
|
Item
|
Item to validate. |
required |
item_template
|
ItemTemplate
|
Template the item was constructed from. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If task_type is not "forced_choice". |
ValueError
|
If task_spec.options is not defined. |
ValueError
|
If item is missing required rendered_elements. |
ordinal_scale
¶
Ordinal scale model for ordered rating scales (Likert, sliders, etc.).
Implements truncated normal distribution for bounded continuous responses on [0, 1]. Supports GLMM with participant-level random effects (intercepts and slopes).
OrdinalScaleModel
¶
Bases: ActiveLearningModel
Model for ordinal_scale tasks with bounded continuous responses.
Uses truncated normal distribution on [scale_min, scale_max] to model slider/Likert responses while properly handling endpoints (0 and 1). Supports three modes: fixed effects, random intercepts, random slopes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
OrdinalScaleModelConfig
|
Configuration object containing all model parameters. |
None
|
Attributes:
| Name | Type | Description |
|---|---|---|
config |
OrdinalScaleModelConfig
|
Model configuration. |
tokenizer |
AutoTokenizer
|
Transformer tokenizer. |
encoder |
AutoModel
|
Transformer encoder model. |
regression_head |
Sequential
|
Regression head (fixed effects head) - outputs continuous μ. |
random_effects |
RandomEffectsManager
|
Manager for participant-level random effects. |
variance_history |
list[VarianceComponents]
|
Variance component estimates over training (for diagnostics). |
_is_fitted |
bool
|
Whether model has been trained. |
Examples:
>>> from uuid import uuid4
>>> from bead.items.item import Item
>>> from bead.config.active_learning import OrdinalScaleModelConfig
>>> items = [
... Item(
... item_template_id=uuid4(),
... rendered_elements={"text": f"Sentence {i}"}
... )
... for i in range(10)
... ]
>>> labels = ["0.3", "0.7"] * 5 # Continuous values as strings
>>> config = OrdinalScaleModelConfig(
... num_epochs=1, batch_size=2, device="cpu"
... )
>>> model = OrdinalScaleModel(config=config)
>>> metrics = model.train(items, labels, participant_ids=None)
>>> predictions = model.predict(items[:3], participant_ids=None)
supported_task_types: list[TaskType]
property
¶
Get supported task types.
Returns:
| Type | Description |
|---|---|
list[TaskType]
|
List containing "ordinal_scale". |
__init__(config: OrdinalScaleModelConfig | None = None) -> None
¶
Initialize ordinal scale model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
OrdinalScaleModelConfig | None
|
Configuration object. If None, uses default configuration. |
None
|
validate_item_compatibility(item: Item, item_template: ItemTemplate) -> None
¶
Validate item is compatible with ordinal scale model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
item
|
Item
|
Item to validate. |
required |
item_template
|
ItemTemplate
|
Template the item was constructed from. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If task_type is not "ordinal_scale". |
binary
¶
Binary model for yes/no or true/false judgments.
Expected architecture: Binary classification with 2-class output. Different from 2AFC in semantics - represents absolute judgment rather than choice.
BinaryModel
¶
Bases: ActiveLearningModel
Model for binary tasks (yes/no, true/false judgments).
Uses true binary classification with a single output unit and sigmoid activation (logistic regression). This is more efficient than using 2-class softmax, as we only need to output P(y=1) and compute P(y=0) = 1 - P(y=1).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
BinaryModelConfig
|
Configuration object containing all model parameters. |
None
|
Attributes:
| Name | Type | Description |
|---|---|---|
config |
BinaryModelConfig
|
Model configuration. |
tokenizer |
AutoTokenizer
|
Transformer tokenizer. |
encoder |
AutoModel
|
Transformer encoder model. |
classifier_head |
Sequential
|
Classification head (fixed effects head) - outputs single logit. |
num_classes |
int
|
Number of output units (always 1 for binary classification). |
label_names |
list[str] | None
|
Label names (e.g., ["no", "yes"] sorted alphabetically). |
positive_class |
str | None
|
Which label corresponds to y=1 (second alphabetically). |
random_effects |
RandomEffectsManager
|
Manager for participant-level random effects (scalar biases). |
variance_history |
list[VarianceComponents]
|
Variance component estimates over training (for diagnostics). |
_is_fitted |
bool
|
Whether model has been trained. |
Examples:
>>> from uuid import uuid4
>>> from bead.items.item import Item
>>> from bead.config.active_learning import BinaryModelConfig
>>> items = [
... Item(
... item_template_id=uuid4(),
... rendered_elements={"text": f"Sentence {i}"}
... )
... for i in range(10)
... ]
>>> labels = ["yes"] * 5 + ["no"] * 5
>>> config = BinaryModelConfig(
... num_epochs=1, batch_size=2, device="cpu"
... )
>>> model = BinaryModel(config=config)
>>> metrics = model.train(items, labels, participant_ids=None)
>>> predictions = model.predict(items[:3], participant_ids=None)
Notes
This model uses BCEWithLogitsLoss instead of CrossEntropyLoss, and applies sigmoid activation to get probabilities. Random intercepts are scalar values (1-dimensional) that shift the logit for each participant.
supported_task_types: list[TaskType]
property
¶
Get supported task types.
Returns:
| Type | Description |
|---|---|
list[TaskType]
|
List containing "binary". |
__init__(config: BinaryModelConfig | None = None) -> None
¶
Initialize binary model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
BinaryModelConfig | None
|
Configuration object. If None, uses default configuration. |
None
|
validate_item_compatibility(item: Item, item_template: ItemTemplate) -> None
¶
Validate item is compatible with binary model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
item
|
Item
|
Item to validate. |
required |
item_template
|
ItemTemplate
|
Template the item was constructed from. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If task_type is not "binary". |
categorical
¶
Model for categorical tasks (unordered N-class classification).
CategoricalModel
¶
Bases: ActiveLearningModel
Model for categorical tasks with N unordered categories.
Supports N-class classification (N ≥ 2) using any HuggingFace transformer model. Provides two encoding strategies: single encoder (concatenate categories) or dual encoder (separate embeddings).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
CategoricalModelConfig
|
Configuration object containing all model parameters. |
None
|
Attributes:
| Name | Type | Description |
|---|---|---|
config |
CategoricalModelConfig
|
Model configuration. |
tokenizer |
AutoTokenizer
|
Transformer tokenizer. |
encoder |
AutoModel
|
Transformer encoder model. |
classifier_head |
Sequential
|
Classification head (fixed effects head). |
num_classes |
int | None
|
Number of classes (inferred from training data). |
category_names |
list[str] | None
|
Category names (e.g., ["entailment", "neutral", "contradiction"]). |
random_effects |
RandomEffectsManager
|
Manager for participant-level random effects. |
variance_history |
list[VarianceComponents]
|
Variance component estimates over training (for diagnostics). |
_is_fitted |
bool
|
Whether model has been trained. |
Examples:
>>> from uuid import uuid4
>>> from bead.items.item import Item
>>> from bead.config.active_learning import CategoricalModelConfig
>>> items = [
... Item(
... item_template_id=uuid4(),
... rendered_elements={"premise": "sent A", "hypothesis": "sent B"}
... )
... for _ in range(10)
... ]
>>> labels = ["entailment"] * 5 + ["contradiction"] * 5
>>> config = CategoricalModelConfig(
... num_epochs=1, batch_size=2, device="cpu"
... )
>>> model = CategoricalModel(config=config)
>>> metrics = model.train(items, labels, participant_ids=None)
>>> predictions = model.predict(items[:3], participant_ids=None)
supported_task_types: list[TaskType]
property
¶
Get supported task types.
Returns:
| Type | Description |
|---|---|
list[TaskType]
|
List containing "categorical". |
__init__(config: CategoricalModelConfig | None = None) -> None
¶
Initialize categorical model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
CategoricalModelConfig | None
|
Configuration object. If None, uses default configuration. |
None
|
validate_item_compatibility(item: Item, item_template: ItemTemplate) -> None
¶
Validate item is compatible with categorical model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
item
|
Item
|
Item to validate. |
required |
item_template
|
ItemTemplate
|
Template the item was constructed from. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If task_type is not "categorical". |
multi_select
¶
Multi-select model for selecting multiple options.
Expected architecture: Multi-label classification with sigmoid output per option. Each option can be independently selected or not selected.
MultiSelectModel
¶
Bases: ActiveLearningModel
Model for multi_select tasks with N selectable options.
Uses multi-label classification where each option can be independently selected or not selected. Applies sigmoid activation to each option's logit and uses BCEWithLogitsLoss for training.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
MultiSelectModelConfig
|
Configuration object containing all model parameters. |
None
|
Attributes:
| Name | Type | Description |
|---|---|---|
config |
MultiSelectModelConfig
|
Model configuration. |
tokenizer |
AutoTokenizer
|
Transformer tokenizer. |
encoder |
AutoModel
|
Transformer encoder model. |
classifier_head |
Sequential
|
Classification head (fixed effects head) - outputs N logits. |
num_options |
int | None
|
Number of selectable options (inferred from training data). |
option_names |
list[str] | None
|
Option names (e.g., ["option_a", "option_b", "option_c"]). |
random_effects |
RandomEffectsManager
|
Manager for participant-level random effects. |
variance_history |
list[VarianceComponents]
|
Variance component estimates over training (for diagnostics). |
_is_fitted |
bool
|
Whether model has been trained. |
Examples:
>>> from uuid import uuid4
>>> from bead.items.item import Item
>>> from bead.config.active_learning import MultiSelectModelConfig
>>> items = [
... Item(
... item_template_id=uuid4(),
... rendered_elements={
... "option_a": "First option",
... "option_b": "Second option",
... "option_c": "Third option"
... }
... )
... for _ in range(10)
... ]
>>> # Labels as lists of selected options
>>> labels_list = [["option_a", "option_b"], ["option_c"], ["option_a"]]
>>> labels = labels_list * 3 + [["option_b"]]
>>> config = MultiSelectModelConfig(
... num_epochs=1, batch_size=2, device="cpu"
... )
>>> model = MultiSelectModel(config=config)
>>> # Convert labels to serialized format for train()
>>> label_strs = [json.dumps(sorted(lbls)) for lbls in labels]
>>> metrics = model.train(items, label_strs, participant_ids=None)
Notes
This model uses BCEWithLogitsLoss (not CrossEntropyLoss) and applies sigmoid activation to get independent probabilities for each option. Random intercepts are bias vectors (one per option) that shift logits independently for each participant.
supported_task_types: list[TaskType]
property
¶
Get supported task types.
Returns:
| Type | Description |
|---|---|
list[TaskType]
|
List containing "multi_select". |
__init__(config: MultiSelectModelConfig | None = None) -> None
¶
Initialize multi-select model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
MultiSelectModelConfig | None
|
Configuration object. If None, uses default configuration. |
None
|
validate_item_compatibility(item: Item, item_template: ItemTemplate) -> None
¶
Validate item is compatible with multi-select model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
item
|
Item
|
Item to validate. |
required |
item_template
|
ItemTemplate
|
Template the item was constructed from. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If task_type is not "multi_select". |
magnitude
¶
Magnitude model for unbounded and bounded numeric judgments.
Implements continuous regression with support for: - Unbounded values: Normal distribution N(μ, σ²) - Bounded values: Truncated Normal distribution N(μ, σ) T[min, max] Supports GLMM with participant-level random effects (intercepts and slopes).
MagnitudeModel
¶
Bases: ActiveLearningModel
Model for magnitude tasks with unbounded or bounded continuous responses.
Uses Normal distribution for unbounded values (e.g., reading time, plausibility) or Truncated Normal for bounded values (e.g., confidence on 0-100 scale). Supports three modes: fixed effects, random intercepts, random slopes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
MagnitudeModelConfig
|
Configuration object containing all model parameters. |
None
|
Attributes:
| Name | Type | Description |
|---|---|---|
config |
MagnitudeModelConfig
|
Model configuration. |
tokenizer |
AutoTokenizer
|
Transformer tokenizer. |
encoder |
AutoModel
|
Transformer encoder model. |
regression_head |
Sequential
|
Regression head (fixed effects head) - outputs continuous μ. |
random_effects |
RandomEffectsManager
|
Manager for participant-level random effects. |
variance_history |
list[VarianceComponents]
|
Variance component estimates over training (for diagnostics). |
_is_fitted |
bool
|
Whether model has been trained. |
Examples:
>>> from uuid import uuid4
>>> from bead.items.item import Item
>>> from bead.config.active_learning import MagnitudeModelConfig
>>> items = [
... Item(
... item_template_id=uuid4(),
... rendered_elements={"text": f"Sentence {i}"}
... )
... for i in range(10)
... ]
>>> labels = ["250.5", "300.2"] * 5 # Reading times in ms
>>> config = MagnitudeModelConfig(
... num_epochs=1, batch_size=2, device="cpu"
... )
>>> model = MagnitudeModel(config=config)
>>> metrics = model.train(items, labels, participant_ids=None)
>>> predictions = model.predict(items[:3], participant_ids=None)
supported_task_types: list[TaskType]
property
¶
Get supported task types.
Returns:
| Type | Description |
|---|---|
list[TaskType]
|
List containing "magnitude". |
__init__(config: MagnitudeModelConfig | None = None) -> None
¶
Initialize magnitude model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
MagnitudeModelConfig | None
|
Configuration object. If None, uses default configuration. |
None
|
validate_item_compatibility(item: Item, item_template: ItemTemplate) -> None
¶
Validate item is compatible with magnitude model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
item
|
Item
|
Item to validate. |
required |
item_template
|
ItemTemplate
|
Template the item was constructed from. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If task_type is not "magnitude". |
free_text
¶
Free text model for open-ended text generation with GLMM support.
Implements seq2seq generation with participant-level random effects using: - Random intercepts: Bias on decoder output logits (token probability shifts) - Random slopes: LoRA adapters on decoder attention layers
Architecture: T5-base or BART-base encoder-decoder model
FreeTextModel
¶
Bases: ActiveLearningModel
Model for free_text tasks with participant-level random effects.
Uses seq2seq architecture (T5 or BART) with three modes: - Fixed effects: Standard encoder-decoder - Random intercepts: Participant-specific bias on output logits - Random slopes: Participant-specific LoRA adapters on decoder
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
FreeTextModelConfig
|
Configuration object containing all model parameters. |
None
|
Attributes:
| Name | Type | Description |
|---|---|---|
config |
FreeTextModelConfig
|
Model configuration. |
tokenizer |
AutoTokenizer
|
Seq2seq tokenizer. |
model |
AutoModelForSeq2SeqLM
|
Base seq2seq model (T5 or BART). |
encoder |
Module
|
Encoder module. |
base_decoder |
Module
|
Base decoder module (shared across participants in fixed/random_intercepts). |
lm_head |
Module
|
Language modeling head (projects decoder output to vocabulary). |
random_effects |
RandomEffectsManager
|
Manager for participant-level random effects. |
variance_history |
list[VarianceComponents]
|
Variance component estimates over training. |
_is_fitted |
bool
|
Whether model has been trained. |
Examples:
>>> from uuid import uuid4
>>> from bead.items.item import Item
>>> from bead.config.active_learning import FreeTextModelConfig
>>> items = [
... Item(
... item_template_id=uuid4(),
... rendered_elements={"prompt": "Summarize: The cat sat."}
... )
... for _ in range(10)
... ]
>>> labels = ["Cat sits."] * 10
>>> config = FreeTextModelConfig(
... num_epochs=1, batch_size=2, device="cpu"
... )
>>> model = FreeTextModel(config=config)
>>> metrics = model.train(items, labels, participant_ids=None)
supported_task_types: list[TaskType]
property
¶
Get supported task types.
Returns:
| Type | Description |
|---|---|
list[TaskType]
|
List containing "free_text". |
__init__(config: FreeTextModelConfig | None = None) -> None
¶
Initialize free text model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
FreeTextModelConfig | None
|
Configuration object. If None, uses default configuration. |
None
|
validate_item_compatibility(item: Item, item_template: ItemTemplate) -> None
¶
Validate item is compatible with free text model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
item
|
Item
|
Item to validate. |
required |
item_template
|
ItemTemplate
|
Template the item was constructed from. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If task_type is not "free_text". |
cloze
¶
Cloze model for fill-in-the-blank tasks with GLMM support.
Implements masked language modeling with participant-level random effects for predicting tokens at unfilled slots in partially-filled templates. Supports three modes: fixed effects, random intercepts, random slopes.
Architecture: Masked LM (BERT/RoBERTa) for token prediction
ClozeModel
¶
Bases: ActiveLearningModel
Model for cloze tasks with participant-level random effects.
Uses masked language modeling (BERT/RoBERTa) to predict tokens at unfilled slots in partially-filled templates. Supports three GLMM modes: - Fixed effects: Standard MLM - Random intercepts: Participant-specific bias on output logits - Random slopes: Participant-specific MLM heads
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ClozeModelConfig
|
Configuration object containing all model parameters. |
None
|
Attributes:
| Name | Type | Description |
|---|---|---|
config |
ClozeModelConfig
|
Model configuration. |
tokenizer |
AutoTokenizer
|
Masked LM tokenizer. |
model |
AutoModelForMaskedLM
|
Masked language model (BERT or RoBERTa). |
encoder |
Module
|
Encoder module from the model. |
mlm_head |
Module
|
MLM prediction head. |
random_effects |
RandomEffectsManager
|
Manager for participant-level random effects. |
variance_history |
list[VarianceComponents]
|
Variance component estimates over training. |
_is_fitted |
bool
|
Whether model has been trained. |
Examples:
>>> from uuid import uuid4
>>> from bead.items.item import Item, UnfilledSlot
>>> from bead.config.active_learning import ClozeModelConfig
>>> items = [
... Item(
... item_template_id=uuid4(),
... rendered_elements={"text": "The cat ___."},
... unfilled_slots=[
... UnfilledSlot(slot_name="verb", position=2, constraint_ids=[])
... ]
... )
... for _ in range(6)
... ]
>>> labels = [["ran"], ["jumped"], ["slept"]] * 2 # One token per unfilled slot
>>> config = ClozeModelConfig(
... num_epochs=1, batch_size=2, device="cpu"
... )
>>> model = ClozeModel(config=config)
>>> metrics = model.train(items, labels, participant_ids=None)
supported_task_types: list[TaskType]
property
¶
Get supported task types.
Returns:
| Type | Description |
|---|---|
list[TaskType]
|
List containing "cloze". |
__init__(config: ClozeModelConfig | None = None) -> None
¶
Initialize cloze model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ClozeModelConfig | None
|
Configuration object. If None, uses default configuration. |
None
|
validate_item_compatibility(item: Item, item_template: ItemTemplate) -> None
¶
Validate item is compatible with cloze model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
item
|
Item
|
Item to validate. |
required |
item_template
|
ItemTemplate
|
Template the item was constructed from. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If task_type is not "cloze". |
ValueError
|
If item has no unfilled_slots. |
Random Effects and LoRA¶
random_effects
¶
Manager for random effects in GLMM-based active learning.
Implements: - Random effect storage and retrieval (intercepts and slopes) - Variance component estimation (G matrix via MLE/REML) - Empirical Bayes shrinkage for small groups - Adaptive regularization based on sample counts - Save/load with variance component history
RandomEffectsManager
¶
Manages random effects following GLMM theory: u ~ N(0, G).
Core responsibilities: 1. Store random effect values: u_i for each participant i 2. Estimate variance components: σ²_u (the G matrix) 3. Implement shrinkage: u_shrunk_i = λ_i * u_i + (1-λ_i) * μ_0 4. Compute prior loss: L_prior = λ * Σ_i w_i * ||u_i - μ_0||² 5. Handle unknown participants: Use population mean (μ_0)
Attributes:
| Name | Type | Description |
|---|---|---|
config |
MixedEffectsConfig
|
Configuration including mode, priors, regularization. |
intercepts |
dict[str, Tensor]
|
Random intercepts per participant. Key: participant_id, Value: bias vector of shape (n_classes,) |
slopes |
dict[str, Module]
|
Random slopes per participant. Key: participant_id, Value: model head (nn.Module) |
participant_sample_counts |
dict[str, int]
|
Training samples per participant (for adaptive regularization). |
variance_components |
VarianceComponents | None
|
Latest variance component estimates. |
variance_history |
list[VarianceComponents]
|
Variance components over training (for diagnostics). |
Examples:
>>> config = MixedEffectsConfig(mode='random_intercepts')
>>> manager = RandomEffectsManager(config, n_classes=3)
>>> # Register participants during training
>>> manager.register_participant("alice", n_samples=10)
>>> manager.register_participant("bob", n_samples=15)
>>> # Get intercepts (creates if missing)
>>> bias_alice = manager.get_intercepts("alice", n_classes=3)
>>> # Estimate variance components after training
>>> var_comp = manager.estimate_variance_components()
>>> print(f"σ²_u = {var_comp.variance:.3f}")
__init__(config: MixedEffectsConfig, **kwargs: Any) -> None
¶
Initialize random effects manager.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
MixedEffectsConfig
|
GLMM configuration. |
required |
**kwargs
|
Any
|
Additional arguments (e.g., n_classes, hidden_dim). Required arguments depend on mode. |
{}
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If mode='random_slopes' but required kwargs missing. |
register_participant(participant_id: str, n_samples: int) -> None
¶
Register participant and track sample count.
Used for: - Adaptive regularization (fewer samples → stronger regularization) - Shrinkage estimation (fewer samples → shrink toward mean) - Variance component estimation
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
participant_id
|
str
|
Participant identifier. |
required |
n_samples
|
int
|
Number of samples for this participant. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If participant_id empty or n_samples not positive. |
Examples:
get_intercepts(participant_id: str, n_classes: int, param_name: str, create_if_missing: bool = True) -> torch.Tensor
¶
Get random intercepts for specific distribution parameter.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
participant_id
|
str
|
Participant identifier. |
required |
n_classes
|
int
|
Number of classes (length of bias vector). |
required |
param_name
|
str
|
Name of the distribution parameter (e.g., "mu", "cutpoint_1", "cutpoint_2"). |
required |
create_if_missing
|
bool
|
Whether to create new intercepts for unknown participants. True: Training (create new random effects) False: Prediction (use prior mean for unknown) |
True
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Bias vector of shape (n_classes,). |
Raises:
| Type | Description |
|---|---|
ValueError
|
If mode is not 'random_intercepts'. |
Examples:
get_intercepts_with_shrinkage(participant_id: str, n_classes: int, param_name: str = 'bias') -> torch.Tensor
¶
Get random intercepts with Empirical Bayes shrinkage.
Implements shrinkage toward population mean:
u_shrunk_i = λ_i * u_mle_i + (1 - λ_i) * μ_0
where: λ_i = n_i / (n_i + k) k ≈ σ²_ε / σ²_u (ratio of residual to random effect variance)
For participants with few samples, shrink toward μ_0 (population mean). For participants with many samples, use their specific estimate.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
participant_id
|
str
|
Participant identifier. |
required |
n_classes
|
int
|
Number of classes. |
required |
param_name
|
str
|
Name of the distribution parameter. |
"bias"
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Shrunk bias vector of shape (n_classes,). |
Examples:
get_slopes(participant_id: str, fixed_head: nn.Module, create_if_missing: bool = True) -> nn.Module
¶
Get random slopes (model head) for participant.
Behavior: - Known participant: Return learned head - Unknown participant: - If create_if_missing=True: Clone fixed_head and add noise - If create_if_missing=False: Return clone of fixed_head
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
participant_id
|
str
|
Participant identifier. |
required |
fixed_head
|
Module
|
Fixed effects head to clone for new participants. |
required |
create_if_missing
|
bool
|
Whether to create new slopes for unknown participants. |
True
|
Returns:
| Type | Description |
|---|---|
Module
|
Model head for this participant. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If mode is not 'random_slopes'. |
Examples:
estimate_variance_components() -> dict[str, VarianceComponents] | None
¶
Estimate variance components (G matrix) from random effects.
Returns:
| Type | Description |
|---|---|
dict[str, VarianceComponents] | None
|
Dictionary mapping param_name -> VarianceComponents. For single-parameter models (most common), returns dict with one key. For multi-parameter models (e.g., ordered beta), returns dict with multiple keys. Returns None if mode='fixed' or no random_slopes. |
Examples:
compute_prior_loss() -> torch.Tensor
¶
Compute regularization loss toward prior.
Implements adaptive regularization:
L_prior = λ * Σ_i w_i * ||u_i - μ_0||²
where: w_i = 1 / max(n_i, min_samples) (adaptive weighting) λ = regularization_strength
Participants with fewer samples get stronger regularization. This prevents overfitting when participant has little data.
For multi-parameter random effects, sums over all parameters.
Returns:
| Type | Description |
|---|---|
Tensor
|
Scalar regularization loss to add to training loss. |
Examples:
save(path: Path) -> None
¶
Save random effects to disk.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
Path
|
Directory to save random effects. |
required |
load(path: Path, fixed_head: nn.Module | None = None) -> None
¶
Load random effects from disk.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
Path
|
Directory to load from. |
required |
fixed_head
|
Module | None
|
Fixed head (required if mode='random_slopes'). |
None
|
Raises:
| Type | Description |
|---|---|
FileNotFoundError
|
If path doesn't exist. |
ValueError
|
If mode='random_slopes' but fixed_head is None. |
Examples:
lora
¶
LoRA (Low-Rank Adaptation) implementation for transformer personalization.
Implements participant-specific low-rank updates to attention layers for efficient parameter-efficient fine-tuning (PEFT) in the GLMM framework.
References
- Hu et al. (2021): "LoRA: Low-Rank Adaptation of Large Language Models" https://arxiv.org/abs/2106.09685
- Microsoft LoRA: https://github.com/microsoft/LoRA
LoRALayer
¶
Bases: Module
Low-rank adaptation layer for attention projections.
Implements: ΔW = (α/r) * B @ A where: - B ∈ ℝ^(in_features × rank) - A ∈ ℝ^(rank × out_features) - r is the rank (much smaller than in_features, out_features) - α is a scaling factor
This additive update is applied to frozen base weights: W' = W + ΔW
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input dimension. |
required |
out_features
|
int
|
Output dimension. |
required |
rank
|
int
|
LoRA rank r. Typical values: 4-16. |
8
|
alpha
|
float
|
Scaling factor α. Typically 2*rank. |
16.0
|
dropout
|
float
|
Dropout probability for LoRA path. |
0.1
|
Attributes:
| Name | Type | Description |
|---|---|---|
lora_A |
Parameter
|
First low-rank matrix, shape (in_features, rank). Initialized with Kaiming uniform. |
lora_B |
Parameter
|
Second low-rank matrix, shape (rank, out_features). Initialized with zeros (so ΔW = 0 initially). |
scaling |
float
|
Computed as α/r. |
Examples:
>>> lora = LoRALayer(768, 768, rank=8, alpha=16.0)
>>> x = torch.randn(2, 10, 768) # (batch, seq_len, in_features)
>>> delta = lora(x) # (batch, seq_len, out_features)
>>> delta.shape
torch.Size([2, 10, 768])
__init__(in_features: int, out_features: int, rank: int = 8, alpha: float = 16.0, dropout: float = 0.1) -> None
¶
Initialize LoRA layer.
forward(x: Tensor) -> Tensor
¶
Apply LoRA: x @ (A @ B) * scaling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor, shape (batch, seq_len, in_features). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
LoRA output, shape (batch, seq_len, out_features). |
LoRALinear
¶
Bases: Module
Linear layer with LoRA adaptation.
Wraps a frozen linear layer and adds trainable low-rank updates. Forward pass: output = base_layer(x) + lora(x)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base_layer
|
Linear
|
The original linear layer to adapt (will be frozen). |
required |
rank
|
int
|
LoRA rank r. |
8
|
alpha
|
float
|
LoRA scaling factor α. |
16.0
|
dropout
|
float
|
Dropout for LoRA path. |
0.1
|
Attributes:
| Name | Type | Description |
|---|---|---|
base_layer |
Linear
|
Frozen base linear layer. |
lora |
LoRALayer
|
Low-rank adaptation layer. |
Examples:
>>> base = nn.Linear(768, 768)
>>> lora_linear = LoRALinear(base, rank=8)
>>> x = torch.randn(2, 10, 768)
>>> out = lora_linear(x)
>>> out.shape
torch.Size([2, 10, 768])
__init__(base_layer: nn.Linear, rank: int = 8, alpha: float = 16.0, dropout: float = 0.1) -> None
¶
Initialize LoRA linear layer.
forward(x: Tensor) -> Tensor
¶
Forward pass: base output + LoRA adaptation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor, shape (batch, seq_len, in_features). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Output with LoRA adaptation, shape (batch, seq_len, out_features). |
ParticipantLoRAAdapter
¶
Bases: Module
Participant-specific LoRA adapters for seq2seq decoder.
Injects LoRA layers into specified target modules (typically query and value projections in attention layers). Used for random slopes mode in GLMM.
This class wraps a decoder module and applies participant-specific low-rank adaptations to attention projections.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
decoder
|
Module
|
The decoder module to adapt (e.g., T5 decoder, BART decoder). |
required |
rank
|
int
|
LoRA rank r. |
required |
alpha
|
float
|
LoRA scaling factor α. |
required |
dropout
|
float
|
Dropout for LoRA layers. |
required |
target_modules
|
list[str]
|
Names of modules to inject LoRA into (e.g., ["q_proj", "v_proj"]). |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
decoder |
Module
|
The adapted decoder (with LoRA layers injected). |
lora_layers |
dict[str, LoRALinear]
|
Mapping from module name to LoRA linear layer. |
Examples:
>>> from transformers import AutoModelForSeq2SeqLM
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
>>> decoder = model.get_decoder()
>>> adapter = ParticipantLoRAAdapter(
... decoder,
... rank=8,
... alpha=16.0,
... target_modules=["q", "v"] # T5 uses "q" and "v"
... )
>>> # adapter.decoder now has LoRA layers injected
__init__(decoder: nn.Module, rank: int, alpha: float, dropout: float, target_modules: list[str]) -> None
¶
Initialize participant LoRA adapter.
forward(input_ids: Tensor, attention_mask: Tensor | None = None) -> Tensor
¶
Forward pass through decoder with LoRA.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_ids
|
Tensor
|
Input token IDs, shape (batch_size, seq_len). |
required |
attention_mask
|
Tensor | None
|
Attention mask, shape (batch_size, seq_len). If None, no masking. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Decoder output tensor. |
get_lora_parameters() -> list[nn.Parameter]
¶
Get all LoRA parameters for optimization.
Returns:
| Type | Description |
|---|---|
list[Parameter]
|
List of all trainable LoRA parameters (A and B matrices). |
create_participant_lora_adapter(base_decoder: nn.Module, rank: int, alpha: float, dropout: float, target_modules: list[str]) -> ParticipantLoRAAdapter
¶
Create a participant LoRA adapter.
Creates a deep copy of the base decoder and injects LoRA layers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base_decoder
|
Module
|
Base decoder to copy and adapt. |
required |
rank
|
int
|
LoRA rank. |
required |
alpha
|
float
|
LoRA scaling factor. |
required |
dropout
|
float
|
LoRA dropout. |
required |
target_modules
|
list[str]
|
Target modules for LoRA injection. |
required |
Returns:
| Type | Description |
|---|---|
ParticipantLoRAAdapter
|
New adapter with LoRA injected into copied decoder. |