csl

csl: PyTorch-based constrained learning

The csl module provides a common interface to specify and solve constrained learning problems using PyTorch.

Model wrappers

class csl.PytorchModel(model)[source]

Bases: object

PyTorch model wrapper for constrained learning problems.

Variables
  • model (torch.nn.Module) – A PyTorch model.

  • parameters (list [torch.tensor]) – Model parameters. Obtained directly from the underlying PyTorch module as model.parameters(). Setting parameters, however, expects a module state dictionary obtained by calling model.state_dict.

predict(x)[source]

Evaluate model prediction

Predicts the label of each data point in x. Assumes the neural network has one output per class and returns the class corresponding to the largest output.

Parameters

x (torch.tensor) – Input data

Constrained learning problem

class csl.ConstrainedLearningProblem[source]

Bases: object

Constrained learning problem base class.

Constrained learning problems are defined by inheriting from ConstrainedLearningProblem and defining its attributes:

  • model: underlying model to train

  • data: data with which to train the model

  • batch_size (optional): maximum number of data points to load to memory at once

  • obj_function: objective function or training loss

  • constraints (optional): average constraints

  • rhs (optional): right-hand side of average constraints

  • pointwise (optional): pointwise constraints

  • pointwise_rhs (optional): right-hand side of pointwise constraints

A detailed description of each of these attributes is given below.

Variables
  • model (callable) –

    Model used to solve the constrained learning problem. The model must have an attribute parameters and a method __call__ as specified below.

    • parameters: model parameters (list [torch.tensor] with requires_grad=True)

    • __call__(x): takes a data batch x and evaluates the output of the model for each data point in x (callable)

  • data (list) –

    Training data. Must define the methods __len__ and __get_item__:

    • __len__: Returns size of dataset (callable)

    • __get_item__: Returns element(s) from dataset (callable)

  • batch_size (int) –

    Internal batch size to evaluate empirical averages.

    ..Note:: this has no effect on the training batch size. It is only

    used internally to avoid running out of memory when evaluating quantities that require a full pass over the dataset.

  • obj_function (callable) – Objective function. Takes a list of indices (list [int]) defining a mini-batch and returns the objective function value (torch.tensor, (1, )) over that mini-batch.

  • constraints (list [callable]) –

    Functions defining the average constraints. Each function takes

    • batch_idx: list of indices defining a mini-batch (list [int])

    • primal: True if constraint is being evaluated for primal update or False otherwise (bool)

    and returns the average constraint value (torch.tensor, (1, )) over that mini-batch.

  • rhs (list [float]) – List containing the right-hand side of each average constraint.

  • pointwise (list [callable]) –

    Functions defining the pointwise constraints. Each function takes

    • batch_idx: list of indices defining a mini-batch (list [int])

    • primal: True if constraint is being evaluated for primal update or False otherwise (bool)

    and returns the pointwise constraint value (torch.tensor, (len(batch_idx), )) for each point in the mini-batch.

  • pointwise_rhs (list [torch.tensor, (N, )]) – List containing the right-hand side of each pointwise constraint.

Notes

The primal flag

When working with non-differentiable constraints, a smooth approximation can be used during the primal computation to enable gradient updates. If this approximation is good enough, i.e., if the minimum of the smooth function is a good approximation of the minimum of the non-differentiable function, then certain guarantees can be given on the solutions obtained by the primal-dual iterations.

The purpose of this flag is to allow for these alternative smooth approximations to be used when minimizing the Lagrangian using gradient descent. The original, non-differentiable loss is then used during the dual update to compute the supergradients of the Lagrangian with respect to the dual variables.

An example problem

To pose a constrained learning problem, inherit from ConstrainedLearningProblem and define its attributes before initializing the base class by calling super().__init__().

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn.functional as F
import csl

####################################
# MODEL                            #
####################################
class Logistic:
    def __init__(self, n_features):
        self.parameters = [torch.zeros(1, dtype = torch.float, requires_grad = True),
                           torch.zeros([n_features,1], dtype = torch.float, requires_grad = True)]

    def __call__(self, x):
        if len(x.shape) == 1:
            x = x.unsqueeze(1)

        yhat = self.logit(torch.mm(x, self.parameters[1]) + self.parameters[0])

        return torch.cat((1-yhat, yhat), dim=1)

    def predict(self, x):
        _, predicted = torch.max(self(x), 1)
        return predicted

    @staticmethod
    def logit(x):
        return 1/(1 + torch.exp(-x))

