quivers.data.schema

DatasetSchema and the compose helper.

schema

Dataframe-to-QVR schema bridge.

DatasetSchema is the single point that turns "I have a dataframe" into "I have the object cardinalities, the observations dict, and the plate-index tensors a QVR program needs." It accepts pandas, polars, or any other Narwhals-compatible dataframe and emits the two artefacts inference consumes:

  • A .qvr declaration prelude with one object X : N line per declared object axis. N is derived from df[col].n_unique(); the canonical ordering of categories is cached so plate indices are reproducible across reruns.

  • An observations dict mapping observe-site / plate-index names to torch.Tensor values, ready to hand into quivers.inference.MCMC.run or quivers.inference.SVI.step.

The companion compose wraps quivers.dsl.loads so a user can write a .qvr body without spelling out object Verb : 40 when 40 came from a dataframe column anyway.

DatasetSchema

Bases: Model

Mapping from dataframe columns to QVR program artefacts.

ATTRIBUTE DESCRIPTION
df

Source dataframe; pandas, polars, modin, dask, pyarrow, or anything else Narwhals' from_native accepts. Stored as an opaque field so the schema can be serialized without depending on a specific dataframe flavour.

TYPE: Any

objects

Map from column name to the QVR object name. The object's cardinality is inferred from the column's number of unique values; the canonical ordering is the sorted set of unique values, so plate indices are deterministic across reruns.

TYPE: Mapping[str, str]

observations

