quivers.data.encoding

Column-level encoding utilities and the ColumnRole / MissingPolicy enums.

encoding

Column-level encoding utilities: dtype dispatch, missing-data policies, and the role enum that classifies what a column is for in a probabilistic program (object axis, observed site, plate index, covariate).

ColumnRole

Bases: str, Enum

How a dataframe column participates in a QVR program.

MissingPolicy

Bases: str, Enum

How to handle NaN / null entries when encoding a column.

encode_column

encode_column(df: DataFrame, column: str, *, role: ColumnRole, categories: tuple[str, ...] | None = None, missing_policy: MissingPolicy = RAISE) -> Tensor

Encode a single column into a torch.Tensor ready for QVR inference.

PARAMETER DESCRIPTION
df

Narwhals-wrapped dataframe.

TYPE: DataFrame

column

Column to encode.

TYPE: str

role

How the column participates in the program. PLATE_INDEX and OBJECT columns require a categories tuple for reproducible code assignment.

TYPE: ColumnRole

categories

Canonical ordering of categorical values; if provided, codes are assigned by categories.index(value). Required for PLATE_INDEX and for OBSERVATION of a non-numeric column. None is allowed for numeric OBSERVATION / COVARIATE columns.

TYPE: tuple[str, ...] or None DEFAULT: None

missing_policy

Policy for NaN / null handling.

TYPE: MissingPolicy DEFAULT: RAISE

RETURNS DESCRIPTION
Tensor

LongTensor for categorical encodings, FloatTensor otherwise.

Source code in src/quivers/data/encoding.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def encode_column(
    df: nw.DataFrame,
    column: str,
    *,
    role: ColumnRole,
    categories: tuple[str, ...] | None = None,
    missing_policy: MissingPolicy = MissingPolicy.RAISE,
) -> torch.Tensor:
    """Encode a single column into a ``torch.Tensor`` ready for
    QVR inference.

    Parameters
    ----------
    df : nw.DataFrame
        Narwhals-wrapped dataframe.
    column : str
        Column to encode.
    role : ColumnRole
        How the column participates in the program. ``PLATE_INDEX``
        and ``OBJECT`` columns require a categories tuple for
        reproducible code assignment.
    categories : tuple[str, ...] or None
        Canonical ordering of categorical values; if provided, codes
        are assigned by ``categories.index(value)``. Required for
        ``PLATE_INDEX`` and for ``OBSERVATION`` of a non-numeric
        column. ``None`` is allowed for numeric ``OBSERVATION`` /
        ``COVARIATE`` columns.
    missing_policy : MissingPolicy
        Policy for ``NaN`` / null handling.

    Returns
    -------
    torch.Tensor
        ``LongTensor`` for categorical encodings, ``FloatTensor``
        otherwise.
    """
    series = df[column]
    dtype = series.dtype
    is_numeric = _is_numeric_dtype(dtype)
    null_count = series.is_null().sum()

    if null_count > 0:
        if missing_policy == MissingPolicy.RAISE:
            raise ValueError(
                f"column {column!r} has {null_count} missing values "
                f"but missing_policy={MissingPolicy.RAISE.value}"
            )
        if missing_policy == MissingPolicy.DROP:
            raise ValueError(
                f"column {column!r}: MissingPolicy.DROP requires the "
                f"caller to pre-filter the dataframe; this function "
                f"encodes the column as given"
            )
        if missing_policy == MissingPolicy.IMPUTE:
            if is_numeric:
                fill = series.mean()
            else:
                # Modal value: take the value with the highest count.
                counts = series.drop_nulls().value_counts(name="_count_")
                fill = counts.sort("_count_", descending=True)[column][0]
            series = series.fill_null(fill)
        # MASK falls through; NaN -> NaN for numeric, -1 code for
        # categorical (handled below).

    if role in (ColumnRole.PLATE_INDEX, ColumnRole.OBJECT):
        if categories is None:
            raise ValueError(
                f"encode_column: role={role.value} requires a "
                f"categories ordering for column {column!r}"
            )
        cat_index = {c: i for i, c in enumerate(categories)}
        values = series.to_list()
        codes = [cat_index[v] if v is not None else -1 for v in values]
        return torch.tensor(codes, dtype=torch.long)

    if role == ColumnRole.OBSERVATION and not is_numeric:
        if categories is None:
            raise ValueError(
                f"encode_column: non-numeric observation column "
                f"{column!r} requires a categories ordering"
            )
        cat_index = {c: i for i, c in enumerate(categories)}
        values = series.to_list()
        codes = [cat_index[v] if v is not None else -1 for v in values]
        return torch.tensor(codes, dtype=torch.long)

    # Numeric observation or covariate path.
    values = series.to_list()
    return torch.tensor(
        [float("nan") if v is None else float(v) for v in values],
        dtype=torch.float32,
    )