Should I balance class weights on logistic regression?

To balance or not to balance?

Despite totally off-topic cover image, in this post I will share and destroy my newbieish feeling that balancing weights in scikit-learn logistic regression might help dealing with imbalanced datasets.

What it does?

We can weight our log-loss / cross-entropy loss depending on the class like this:

$$ LogLoss = - \frac{1}{N} \sum \limits_{n=1} ^{N} w_{n} · \Bigg [y_{n}\log\hat{y}_{n} + (1 - y_{n})\log(1-\hat{y}_{n}) \Bigg ], $$

where $w_{n}$ is the weigth associated with $y_{n}$ class. If we want to "fix" the class imbalance, $w_{under-represented} > w_{over-represented}$ so that a training example from the under-represented class "weights more" than one from the over-represented class.

I am citing scikit-learn class_weight='balanced' below:

The “balanced” mode uses the values of y to automatically adjust weights inversely proportional to class frequencies in the input data as n_samples / (n_classes * np.bincount(y)).

Just don't

It is useful in some scenarios. Like when the class balance in your dataset is not representative of the true population but you want your model to be (it can also be weighted during the fit). However, in other scenarios, e.g. when your dataset is supposed to be representative but has class imbalance, it might look as useful but it is not. Actually, if something, it will be detrimental for the model. I'll give some context.

Some boring context