Map from column name to the QVR observe-site name. Categorical columns are encoded to LongTensor codes (using either their own object's category ordering, when the column is also listed under objects, or a sorted-unique fallback); numeric columns to FloatTensor.

TYPE: Mapping[str, str]

plate_indices

Map from column name (which must also appear under objects) to the per-row plate-index variable name. Encoded as LongTensor of category codes; one entry per row.

TYPE: Mapping[str, str]

covariates

Map from numeric column name to the QVR variable name to bind the column's values to (as a FloatTensor).

TYPE: Mapping[str, str]

missing_policy

Policy applied to every column with nulls. Default quivers.data.encoding.MissingPolicy.RAISE.

TYPE: MissingPolicy

cardinalities

cardinalities() -> Mapping[str, int]

Inferred object cardinalities, keyed by QVR object name.

Source code in src/quivers/data/schema.py
123
124
125
126
127
128
129
130
131
132
@dx.derived
def cardinalities(self) -> Mapping[str, int]:
    """Inferred object cardinalities, keyed by QVR object name."""
    # Touch _nw_df to trigger validation even on schemas that
    # declare no object columns.
    _ = self._nw_df
    return {
        obj_name: len(self._categories[col])
        for col, obj_name in self.objects.items()
    }

categories

categories(column: str) -> tuple[str, ...]

Canonical ordering of values for an object-column.

Codes are assigned as categories.index(value); the ordering is the column's sorted unique non-null values, so the same dataframe always produces the same indices.

Source code in src/quivers/data/schema.py
134
135
136
137
138
139
140
141
142
143
144
145
146
def categories(self, column: str) -> tuple[str, ...]:
    """Canonical ordering of values for an object-column.

    Codes are assigned as ``categories.index(value)``; the
    ordering is the column's sorted unique non-null values, so
    the same dataframe always produces the same indices.
    """
    if column not in self._categories:
        raise KeyError(
            f"DatasetSchema.categories: column {column!r} is not "
            f"declared as an object column"
        )
    return self._categories[column]

declarations

declarations() -> str

Emit a .qvr declaration prelude.

Lines are object <Name> : FinSet <cardinality>, sorted by name for reproducibility. Suitable for prepending to a user's .qvr source via compose.

Source code in src/quivers/data/schema.py
148
149
150
151
152
153
154
155
156
157
158
159
160
def declarations(self) -> str:
    """Emit a ``.qvr`` declaration prelude.

    Lines are ``object <Name> : FinSet <cardinality>``, sorted
    by name for reproducibility. Suitable for prepending to a
    user's ``.qvr`` source via `compose`.
    """
    sorted_objs = sorted(self.objects.items(), key=lambda kv: kv[1])
    lines = [
        f"object {obj_name} : FinSet {self.cardinalities[obj_name]}"
        for _, obj_name in sorted_objs
    ]
    return "\n".join(lines) + ("\n" if lines else "")

observations_dict

observations_dict() -> dict[str, Tensor]

Build the observations dict for inference.

Contains entries for every observation, plate-index, and covariate column. Categorical observations and plate indices use the canonical ordering returned by categories; numeric observations and covariates become FloatTensor.

Source code in src/quivers/data/schema.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def observations_dict(self) -> dict[str, torch.Tensor]:
    """Build the observations dict for inference.

    Contains entries for every observation, plate-index, and
    covariate column.  Categorical observations and plate
    indices use the canonical ordering returned by
    `categories`; numeric observations and covariates
    become ``FloatTensor``.
    """
    result: dict[str, torch.Tensor] = {}

    for col, site in self.observations.items():
        cats: tuple[str, ...] | None = None
        if col in self.objects:
            cats = self._categories[col]
        else:
            dtype = self._nw_df[col].dtype
            if dtype == nw.String:
                cats = tuple(
                    str(v)
                    for v in self._nw_df[col].drop_nulls().unique().sort().to_list()
                )
        result[site] = encode_column(
            self._nw_df,
            col,
            role=ColumnRole.OBSERVATION,
            categories=cats,
            missing_policy=self.missing_policy,
        )

    for col, var in self.plate_indices.items():
        result[var] = encode_column(
            self._nw_df,
            col,
            role=ColumnRole.PLATE_INDEX,
            categories=self._categories[col],
            missing_policy=self.missing_policy,
        )

    for col, var in self.covariates.items():
        result[var] = encode_column(
            self._nw_df,
            col,
            role=ColumnRole.COVARIATE,
            missing_policy=self.missing_policy,
        )

    return result

compose

compose(qvr_body: str, schema: DatasetSchema, **kwargs)

Compile a .qvr body against a dataset schema.

Prepends the schema's object declarations to qvr_body, then calls quivers.dsl.loads. The user writes only the program body (latents, kernels, observations, return); object cardinalities inferred from the dataframe are slotted in automatically. If the body re-declares an object that appears in the schema, the body's declaration wins.

PARAMETER DESCRIPTION
qvr_body

QVR source without the object declarations covered by schema.objects.

TYPE: str

schema

Dataframe schema providing cardinalities.

TYPE: DatasetSchema

**kwargs

Forwarded to quivers.dsl.loads (e.g. data=... for from_data lookups).

DEFAULT: {}

Source code in src/quivers/data/schema.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def compose(qvr_body: str, schema: DatasetSchema, **kwargs):
    """Compile a ``.qvr`` body against a dataset schema.

    Prepends the schema's ``object`` declarations to ``qvr_body``,
    then calls `quivers.dsl.loads`.  The user writes only the
    program body (latents, kernels, observations, return); object
    cardinalities inferred from the dataframe are slotted in
    automatically.  If the body re-declares an object that appears
    in the schema, the body's declaration wins.

    Parameters
    ----------
    qvr_body : str
        QVR source without the ``object`` declarations covered by
        ``schema.objects``.
    schema : DatasetSchema
        Dataframe schema providing cardinalities.
    **kwargs
        Forwarded to `quivers.dsl.loads` (e.g. ``data=...`` for
        ``from_data`` lookups).
    """
    body_declares: set[str] = set()
    for line in qvr_body.splitlines():
        stripped = line.strip()
        if stripped.startswith("object "):
            after = stripped[len("object ") :].split(":")[0].split("=")[0]
            body_declares.add(after.strip())

    prelude_lines = []
    for _, obj_name in sorted(schema.objects.items(), key=lambda kv: kv[1]):
        if obj_name in body_declares:
            continue
        prelude_lines.append(
            f"object {obj_name} : FinSet {schema.cardinalities[obj_name]}"
        )
    prelude = "\n".join(prelude_lines)
    if prelude:
        prelude += "\n\n"
    return loads(prelude + qvr_body, **kwargs)