####################################
# PROBLEM                          #
####################################
class fairClassification(csl.ConstrainedLearningProblem):
    def __init__(self):
        self.model = Logistic(data[0][0].shape[0])
        self.data = data
        self.obj_function = self.loss

        # Demographic parity
        self.constraints = [ self.demographic_parity ]
        self.rhs = [ 0.1 ]

        super().__init__()

    def loss(self, batch_idx):
        # Evaluate objective
        x, y = self.data[batch_idx]
        yhat = self.model(x)

        return F.cross_entropy(yhat, y)

    def demographic_parity(self, batch_idx, primal):
            protected_idx = 3
            x, y = self.data[batch_idx]
            group_idx = (x[:,protected_idx] == 1)

            if primal:
                # Sigmoid approximation of indicator function
                yhat = self.model(x)
                pop_indicator = torch.sigmoid(8*(yhat[:,1] - 0.5))
                group_indicator = torch.sigmoid(8*(yhat[group_idx,1] - 0.5))
            else:
                # Indicator function
                yhat = self.model.predict(x)
                pop_indicator = yhat.float()
                group_indicator = yhat[group_idx].float()

            return pop_indicator.mean() - group_indicator.mean()
lagrangian(batch_idx=None)[source]

Evaluate Lagrangian (and its gradient)

Parameters

batch_idx (list [int], optional) – Indices of batch. The default is None to evaluate over the full dataset. The evaluation is done in batches size according to batch_size to avoid loading the full dataset to the memory.

Returns

  • L (float) – Lagrangian value.

  • obj_value (float) – Objective value.

  • constraints_slacks (list [torch.tensor, (1, )]) – Slacks of average constraints

  • pointwise_slacks (list [torch.tensor, (len(batch_idx), )]) – Slacks of pointwise constraints

objective(batch_idx=None)[source]

Evaluate the objective function

Parameters

batch_idx (list [int], optional) – Indices of batch. The default is None to evaluate over the full dataset. The evaluation is done in batches size according to batch_size to avoid loading the full dataset to the memory.

Returns

obj_value – Objective value.

Return type

float

slacks(batch_idx=None)[source]

Evaluate constraint slacks

Parameters

batch_idx (list [int], optional) – Indices of batch. The default is None to evaluate over the full dataset. The evaluation is done in batches size according to batch_size to avoid loading the full dataset to the memory.

Returns

  • constraint_slacks (list [float]) – Constraint violation of the average constraints.

  • pointwise_slacks (list [torch.tensor, (len(batch_idx), )]) – Constraint violation of the pointwise constraints.

Solvers

class csl.PrimalThenDual(user_settings={})[source]

Bases: csl.solver_base.PrimalDualBase

__init__(user_settings={})[source]

Primal-then-dual solver.

Update primal using a full pass over the dataset then update dual using a full pass over the dataset.

Parameters

user_settings (dict, optional) – Dictionary containing solver settings. See SolverSettings for basic solver settings and defaults.

Additional specific settings:

  • batch_size: Mini-batch size. The default is None (uses full dataset at once).

  • shuffle: Shuffle dataset before batching. The default is True.

  • dual_period: Epoch period of dual update (update dual once every dual_period epochs). The default is 1, run once per primal epoch.

csl.PrimalDual

alias of csl.solvers.PrimalThenDual

class csl.SimultaneousPrimalDual(user_settings={})[source]

Bases: csl.solver_base.PrimalDualBase

__init__(user_settings={})[source]

Simultaneous primal-dual solver.

For each batch, update primal then update dual.

Parameters

user_settings (dict, optional) – Dictionary containing solver settings. See SolverSettings for basic solver settings and defaults.

Additional specific settings:

  • batch_size: Mini-batch size. The default is None (uses full dataset at once).

  • shuffle: Shuffle dataset before batching. The default is True.

Base solver

class csl.solver_base.SolverSettings(specific_settings={})[source]

Bases: object

Primal-dual solver settings

Variables

settings (dict, optional) –

