Skip to content

API Reference

stochatreat.stochatreat

Stratified random assignment of treatments to units.

This module provides a function to assign treatments to units in a stratified manner. The function is designed to work with pandas dataframes and is able to handle multiple strata. There are also different strategies to deal with misfits (units that are left over after the stratified assignment procedure).

stochatreat(data: pd.DataFrame, stratum_cols: list[str] | str, treats: int, probs: list[float] | None = None, random_state: int | None = 42, idx_col: str | None = None, size: int | None = None, misfit_strategy: MisfitStrategy = 'stratum') -> pd.DataFrame

Assign treatments to units in a stratified manner.

Takes a dataframe and an arbitrary number of treatments over an arbitrary number of strata.

Attempts to return equally sized treatment groups, while randomly assigning misfits (left overs from strata not divisible by the number of treatments).

Parameters:

Name Type Description Default
data DataFrame

The data that contains unique ids and the stratification columns.

required
stratum_cols list[str] | str

The columns in 'data' that you want to stratify over.

required
treats int

The number of treatments you would like to implement, including control.

required
probs list[float] | None

The assignment probabilities for each of the treatments.

None
random_state int | None

The seed for the rng instance.

42
idx_col str | None

The column name that indicates the ids for your data.

None
size int | None

The size of the sample if you would like to sample from your data.

None
misfit_strategy MisfitStrategy

The strategy used to assign misfits. One of 'stratum' (default) — assign misfits randomly within each stratum using probs; 'global' — pool all misfits across strata and assign together; 'none' — leave misfits unassigned (treat = NA, stratum_id = NA) for manual handling.

'stratum'

Returns:

Type Description
DataFrame

pandas.DataFrame with idx_col, treat (treatment assignments) and

DataFrame

stratum_id (the id of the stratum within which the assignment

DataFrame

procedure was carried out) columns. Both treat and stratum_id use

DataFrame

pandas nullable integer types (Int64) to support NA values.

Examples:

Single stratum:

>>> treats = stochatreat(data=data,               # your dataframe
                         stratum_cols='stratum1', # stratum variable
                         treats=2,                # including control
                         idx_col='myid',          # unique id column
                         random_state=42)         # seed for rng
>>> data = data.merge(treats, how="left", on="myid")

Multiple strata:

>>> treats = stochatreat(data=data,
                         stratum_cols=['stratum1', 'stratum2'],
                         treats=2,
                         probs=[1/3, 2/3],
                         idx_col='myid',
                         random_state=42)
>>> data = data.merge(treats, how="left", on="myid")
Source code in src/stochatreat/stochatreat.py
def stochatreat(
    data: pd.DataFrame,
    stratum_cols: list[str] | str,
    treats: int,
    probs: list[float] | None = None,
    random_state: int | None = 42,
    idx_col: str | None = None,
    size: int | None = None,
    misfit_strategy: MisfitStrategy = "stratum",
) -> pd.DataFrame:
    """Assign treatments to units in a stratified manner.

    Takes a dataframe and an arbitrary number of treatments over an
    arbitrary number of strata.

    Attempts to return equally sized treatment groups, while randomly
    assigning misfits (left overs from strata not divisible by the number
    of treatments).

    Args:
        data: The data that contains unique ids and the stratification columns.
        stratum_cols: The columns in 'data' that you want to stratify over.
        treats: The number of treatments you would like to implement,
            including control.
        probs: The assignment probabilities for each of the treatments.
        random_state: The seed for the rng instance.
        idx_col: The column name that indicates the ids for your data.
        size: The size of the sample if you would like to sample from your
            data.
        misfit_strategy: The strategy used to assign misfits. One of
            'stratum' (default) — assign misfits randomly within each stratum
            using probs; 'global' — pool all misfits across strata and assign
            together; 'none' — leave misfits unassigned (treat = NA,
            stratum_id = NA) for manual handling.

    Returns:
        pandas.DataFrame with idx_col, treat (treatment assignments) and
        stratum_id (the id of the stratum within which the assignment
        procedure was carried out) columns. Both treat and stratum_id use
        pandas nullable integer types (Int64) to support NA values.

    Examples:
        Single stratum:

        >>> treats = stochatreat(data=data,               # your dataframe
                                 stratum_cols='stratum1', # stratum variable
                                 treats=2,                # including control
                                 idx_col='myid',          # unique id column
                                 random_state=42)         # seed for rng
        >>> data = data.merge(treats, how="left", on="myid")

        Multiple strata:

        >>> treats = stochatreat(data=data,
                                 stratum_cols=['stratum1', 'stratum2'],
                                 treats=2,
                                 probs=[1/3, 2/3],
                                 idx_col='myid',
                                 random_state=42)
        >>> data = data.merge(treats, how="left", on="myid")

    """
    spec = TreatmentSpec(treats, probs, random_state)
    preparator = DataPreparator(
        data, stratum_cols, idx_col, size, random_state
    )
    prepared_data, resolved_idx_col = preparator.prepare()
    handler = make_misfit_handler(misfit_strategy)
    prepared_data = handler.handle(
        prepared_data, spec.lcm_prob_denominators, random_state
    )
    assigner = TreatmentAssigner(spec)
    return assigner.assign(prepared_data, resolved_idx_col)