I came across an already ongoing project where there was a logistic regression going on (from now on I'll use LR). There was a grid of parameters being explored for its implementation in scikit-learn. While reviewing the code, I took the opportunity to remove parameters that I felt they were not worth exploring. Noteworthy, there was the kwarg class_weight set to 'balanced' and because we had to reach a high sensitivity and the odds of the event were like 1:4 I thought

" meh, just leave it, it won't hurt"

But it looks that I was wrong.

Do balancing data in LogReg (when sampling is fine) helps ?

No.
Reading/googling about it I found a cool entry in Chadler Zuo's site. In that link he explores the benefits of weighting or subsampling data when fitting LR on (synthetic) imbalanced data. Which is: NO IMPROVEMENT in terms of discriminability. Few years later he published a preprint about the topic. There's also this research article, which I honestly did not read, but it looks like its core idea is the one I show below.

TLDR?

When our sampling methodology is correct, using class_weight = 'balanced' will not increase auROC and it is likely to cause havoc on our model's calibration. To meet specific requirements (sensitivity, specificity, etc.) just move the cutoff point instead of weighting log-loss so that final threshold is 0.5

Minimal experiment with synthetic data.

We will:

  • generate some synthetic data
  • compare two LogReg models, one naive and the other with balanced weights:
    • auROC
    • Threshold for sensitivity >95% and the resulting specificity
    • Calibration curves
    • Decision curve analysis

Generate data

In [121]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
from sklearn.metrics import RocCurveDisplay, confusion_matrix, roc_auc_score, f1_score
from sklearn.calibration import CalibrationDisplay, calibration_curve


def gen_dataset(random_state=0):
    """ returns synthetic X, y """
    return make_classification(
        n_samples=10_000,
        n_features=10,
        n_informative=10,
        n_redundant=0,
        weights=[0.75],
        random_state=0
    )
X, y = gen_dataset()
print(f'resulting True class fraction is {y.mean()}')
resulting True class fraction is 0.2531

ROC and auc

In [111]:
def split_and_train(X, y, random_state=0):
    """ return data splits and naive & balanced models already trained """
    # because its synthetic nature, we can skip scaling, etc.
    # We will also skip parameter gridsearch (using defaults below.)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, stratify=y, random_state=random_state)

    naive = LogisticRegression().fit(X_train, y_train)
    balanced = LogisticRegression(class_weight='balanced').fit(X_train, y_train)
    return X_train, X_test, y_train, y_test, naive, balanced

X_train, X_test, y_train, y_test, naive, balanced = split_and_train(X, y)

f, ax = plt.subplots()
for model, label in [[naive,'naive'],[balanced, 'balanced']]:
    RocCurveDisplay.from_estimator(
        model, X_test, y_test, ax=ax, name=label
    )
plt.show()

They show almost the same ROC and associated AUC.

What about sensitivity and specificity? And the predicted probability distribution?

Given that we aim at getting a 95% sensitivity, search for the lowest threshold that fulfills it, and retrieve the corresponding specificity. We will also show $\hat{y}$ distribution and the calibration curves (a.k.a predicted vs real probabilities, binned)

In [112]:
def calc_sens_esp(gt, proba, thr):
    """returns sens and specificity for a given thr"""
    preds = (proba >= thr) * 1
    tn, fp, fn, tp = confusion_matrix(gt, preds).ravel()
    sensibility = tp / (tp+fn)
    specificity = tn / (tn+fp)
    return sensibility, specificity

def find_best_thr(gt, proba, sens_thr=.95):
    """ gets the best threshold to satisfy sens thr,
    proba is (n,) dimensional, with probs of label=1
    """
    prev_thr = proba.min()
    for thr in np.unique(proba): # np.unique sorts
        sens, _ = calc_sens_esp(gt, proba, thr)
        if sens < sens_thr:
            return prev_thr
        prev_thr = thr
    
    return thr


f, ax = plt.subplots(ncols=2, figsize=(12,6))

for model, label in [[naive,'naive'],[balanced, 'balanced']]:
    predicted_probability = model.predict_proba(X_test)[:,1]
    best_thr = find_best_thr(y_test, predicted_probability)
    sens, spec = calc_sens_esp(y_test, predicted_probability, best_thr)
    print(
        f'{label} model best threshold was {best_thr:.3g} ' 
        f'with {sens:.3g} sensitivity and {spec:.3g} specificity')

    ax[0].hist(
        predicted_probability, bins=np.linspace(0,1,21), 
        label=label, alpha=.5)
    
    CalibrationDisplay.from_estimator(
        model, X_test, y_test, name=label, ax=ax[1]
    )
    
ax[0].set_ylabel('counts')
ax[0].set_xlabel('$\hat{y}$ distribution')
ax[0].legend()
plt.show()
naive model best threshold was 0.0946 with 0.951 sensitivity and 0.329 specificity
balanced model best threshold was 0.235 with 0.951 sensitivity and 0.35 specificity

As we can see, specificity tends to be slightly lower in naive model. Thus, one might be tempted to use balanced model. However, if we have imbalanced data, is it OK for our model to output "normal-shaped" predictions? If the model does a good job predicting the probabilities for a relatively uncommon event to happen, they should be skewed, as well.

With calibration curves we can clearly see that 'balanced' model overestimates $p(y)$ because predicted probability > actual probability in all the bins.

But I just care about sensitivity!

Well then, go ahead. Remember though that $p(y)$ estimate is wrong and because having same discriminability the modest improvement in specificity in this particular setup will come-back elsewhere. Also, the resources I shared above about the topic, may have points stronger than mine.

Run the experiment several times

So we can draw some bands/errors.
I will stack some metrics taking naive model's as a reference. Hence, $\Delta auROC = auROC_{balanced} - auROC_{naive}$, etc.
For the calibration curves, we will use a custom function to control bins so we stack the same thing when aggregating them. Remember that items per bin might differ quite a lot across bins/curves.

In [120]:
from statsmodels.stats.proportion import proportion_confint 
from scipy.stats import binned_statistic
from tqdm.notebook import tqdm

def custom_calib_curve(y_pred, y_true, bins=np.linspace(0,1,6)):
    args = (y_pred, y_true)
    kwargs = dict(bins=bins, range=[0,1])
    obs, _, _  = binned_statistic(*args, **kwargs, statistic='count')
    counts, _, _ = binned_statistic(*args, **kwargs, statistic='sum')
    average, _, _ = binned_statistic(*args, **kwargs, statistic='mean')
    # not a consistent naming, but using porportion_confint semantics
    return obs, counts, average

results = {
    '$\Delta$ auROC': [],
    '$\Delta$ f1': [],
    '$\Delta$ specificity': [],
    'obs_naive': [],
    'counts_naive': [],
    'average_naive': [],
    'obs_balanced': [],
    'counts_balanced': [],
    'average_balanced': [],
}
for i in tqdm(range(250)):
    X, y = gen_dataset(random_state=i)
    X_train, X_test, y_train, y_test, naive, balanced = split_and_train(
        X, y, random_state=i)
    proba_naive = naive.predict_proba(X_test)[:,1]
    thr_naive = find_best_thr(y_test, proba_naive)
    _, specificity_naive = calc_sens_esp(y_test, proba_naive, thr_naive)
    proba_balanced = balanced.predict_proba(X_test)[:,1]
    thr_balanced = find_best_thr(y_test, proba_balanced)
    _, specificity_balanced = calc_sens_esp(y_test, proba_balanced, thr_balanced)



    results['$\Delta$ auROC'].append(
        roc_auc_score(y_test, proba_balanced) - roc_auc_score(y_test, proba_naive)
    ) 
    results['$\Delta$ f1'].append(
        f1_score(y_test, (proba_balanced >= thr_balanced)*1) -\
        f1_score(y_test, (proba_naive >= thr_naive)*1)
    )
    results['$\Delta$ specificity'].append(
        specificity_balanced - specificity_naive
    )
    obs, counts, average = custom_calib_curve(proba_naive, y_test)
    results['obs_naive']+= [obs],
    results['counts_naive'] += [counts],
    results['average_naive'] += [average]

    obs, counts, average = custom_calib_curve(proba_balanced, y_test)
    results['obs_balanced']+= [obs],
    results['counts_balanced'] += [counts],
    results['average_balanced'] += [average]
In [136]:
from statsmodels.stats.weightstats import DescrStatsW 
f, ax = plt.subplots(ncols=3, figsize=(9,5))
for i, title in enumerate(
    ['$\Delta$ auROC', '$\Delta$ f1', '$\Delta$ specificity']):
    arr = np.asarray(results[title])
    lci , hci = DescrStatsW(arr).tconfint_mean()
    sns.boxenplot(y=arr, ax=ax[i], color='green')
    ax[i].errorbar(
        -0.25, arr.mean(), yerr=[[lci], [hci]],
        marker='o', capsize=4, color='crimson'
    )
    ax[i].set_title(title)
    ax[i].axhline(0, ls=':', color='gray')

Woah this was truly unexpected! I wonder why it is different from 0 in all 3 cases. Although the effects are small, this goes against previous claims. Unfortunately, I cannot come up with any intuition behind it.
And now calibration curves averaging experiments or pooling results.

In [150]:
f, ax = plt.subplots(ncols=2, figsize=(12,6))
x_dim = np.array([.1,.3,.5,.7,.9])
for model in ['naive', 'balanced']:
    mat = np.squeeze(results[f'average_{model}'])
    cis = [DescrStatsW(mat[:, i]).tconfint_mean() for i in range(5)]
    lower = np.array([x[0] for x in cis])
    upper = np.array([x[1] for x in cis])
    ax[0].fill_between(x_dim, lower, upper, alpha=0.5)
    ax[0].plot(x_dim, mat.mean(axis=0), label=model, marker='o')
    ax[0].set_title('averaging experiments')

    counts = np.squeeze(results[f'counts_{model}']).sum(axis=0)
    nobs =np.squeeze(results[f'obs_{model}']).sum(axis=0)
    average = counts/nobs
    lower, upper = proportion_confint(counts, nobs, method='beta')
    ax[1].fill_between(x_dim, lower, upper, alpha=0.5)
    ax[1].plot(x_dim, average, label=model, marker='o')
    ax[1].set_title('pooling experiments')

for a in ax:
    a.plot([0,1],[0,1], ls=':', color='k', label='Perfectly calibrated')
    a.set_xlabel('Mean predicted probability')
    a.set_ylabel('Fraction of positives')
    a.legend()

plt.show()

No big surprises here.

Bonus: Decision Curve Analysis (DCA)

If this particular model was to be used as a diagnostic technique, you can weight concerns of treating vs not treating using DCA. You can read more about it here, here and here. I think it can be useful, however sometimes its interpretation might not be as obvious as one could expect before-hand.
Because back when I was asked to, I found no working implementation working consistently in python, here I share mine.

In [153]:
class DCA:
    """
    incomplete, written just to fulfill my own requirements:
    (several probs already calculated elsewhere)
    check rdma R package for a complete and high qual solution
    """
    def __init__(
        self,
        data: pd.DataFrame = None,
        ground_truth: str = 'target',
        thresholds: np.ndarray = np.linspace(0,.5,51)
    ):
        """instantiates class, assign to internal attrs"""
        self.df = data
        self.gt = data[ground_truth].values.astype(bool) # force (N,) dim
        self.thresholds = thresholds

    def calc_net_benefit_dcurve(self, probs, gt):
        """ 
        returns a curve of net benefit with dims (self.threholds, )
        using the formula defined here: https://doi.org/10.1136/bmj.i6
        """
        # ensure dimensionality for proper broadcasting
        res = probs.reshape(-1, 1) >= self.thresholds.reshape(1,-1)
        N = gt.size
        net_benefit = res[gt].sum(axis=0)/N - res[~gt].sum(axis=0)/N * (self.thresholds/(1-self.thresholds))
        return net_benefit

    def plot_dca(
        self,
        cols: list = [],
        labels: list = None,
        ax=None,
        pal: str = 'tab10' # qualitative colormap, else will fail
    ):
        """ plots dca in a given ax and returns it"""
        if ax is None:
            f, ax = plt.subplots()

        if labels is None:
            labels = cols 
            
        # plot trat none and treat all strategies
        ax.plot(self.thresholds, np.zeros(self.thresholds.size), 
            c='k', label='treat none')
        treat_all = (
            self.gt.mean() - (~self.gt).mean() *  
            (self.thresholds/(1-self.thresholds)))

        ax.plot(self.thresholds, treat_all, c='magenta', label='treat all')

        colors = plt.cm.get_cmap(pal)

        for i, col in enumerate(cols):
            curve = self.calc_net_benefit_dcurve(self.df[col].values, self.gt)
            ax.plot(self.thresholds, curve, color=colors(i), label=labels[i])

        ax.set_ylim([-0.05, None])
        ax.set_xlabel('threshold')
        ax.set_ylabel('net benefit')
        return ax


# reuse in-memory data
dca = DCA(
    data = pd.DataFrame({
        'ground_truth': y_test,
        'naive': proba_naive,
        'balanced': proba_balanced
    }),
    ground_truth='ground_truth'
)
ax = dca.plot_dca(
    cols = ['naive', 'balanced'],
    
)
ax.legend(frameon=False, fancybox=False)
ax.set_ylabel('Net Benefit')
ax.set_xlabel('Threshold')
plt.show()

According to DCA, naive model outperforms balanced model in the entire tested threshold range.

Cheers!

Show Comments