Applications

Fairness

We will use the Adult dataset from UCI to demonstrate how to impose fairness constraints. Here, the goal is to predict whether to grant a loan to an individual by trying to predict if they make more than US$ 50k per year. However, we want to make sure that loans are granted as likely to be granted to women than to men.

You can find more information in [CR, NeurIPS’20].

For this example, you will need to go get adult.data and adult.test from UCI and place them in a folder named data.

You can try the full code on GitHub.

Basic setup

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
import torch
import torch.nn.functional as F
import torchvision

import matplotlib.pyplot as plot

import functools

import sys, os
sys.path.append(os.path.abspath('../'))

import csl, csl.datasets

Loading data

We use csl.datasets.utils to do a bit of data wrangling. Drop some variables, bin others, and dummy code categorical variables.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Preprocessing
preprocess = torchvision.transforms.Compose([
    csl.datasets.utils.Drop(['fnlwgt', 'educational-num', 'relationship', 'capital-gain', 'capital-loss']),
    csl.datasets.utils.Recode('education', {'<= K-12': ['Preschool', '1st-4th', '5th-6th', '7th-8th',
                                      '9th', '10th', '11th', '12th']}),
    csl.datasets.utils.Recode('race', {'Other': ['Other', 'Amer-Indian-Eskimo']}),
    csl.datasets.utils.Recode('marital-status', {'Married': ['Married-civ-spouse', 'Married-AF-spouse',
                                          'Married-spouse-absent'],
                              'Divorced/separated': ['Divorced', 'Separated']}),
    csl.datasets.utils.Recode('native-country', {'South/Central America': ['Columbia', 'Cuba', 'Guatemala',
                                                        'Haiti', 'Ecuador', 'El-Salvador',
                                                        'Dominican-Republic', 'Honduras',
                                                        'Jamaica', 'Nicaragua', 'Peru',
                                                        'Trinadad&Tobago'],
                              'Europe': ['England', 'France', 'Germany', 'Greece',
                                          'Holand-Netherlands', 'Hungary', 'Italy',
                                          'Ireland', 'Portugal', 'Scotland', 'Poland',
                                          'Yugoslavia'],
                              'Southeast Asia': ['Cambodia', 'Laos', 'Philippines',
                                                  'Thailand', 'Vietnam'],
                              'Chinas': ['China', 'Hong', 'Taiwan'],
                              'USA': ['United-States', 'Outlying-US(Guam-USVI-etc)',
                                      'Puerto-Rico']}),
    csl.datasets.utils.QuantileBinning('age', 6),
    csl.datasets.utils.Binning('hours-per-week', bins = [0,40,100]),
    csl.datasets.utils.Dummify(csl.datasets.Adult.categorical + ['age', 'hours-per-week'])
    ])

# Load Adult data
trainset = csl.datasets.Adult(root = 'data', train = True, target_name = 'income', preprocess = preprocess,
                              transform = csl.datasets.utils.ToTensor(dtype = torch.float),
                              target_transform = csl.datasets.utils.ToTensor(dtype = torch.long))

testset = csl.datasets.Adult(root = 'data', train = False, target_name = 'income', preprocess = preprocess,
                             transform = csl.datasets.utils.ToTensor(dtype = torch.float),
                             target_transform = csl.datasets.utils.ToTensor(dtype = torch.long))

# Gender column index
fullset = csl.datasets.Adult(root = 'data', train = False, target_name = 'income', preprocess = preprocess)
gender_idx = [idx for idx, name in enumerate(fullset[0][0].columns) if name.startswith('gender')]

Defining a logistic model

Here we construct a simple logistic model that we will use to predict decide whether to grant the loan by predicting if the individual makes more than US$ 50k.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
class Logistic:
    def __init__(self, n_features):
        self.parameters = [torch.zeros(1, dtype = torch.float, requires_grad = True),
                           torch.zeros([n_features,1], dtype = torch.float, requires_grad = True)]

    def __call__(self, x):
        yhat = self.logit(torch.mm(x, self.parameters[1]) + self.parameters[0])

        return torch.cat((1-yhat, yhat), dim=1)

    def predict(self, x):
        _, predicted = torch.max(self(x), 1)
        return predicted

    @staticmethod
    def logit(x):
        return 1/(1 + torch.exp(-x))

The fair classification problem

We define the fair classification problem using the Logistic model, the trainset, and a logistic loss (see obj_function). We then include two (asymmetrical) demographic parity constraints, one for women and another for men. The specification rhs will be passed as a parameter and rhs=None is used to construct an unconstrained problem.

