Source code for csl.problem

# -*- coding: utf-8 -*-
"""Constrained learning problem base class

In csl, constrained learning problems are specified by inheriting from
`ConstrainedLearningProblem` and defining its attributes.

"""

import torch
import numpy as np
from csl.utils import _batches


[docs]class ConstrainedLearningProblem: """Constrained learning problem base class. Constrained learning problems are defined by inheriting from `ConstrainedLearningProblem` and defining its attributes: - ``model``: underlying model to train - ``data``: data with which to train the model - ``batch_size`` (optional): maximum number of data points to load to memory at once - ``obj_function``: objective function or training loss - ``constraints`` (optional): average constraints - ``rhs`` (optional): right-hand side of average constraints - ``pointwise`` (optional): pointwise constraints - ``pointwise_rhs`` (optional): right-hand side of pointwise constraints A detailed description of each of these attributes is given below. Attributes ---------- model : `callable` Model used to solve the constrained learning problem. The model must have an attribute ``parameters`` and a method ``__call__`` as specified below. - ``parameters``: model parameters (`list` [`torch.tensor`] with ``requires_grad=True``) - ``__call__(x)``: takes a data batch ``x`` and evaluates the output of the model for each data point in ``x`` (`callable`) data : `list` Training data. Must define the methods ``__len__`` and ``__get_item__``: - ``__len__``: Returns size of dataset (callable) - ``__get_item__``: Returns element(s) from dataset (callable) batch_size : `int` Internal batch size to evaluate empirical averages. ..Note:: this has no effect on the training batch size. It is only used internally to avoid running out of memory when evaluating quantities that require a full pass over the dataset. obj_function : `callable` Objective function. Takes a list of indices (`list` [`int`]) defining a mini-batch and returns the objective function value (`torch.tensor`, (1, )) over that mini-batch. constraints : `list` [`callable`] Functions defining the average constraints. Each function takes - ``batch_idx``: list of indices defining a mini-batch (`list` [`int`]) - ``primal``: `True` if constraint is being evaluated for primal update or `False` otherwise (`bool`) and returns the average constraint value (`torch.tensor`, (1, )) over that mini-batch. rhs : `list` [`float`] List containing the right-hand side of each average constraint. pointwise : `list` [`callable`] Functions defining the pointwise constraints. Each function takes - ``batch_idx``: list of indices defining a mini-batch (`list` [`int`]) - ``primal``: `True` if constraint is being evaluated for primal update or `False` otherwise (`bool`) and returns the pointwise constraint value (`torch.tensor`, (``len(batch_idx)``, )) for each point in the mini-batch. pointwise_rhs : `list` [`torch.tensor`, (N, )] List containing the right-hand side of each pointwise constraint. Notes ----- **The primal flag** When working with non-differentiable constraints, a smooth approximation can be used during the primal computation to enable gradient updates. If this approximation is good enough, i.e., if the minimum of the smooth function is a good approximation of the minimum of the non-differentiable function, then certain guarantees can be given on the solutions obtained by the primal-dual iterations. The purpose of this flag is to allow for these alternative smooth approximations to be used when minimizing the Lagrangian using gradient descent. The original, non-differentiable loss is then used during the dual update to compute the supergradients of the Lagrangian with respect to the dual variables. **An example problem** To pose a constrained learning problem, inherit from `ConstrainedLearningProblem` and define its attributes before initializing the base class by calling ``super().__init__()``. .. code-block:: python :linenos: import torch import torch.nn.functional as F import csl #################################### # MODEL # #################################### class Logistic: def __init__(self, n_features): self.parameters = [torch.zeros(1, dtype = torch.float, requires_grad = True), torch.zeros([n_features,1], dtype = torch.float, requires_grad = True)] def __call__(self, x): if len(x.shape) == 1: x = x.unsqueeze(1) yhat = self.logit(torch.mm(x, self.parameters[1]) + self.parameters[0]) return torch.cat((1-yhat, yhat), dim=1) def predict(self, x): _, predicted = torch.max(self(x), 1) return predicted @staticmethod def logit(x): return 1/(1 + torch.exp(-x)) #################################### # PROBLEM # #################################### class fairClassification(csl.ConstrainedLearningProblem): def __init__(self): self.model = Logistic(data[0][0].shape[0]) self.data = data self.obj_function = self.loss # Demographic parity self.constraints = [ self.demographic_parity ] self.rhs = [ 0.1 ] super().__init__() def loss(self, batch_idx): # Evaluate objective x, y = self.data[batch_idx] yhat = self.model(x) return F.cross_entropy(yhat, y) def demographic_parity(self, batch_idx, primal): protected_idx = 3 x, y = self.data[batch_idx] group_idx = (x[:,protected_idx] == 1) if primal: # Sigmoid approximation of indicator function yhat = self.model(x) pop_indicator = torch.sigmoid(8*(yhat[:,1] - 0.5)) group_indicator = torch.sigmoid(8*(yhat[group_idx,1] - 0.5)) else: # Indicator function yhat = self.model.predict(x) pop_indicator = yhat.float() group_indicator = yhat[group_idx].float() return pop_indicator.mean() - group_indicator.mean() """ def __init__(self): # Check subclassing definition if not hasattr(self, 'model'): raise Exception('Your CSL problem must have a model.') if not hasattr(self, 'data'): raise Exception('Your CSL problem must have data.') if not hasattr(self, 'obj_function'): raise Exception('Your CSL problem must have an objective function.') # Finish initializing problem model_device = next(iter(self.model.parameters)).device if not hasattr(self, 'batch_size'): self.batch_size = None if not hasattr(self, 'data_size'): self.data_size = len(self.data) if not hasattr(self, 'constraints'): # Takes batch indices, returns a scalar average value self.constraints = [] self.rhs = [] self.lambdas = [] else: self.lambdas = [torch.tensor(0, dtype = torch.float, requires_grad = False, device = model_device) \ for _ in self.constraints] if not hasattr(self, 'pointwise'): # Takes batch indices, returns a vector with one element per data point self.pointwise = [] self.pointwise_rhs = [] self.mus = [] else: self.mus = [torch.zeros_like(rhs, dtype = torch.float, requires_grad = False, device = model_device) \ for rhs in self.pointwise_rhs]
[docs] def lagrangian(self, batch_idx=None): """Evaluate Lagrangian (and its gradient) Parameters ---------- batch_idx : `list` [`int`], optional Indices of batch. The default is `None` to evaluate over the full dataset. The evaluation is done in batches size according to ``batch_size`` to avoid loading the full dataset to the memory. Returns ------- L : `float` Lagrangian value. obj_value : `float` Objective value. constraints_slacks : `list` [`torch.tensor`, (1, )] Slacks of average constraints pointwise_slacks : `list` [`torch.tensor`, (``len(batch_idx)``, )] Slacks of pointwise constraints """ if batch_idx is not None: L, obj_value, constraint_slacks, pointwise_slacks = self._lagrangian(batch_idx) else: # Initialization L = 0 obj_value = 0 constraint_slacks = [0]*len(self.constraints) pointwise_slacks = [torch.zeros([0])]*len(self.pointwise) # Compute over the whole data set in batches for batch_start, batch_end in _batches(self.data_size, self.batch_size): L_batch, obj_value_batch, constraint_slacks_batch, pointwise_slacks_batch = self._lagrangian(np.arange(batch_start,batch_end)) L += L_batch*(batch_end - batch_start)/self.data_size obj_value += obj_value_batch*(batch_end - batch_start)/self.data_size for ii, slack in enumerate(constraint_slacks_batch): constraint_slacks[ii] += slack*(batch_end - batch_start)/self.data_size for ii, slack in enumerate(pointwise_slacks_batch): pointwise_slacks[ii] = torch.cat((pointwise_slacks[ii], slack)) return L, obj_value, constraint_slacks, pointwise_slacks
[docs] def objective(self, batch_idx=None): """Evaluate the objective function Parameters ---------- batch_idx : `list` [`int`], optional Indices of batch. The default is `None` to evaluate over the full dataset. The evaluation is done in batches size according to ``batch_size`` to avoid loading the full dataset to the memory. Returns ------- obj_value : `float` Objective value. """ if batch_idx is not None: obj_value = self.obj_function(batch_idx).item() else: obj_value = 0 for batch_start, batch_end in _batches(self.data_size, self.batch_size): obj_value += self.obj_function(range(batch_start,batch_end)).item()*(batch_end - batch_start)/self.data_size return obj_value
[docs] def slacks(self, batch_idx=None): """Evaluate constraint slacks Parameters ---------- batch_idx : `list` [`int`], optional Indices of batch. The default is `None` to evaluate over the full dataset. The evaluation is done in batches size according to ``batch_size`` to avoid loading the full dataset to the memory. Returns ------- constraint_slacks : `list` [`float`] Constraint violation of the average constraints. pointwise_slacks : `list` [`torch.tensor`, (``len(batch_idx)``, )] Constraint violation of the pointwise constraints. """ if batch_idx is not None: constraint_slacks = self._constraint_slacks(batch_idx) pointwise_slacks = self._pointwise_slacks(batch_idx) else: constraint_slacks = [0]*len(self.constraints) pointwise_slacks = [torch.zeros([0])]*len(self.pointwise) for batch_start, batch_end in _batches(self.data_size, self.batch_size): for ii, s in enumerate(self._constraint_slacks(range(batch_start,batch_end))): constraint_slacks[ii] += s*(batch_end - batch_start)/self.data_size for ii, s in enumerate(self._pointwise_slacks(range(batch_start,batch_end))): pointwise_slacks[ii] = torch.cat((pointwise_slacks[ii], s)) return constraint_slacks, pointwise_slacks
########################################################################### #### PRIVATE FUNCTIONS #### ########################################################################### def _constraint_slacks(self, batch_idx): """Evaluate constraint slacks for average constraints over batch Parameters ---------- batch_idx : `list` [`int`] Indices of batch. Returns ------- slacks_value : `list` [`float`] Constraint violation of the average constraints. """ slacks_value = [ell(batch_idx, primal=False) - c for ell, c in zip(self.constraints, self.rhs)] return slacks_value def _pointwise_slacks(self, batch_idx): """Evaluate constraint slacks for pointwise constraints over batch Parameters ---------- batch_idx : `list` [`int`] Indices of batch. Returns ------- slacks_value : `list` [`torch.tensor`, (``len(batch_idx)``, )] Constraint violation of the pointwise constraints. """ slacks_value = [ell(batch_idx, primal=False) - c[batch_idx] for ell, c in zip(self.pointwise, self.pointwise_rhs)] return slacks_value def _lagrangian(self, batch_idx): """Evaluate Lagrangian over batch Parameters ---------- batch_idx : `list` [`int`] Indices of batch. Returns ------- L : `float` Lagrangian value. obj_value : `float` Objective value. constraints_slacks : `list` [`torch.tensor`, (1, )] Slacks of average constraints pointwise_slacks : `list` [`torch.tensor`, (``len(batch_idx)``, )] Slacks of pointwise constraints """ L = 0 constraints_slacks = [] pointwise_slacks = [] # Objective value obj_value = self.obj_function(batch_idx) if torch.is_grad_enabled(): obj_value.backward() L += obj_value.item() # Dualized average constraints for lambda_value, ell, c in zip(self.lambdas, self.constraints, self.rhs): slack = ell(batch_idx, primal=True) - c dualized_slack = lambda_value*slack if torch.is_grad_enabled(): dualized_slack.backward() constraints_slacks += [slack] L += dualized_slack.item() # Dualized pointwise constraints for mu_value, ell, c in zip(self.mus, self.pointwise, self.pointwise_rhs): slack = ell(batch_idx, primal=True) - c[batch_idx] dualized_slack = torch.dot(mu_value[batch_idx], slack)/len(batch_idx) if torch.is_grad_enabled(): dualized_slack.backward() pointwise_slacks += [slack] L += dualized_slack.item() return L, obj_value.item(), constraints_slacks, pointwise_slacks