Dictionary containing the solver settings. The base solver valid keys and default values are listed below. Specific solvers may have additional settings.

  • iterations: Maximum number of iterations. The default is 100.

  • primal_solver: Constructor for the PyTorch solver used to solve the primal problem (takes parameter, returns torch.optim). The default is ADAM.

  • lr_p_scheduler: Primal step size scheduler. The default is None (no decay).

  • dual_solver: Constructor for the PyTorch solver used to solve the dual problem (takes parameter, returns torch.optim). The default is ADAM.

  • lr_d_scheduler: Dual step size scheduler. The default is None (no decay).

  • logger: A logging object. The default outputs directly to the console.

  • verbose: Period of log printing. Every iteration is printed when logger is at DEBUG level. Default is iterations/10. Set to 0 to deactivate.

  • device: Device used for computations. The default is GPU, if available, and CPU otherwise.

  • COMPUTE_TRUE_DGAP: Compute actual primal and dual values instead of using approximations. Note: this effectively requires an extra pass over the whole dataset per iteration. The default is False.

  • STOP_DIVERGENCE: Maximum value allowed before declaring that the algorithm has diverged. The default is 1e4.

  • STOP_PVAL: Primal value threshold. The default is None.

  • STOP_PGRAD: Primal gradient squared norm threshold. The default is None.

  • STOP_ABS_DGAP: Absolute duality gap threshold. The default is None.

  • STOP_DGRAD: Dual gradient squared norm threshold. The default is None.

  • STOP_REL_DGAP: Relative duality gap threshold. The default is None.

  • STOP_ABS_FEAS: Absolute feasibility threshold. The default is None.

  • STOP_REL_FEAS: Relative feasibility threshold. The default is None.

  • STOP_NFEAS: Proportion of feasible constraints threshold. The default is None.

  • STOP_PATIENCE: Iterations with no update threshold. The default is None.

  • STOP_USER_DEFINED: User-defined function. Takes problem and the solver state_dict and returns True to stop or False to continue. The default is None.

Raises

ValueError – When trying to get or set a non-existant setting

__init__(specific_settings={})[source]

Primal-dual solver settings constructor

Parameters

specific_settings (dict, optional) – Solver-specific global settings with default values. The default is {}.

display()[source]

Display effective setting values

initialize(settings={})[source]

Initialize settings

Define global settings and initialize variable settings if not defined (namely, verbose and logger).

Parameters

settings (dict, optional) – User settings to override the default global settings. The default is {}.

override(local_settings)[source]

Mask global setting values

Override the global settings without modifying them.

Parameters

local_settings (dict) – Setting values to override.

class csl.solver_base.PrimalDualBase(settings)[source]

Bases: object

Primal-dual base solver

Variables
  • primal_solver (torch.optim) – Primal problem solver

  • primal_step_size (torch.optim.lr_scheduler) – Primal problem step size scheduler

  • dual_solver (torch.optim) – Dual problem solver

  • dual_step_size (torch.optim.lr_scheduler) – Dual problem step size scheduler

  • state_dict (dict) – Internal solver state

  • settings (SolverSettings) – Solver settings

Notes

Stopping criteria

By default, the solver stops only once it reaches the maximum number of iterations or if divergence is detected (using the threshold STOP_DIVERGENCE). When defined, other stopping modes are:

  • Primal absolute optimality: based on STOP_PVAL or STOP_PGRAD (either or both depending on whether exist). Applies only to unconstrained problems.

  • Primal absolute optimality and absolute feasibility: as above but also check STOP_ABS_FEAS (on average constraints) and STOP_NFEAS (for pointwise constraints), depending on whether they exist.

  • Primal-dual absolute optimality and absolute feasibility: checks optimality using STOP_ABS_DGAP or both STOP_PGRAD and STOP_DGRAD. Feasibility is checked using STOP_ABS_FEAS (on average constraints) and STOP_NFEAS (for pointwise constraints), depending on whether they exist.

  • Primal-dual relative optimality and relative feasibility: checks optimality using STOP_REL_DGAP and feasibility using STOP_REL_FEAS (for average constraints) and STOP_NFEAS (for pointwise constraints).

  • Stalled: based on whether neither primal value nor any constraint violation has improved over the span of STOP_PATIENCE iterations.

  • User-defined criterion: stops if STOP_USER_DEFINED returns True.

Implementing a new solver

Subclasses must:

  • Call PrimalDualBase.__init__ with a SolverSettings object that includes its solver-specific settings (if any).

  • Define primal_dual_update

Solver states