Note that since demographic parity is not differentiable (it is the expected value of an indicator function), the constraints use a sigmoidal approximation when primal is True (see csl.problem.ConstrainedLearningProblem for more details).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class fairClassification(csl.ConstrainedLearningProblem):
    def __init__(self, rhs = None):
        self.model = Logistic(trainset[0][0].shape[0])
        self.data = trainset
        self.obj_function = self.loss

        if rhs is not None:
            # Gender
            self.constraints = [self.DemographicParity(self, gender_idx, 0),
                                self.DemographicParity(self, gender_idx, 1)]
            self.rhs = [rhs, rhs]

        super().__init__()

    def loss(self, batch_idx):
        # Evaluate objective
        x, y = self.data[batch_idx]
        yhat = self.model(x)

        return F.cross_entropy(yhat, y) + 1e-3*(self.model.parameters[0]**2 + self.model.parameters[1].norm()**2)

    class DemographicParity:
        def __init__(self, problem, protected_idx, protected_value):
            self.problem = problem
            self.protected_idx = protected_idx
            self.protected_value = protected_value

        def __call__(self, batch_idx, primal):
            x, y = self.problem.data[batch_idx]

            group_idx = (x[:, self.protected_idx].squeeze() == self.protected_value)

            if primal:
                yhat = self.problem.model(x)
                pop_indicator = torch.sigmoid(8*(yhat[:,1] - 0.5))
                group_indicator = torch.sigmoid(8*(yhat[group_idx,1] - 0.5))
            else:
                yhat = self.problem.model.predict(x)
                pop_indicator = yhat.float()
                group_indicator = yhat[group_idx].float()

            return -(group_indicator.mean() - pop_indicator.mean())

problems = {
   'unconstrained': fairClassification(),
  'constrained': fairClassification(rhs = 0.01),
  }

Solving the constrained learning problem

We can now solve our constrained learning problem by constructing a primal-dual solver and using it to solve each problem in problems. Note the use of csl.solver_base.PrimalDualBase.reset() between each solve. We save the results in solutions.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
solver_settings = {'iterations': 700,
                   'batch_size': None,
                   'primal_solver': lambda p: torch.optim.Adam(p, lr=0.2),
                   'dual_solver': lambda p: torch.optim.Adam(p, lr=0.001),
                   }
solver = csl.PrimalDual(solver_settings)

solutions = {}
for key, problem in problems.items():
    solver.reset()
    solver.solve(problem)
    solver.plot()

    solutions[key] = {'model': problem.model,
                     'lambdas': problem.lambdas,
                     'solver_state': solver.state_dict}

Testing the solutions

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def accuracy(pred, y):
    correct = (pred == y).sum().item()
    return correct/pred.shape[0]

def disparity(x, model, protected_idx, protected_value):
    pred = model.predict(x)

    pop_prev = pred.float().mean().item()

    group_idx = (fullset[:][0].iloc[:,protected_idx].squeeze() == protected_value)

    group_prev = pred[group_idx].float().mean().item()

    disparity_value = group_prev - pop_prev
    rel_disparity_value = disparity_value/pop_prev

    return disparity_value, rel_disparity_value

for key, solution in solutions.items():
    print(f'Model: {key}')
    with torch.no_grad():
        x_test, y_test = testset[:]
        yhat = solution['model'].predict(x_test)

        acc_test = accuracy(yhat, y_test)

        disparity_f, rel_disparity_f = disparity(x_test, solution['model'], gender_idx, 0)
        disparity_m, rel_disparity_m = disparity(x_test, solution['model'], gender_idx, 1)

        print(f'Test accuracy: {100*acc_test:.2f}')
        print(f'Predicted population prevalence: {100*yhat.float().mean().item():.2f}')
        print(f'Female disparity: {100*disparity_f:.2f} | {100*rel_disparity_f:.2f}')
        print(f'Male disparity: {100*disparity_m:.2f} | {100*rel_disparity_m:.2f}')

Robustness

We will use the CIFAR-10 dataset to demonstrate how we can explicitly build accurate models with robustness requirements. The goal is achieve the best nominal performance possible while satisfying a constraint on the adversarial loss.

You can find more information in [CR, NeurIPS’20].

For this example, you will need to go get CIFAR-10 dataset as pytorch tensors as described in csl.datasets.datasets.CIFAR10 and it in a folder named data.

You will also need to install the foolbox module

pip install foolbox

You can try the full code on GitHub.

Basic setup

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import foolbox

import torch
import torchvision
import torch.nn.functional as F

from resnet import ResNet18

import numpy as np

import copy

import sys, os
sys.path.append(os.path.abspath('../'))

import csl, csl.datasets

# Perturbation magnitude
eps = 0.02

# Use GPU if available
theDevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


####################################
# FUNCTIONS                        #
####################################
def accuracy(yhat, y):
    _, predicted = torch.max(yhat, 1)
    correct = (predicted == y).sum().item()
    return correct/yhat.shape[0]

def preprocess(img):
    mean = torch.tensor([0.4914, 0.4822, 0.4465], dtype = img.dtype, device=theDevice).reshape((3, 1, 1))
    std = torch.tensor([0.2023, 0.1994, 0.2010], dtype = img.dtype, device=theDevice).reshape((3, 1, 1))
    return (img - mean) / std

Load data

We will keep a balanced 2% subset of the training data for validation. Just to keep things realistic. We use the subset parameter to do that (see csl.datasets.datasets.CIFAR10).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
n_train = 4900
n_valid = 100

target = csl.datasets.CIFAR10(root = 'data', train = True)[:][1]

label_idx = [np.flatnonzero(target == label) for label in range(0,10)]
label_idx = [np.random.RandomState(seed=42).permutation(idx) for idx in label_idx]
train_subset = [idx[:n_train] for idx in label_idx]
train_subset = np.array(train_subset).flatten()

train_transform = torchvision.transforms.Compose([
    csl.datasets.utils.RandomFlip(),
    csl.datasets.utils.RandomCrop(size=32,padding=4),
    csl.datasets.utils.ToTensor(device=theDevice)
    ])

trainset = csl.datasets.CIFAR10(root = 'data', train = True, subset = train_subset,
                                transform = train_transform,
                                target_transform = csl.datasets.utils.ToTensor(device=theDevice))

valid_subset = [idx[n_train:n_train+n_valid] for idx in label_idx]
valid_subset = np.array(valid_subset).flatten()
validset = csl.datasets.CIFAR10(root = 'data', train = True, subset = valid_subset,
                                transform = csl.datasets.utils.ToTensor(device=theDevice),
                                target_transform = csl.datasets.utils.ToTensor(device=theDevice))

Defining the constrained learning problem

There are two noteworthy things to be careful when encoding the constraint:

  • foolbox has side-effects: it modifies the gradient of the parameters (even though it doesn’t need to), so you need to save those gradients and to reload them later

  • ResNets use batch normalization, which you should take into account only when optimizing the primal. So need to get the model back into train mode a bit earlier for the primal update.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class robustLoss(csl.ConstrainedLearningProblem):
    def __init__(self, rhs):
        self.model = csl.PytorchModel(ResNet18().to(theDevice))
        self.data = trainset
        self.batch_size = 256

        self.obj_function = self.obj_fun

        # Constraints
        self.constraints = [self.adversarialLoss]
        self.rhs = [rhs]

        self.foolbox_model = foolbox.PyTorchModel(self.model.model, bounds=(0, 1),
                                                  device=theDevice,
                                                  preprocessing = dict(mean=[0.4914, 0.4822, 0.4465],
                                                                       std=[0.2023, 0.1994, 0.2010],
                                                                       axis=-3))
        self.attack = foolbox.attacks.LinfPGD(rel_stepsize = 1/3, abs_stepsize = None,
                                              steps = 5, random_start = True)

        super().__init__()

    def obj_fun(self, batch_idx):
        x, y = self.data[batch_idx]

        yhat = self.model(preprocess(x))

        return 0.1*self._loss(yhat, y)

    def adversarialLoss(self, batch_idx, primal):
        x, y = self.data[batch_idx]

        # Attack
        self.model.eval()

        # Save gradients before adversarial runs
        saved_grad = [copy.deepcopy(p.grad) for p in self.model.parameters]

        # Dual is computed in a no_grad() environment
        x_processed, _, _ = self.attack(self.foolbox_model, x, y, epsilons = eps)

        # Reload gradients
        for p,g in zip(self.model.parameters, saved_grad):
            p.grad = g

        if primal:
            self.model.train()
            yhat = self.model(preprocess(x_processed))
            loss = self._loss(yhat, y)
        else:
            with torch.no_grad():
                yhat = self.model(preprocess(x_processed))
                loss = self._loss(yhat, y)
            self.model.train()

        return loss

    @staticmethod
    def _loss(yhat, y):
        return F.cross_entropy(yhat, y)

Setting up a validation hook