Specific solvers (and users) are free to modify and add to state_dict. This can be useful to save internal states of the user-defined stopping criterion STOP_USER_DEFINED, which can also be used as a validation hook. You should use a unique capitalized prefix in order to avoid interfering with the normal operation of the primal-dual solver (unless you know what you are doing). PrimalDualBase has the following internal states:

  • iteration (int): Iteration number

  • no_update_iterations (int): Number of iterations without updates left until solver gives up and stops early. Undefined if STOP_PATIENCE is None. Initial value: STOP_PATIENCE.

  • primal_solver (dict): state_dict of the primal problem solver

  • dual_solver (dict): state_dict of the dual problem solver

  • primal_value (float): Current primal value, i.e., objective function value

  • primal_grad_norm (float): Squared norm of the primal gradient.

  • lagrangian_value (float): Current value of the Lagrangian. If the problem is convex, converges to the value of the dual function.

  • dual_grad_norm (float): Squared norm of the dual gradient (constraint slacks).

  • duality_gap (float): Duality gap, i.e., P - D

  • rel_duality_gap (float): Relative duality gap, i.e., (P - D)/P

  • constraint_feas (np.array): Constraint violation of average constraints. Non-positive if the solution satisfies the constraint.

  • constraint_rel_feas (np.array): Relative constraint violation of average constraints, i.e., slack divided by right-hand side. Non-positive if the solution satisfies the constraint.

  • constraint_nfeas (np.array): Proportion of feasible average constraints.

  • lambdas_max (float): Maximum value of dual variables (average constraints).

  • pointwise_feas (np.array): Constraint violation of pointwise constraints. Non-positive if the solution satisfies the constraint.

  • pointwise_nfeas (np.array): Proportion of feasible points for each pointwise constraint.

  • mus_max (float): Maximum value of dual variables (pointwise constraints).

  • primal_value_log (np.array): Primal value across iterations.

  • lagrangian_value_log (np.array): Lagrangian value across iterations.

  • lambdas_log (np.array): Dual variables (average constraints) across iterations.

  • feas_log (np.array): Constraint violation (average constraints) across iterations.

  • rel_feas_log (np.array): Relative constraint violation (average constraints) across iterations.

  • mus_log (np.array): Dual variables (pointwise constraints) across iterations.

  • nfeas_log (np.array): Proportion of feasible pointwise constraints across iterations.

  • HAS_CONSTRAINTS (bool): True if learning problem has constraints and False otherwise.

  • N_AVG_CONSTRAINTS (int): Number of average constraints of learning problem.

  • N_PTW_CONSTRAINTS (int): Number of pointwise constraints of learning problem.

Except for iteration, states ending in _log, and flags (capitalized states), the value of the state in the previous iteration can be accessed by appending _prev to the state variable name.

Warning

The values of primal_value, lagrangian_value, duality_gap, rel_duality_gap, constraint_feas, constraint_rel_feas, pointwise_feas, pointwise_nfeas should be considered as estimates unless COMPUTE_TRUE_DGAP is True. This is particularly an issue when the solver uses batches rather than operating over the full dataset (e.g., SGD). In these cases, the primal and/or dual variables are updated between between batches, so their average is not representative of the current performance. When COMPUTE_TRUE_DGAP is True, the base solver does an extra pass through the dataset in order to re-evaluate these quantities for the current model. This may considerably increase the computation time of each epoch.

__init__(settings)[source]

Primal-dual base solver constructor

Parameters

settings (SolverSettings) – Solver settings.

plot()[source]

Trace plots of solver

If the problem has no constraints, displays a trace plot of the primal value estimate (see PrimalDualBase).

If the problem has constraints, displays trace plot of primal value, lagrangian value, relative duality gap, dual variables, and feasibility (constraint violation of average constraints and proportion of feasible pointwise constraints).

Returns

  • fig (matplotlib.figure.Figure) – MATPLOTLIB figure handle.

  • axes (matplotlib.axes.Axes) – MATPLOTLIB axes handle.

primal_dual_update(problem)[source]

Primal-dual update

Parameters

problem (csl.ConstrainedLearningProblem) – Constrained learning problem.

Returns

  • primal_value_est (float) – Estimate of the primal value.

  • primal_grad_norm_est (float) – Estimate of the primal gradient squared norm.

  • constraint_slacks_est (list [torch.tensor, (1, )] or None) – Estimate of the value of average constraints or None if there was no dual update.

  • pointwise_slacks_est (list [torch.tensor, (N, )] or None) – Estimate of the value of pointwise constraints or None if there was no dual update.

  • dual_grad_norm_est (float or None) – Estimate of the dual gradient squared norm or None if there was no dual update.

reset()[source]

Reset the constrained learning solver

Allows a single solver object to be used to solve multiple constrained learning problems.

solve(problem, **kwargs)[source]

Solve constrained learning problem

Parameters
  • problem (csl.ConstrainedLearningProblem) – Constrained learning problem to solve.

  • **kwargs (dict, optional) – Temporary settings to override the global solver settings for current run.