We kept some validation data to see how the model is performing on adversarial samples during training. For that, we setup a validation hook which we can plug as a user-defined stopping criterion (see csl.solver_base.PrimalDualBase). We could have the solver stop depending on a value of the validation accuracy, but here we will just let the solver do its thing and alway return False.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def validation_hook(problem, solver_state):
        adv_epoch = 10
        _adv_epoch = adv_epoch

        batch_idx = np.arange(0, len(validset)+1, problem.batch_size)
        if batch_idx[-1] < len(validset):
            batch_idx = np.append(batch_idx, len(validset))

        # Validate
        acc = 0
        acc_adv = 0
        problem.model.eval()
        for batch_start, batch_end in zip(batch_idx, batch_idx[1:]):
            x, y = validset[batch_start:batch_end]
            with torch.no_grad():
                yhat = problem.model(preprocess(x))
                acc += accuracy(yhat, y)*(batch_end - batch_start)/len(validset)

            # Attack
            if _adv_epoch == 1:
                adversarial, _, _ = problem.attack(problem.foolbox_model, x, y, epsilons = eps)
                with torch.no_grad():
                    yhat_adv = problem.model(preprocess(adversarial))
                    acc_adv += accuracy(yhat_adv, y)*(batch_end - batch_start)/len(validset)
        problem.model.train()

        # Results
        if _adv_epoch > 1:
            print(f"Validation accuracy: {acc*100:.2f} / Dual variables: {[lambda_value.item() for lambda_value in problem.lambdas]}")
            _adv_epoch -= 1
        else:
            print(f"Validation accuracy:{acc*100:.2f} / Adversarial accuracy = {acc_adv*100:.2f}")
            _adv_epoch = adv_epoch

        return False

Solving the constrained learning problem

We’ve done most of the work above, so now we just need to call the constructors and solve the problem.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
problem = robustLoss(rhs=0.7)

solver_settings = {'iterations': 400,
                   'verbose': 1,
                   'batch_size': 128,
                   'primal_solver': torch.optim.Adam,
                   'primal_solver': lambda p: torch.optim.Adam(p, lr=0.01),
                   'lr_p_scheduler': None,
                   'dual_solver': lambda p: torch.optim.Adam(p, lr=0.001),
                   'lr_d_scheduler': None,
                   'device': theDevice,
                   'STOP_USER_DEFINED': validation_hook,
                   }
solver = csl.SimultaneousPrimalDual(solver_settings)

solver.solve(problem)
solver.plot()

Testing

We can now test the results using a stronger attack than the one we used to train.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Test data
testset = csl.datasets.CIFAR10(root = 'data', train = False,
                               transform = csl.datasets.utils.ToTensor(device=theDevice),
                               target_transform = csl.datasets.utils.ToTensor(device=theDevice))

# Adversarial attack
problem.model.eval()
foolbox_model = foolbox.PyTorchModel(problem.model.model, bounds=(0, 1),
                                     device=theDevice,
                                     preprocessing = dict(mean=[0.4914, 0.4822, 0.4465],
                                                          std=[0.2023, 0.1994, 0.2010],
                                                          axis=-3))
attack = foolbox.attacks.LinfPGD(rel_stepsize = 1/30, abs_stepsize = None,
                                 steps = 50, random_start = True)
epsilon_test = np.linspace(0.01,0.06,7)

# Prepare batches
batch_idx = np.arange(0, len(testset)+1, problem.batch_size)
if batch_idx[-1] < len(testset):
    batch_idx = np.append(batch_idx, len(testset))

n_total = 0
acc_test = 0
acc_adv = np.zeros(epsilon_test.shape[0])
success_adv = np.zeros_like(acc_adv)

for batch_start, batch_end in zip(batch_idx, batch_idx[1:]):
    x_test, y_test = testset[batch_start:batch_end]

    # Nominal accuracy
    yhat = problem.model(preprocess(x_test))
    acc_test += accuracy(yhat, y_test)*(batch_end - batch_start)

    # Adversarials accuracy
    adversarials, _, success = attack(foolbox_model, x_test, y_test, epsilons = epsilon_test)
    for ii, adv in enumerate(adversarials):
        yhat_adv = problem.model(preprocess(adv))
        acc_adv[ii] += accuracy(yhat_adv, y_test)*(batch_end - batch_start)
        success_adv[ii] += torch.sum(success[ii])

    n_total += batch_end - batch_start

acc_test /= n_total
acc_adv /= n_total
success_adv /= n_total

print('====== TEST ======')
print(f'Test accuracy: {100*acc_test:.2f}')
print(f'Adversarial accuracy: {100*acc_adv}')
print(f'Adversarial success: {100*success_adv}')