csl
¶
csl: PyTorch-based constrained learning
The csl module provides a common interface to specify and solve constrained learning problems using PyTorch.
Model wrappers¶
-
class
csl.
PytorchModel
(model)[source]¶ Bases:
object
PyTorch model wrapper for constrained learning problems.
- Variables
model (torch.nn.Module) – A PyTorch model.
parameters (list [torch.tensor]) – Model parameters. Obtained directly from the underlying PyTorch module as
model.parameters()
. Settingparameters
, however, expects a module state dictionary obtained by callingmodel.state_dict
.
Constrained learning problem¶
-
class
csl.
ConstrainedLearningProblem
[source]¶ Bases:
object
Constrained learning problem base class.
Constrained learning problems are defined by inheriting from ConstrainedLearningProblem and defining its attributes:
model
: underlying model to traindata
: data with which to train the modelbatch_size
(optional): maximum number of data points to load to memory at onceobj_function
: objective function or training lossconstraints
(optional): average constraintsrhs
(optional): right-hand side of average constraintspointwise
(optional): pointwise constraintspointwise_rhs
(optional): right-hand side of pointwise constraints
A detailed description of each of these attributes is given below.
- Variables
model (callable) –
Model used to solve the constrained learning problem. The model must have an attribute
parameters
and a method__call__
as specified below.parameters
: model parameters (list [torch.tensor] withrequires_grad=True
)__call__(x)
: takes a data batchx
and evaluates the output of the model for each data point inx
(callable)
data (list) –
Training data. Must define the methods
__len__
and__get_item__
:__len__
: Returns size of dataset (callable)__get_item__
: Returns element(s) from dataset (callable)
batch_size (int) –
Internal batch size to evaluate empirical averages.
- ..Note:: this has no effect on the training batch size. It is only
used internally to avoid running out of memory when evaluating quantities that require a full pass over the dataset.
obj_function (callable) – Objective function. Takes a list of indices (list [int]) defining a mini-batch and returns the objective function value (torch.tensor, (1, )) over that mini-batch.
constraints (list [callable]) –
Functions defining the average constraints. Each function takes
batch_idx
: list of indices defining a mini-batch (list [int])primal
: True if constraint is being evaluated for primal update or False otherwise (bool)
and returns the average constraint value (torch.tensor, (1, )) over that mini-batch.
rhs (list [float]) – List containing the right-hand side of each average constraint.
pointwise (list [callable]) –
Functions defining the pointwise constraints. Each function takes
batch_idx
: list of indices defining a mini-batch (list [int])primal
: True if constraint is being evaluated for primal update or False otherwise (bool)
and returns the pointwise constraint value (torch.tensor, (
len(batch_idx)
, )) for each point in the mini-batch.pointwise_rhs (list [torch.tensor, (N, )]) – List containing the right-hand side of each pointwise constraint.
Notes
The primal flag
When working with non-differentiable constraints, a smooth approximation can be used during the primal computation to enable gradient updates. If this approximation is good enough, i.e., if the minimum of the smooth function is a good approximation of the minimum of the non-differentiable function, then certain guarantees can be given on the solutions obtained by the primal-dual iterations.
The purpose of this flag is to allow for these alternative smooth approximations to be used when minimizing the Lagrangian using gradient descent. The original, non-differentiable loss is then used during the dual update to compute the supergradients of the Lagrangian with respect to the dual variables.
An example problem
To pose a constrained learning problem, inherit from ConstrainedLearningProblem and define its attributes before initializing the base class by calling
super().__init__()
.1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
import torch import torch.nn.functional as F import csl #################################### # MODEL # #################################### class Logistic: def __init__(self, n_features): self.parameters = [torch.zeros(1, dtype = torch.float, requires_grad = True), torch.zeros([n_features,1], dtype = torch.float, requires_grad = True)] def __call__(self, x): if len(x.shape) == 1: x = x.unsqueeze(1) yhat = self.logit(torch.mm(x, self.parameters[1]) + self.parameters[0]) return torch.cat((1-yhat, yhat), dim=1) def predict(self, x): _, predicted = torch.max(self(x), 1) return predicted @staticmethod def logit(x): return 1/(1 + torch.exp(-x)) #################################### # PROBLEM # #################################### class fairClassification(csl.ConstrainedLearningProblem): def __init__(self): self.model = Logistic(data[0][0].shape[0]) self.data = data self.obj_function = self.loss # Demographic parity self.constraints = [ self.demographic_parity ] self.rhs = [ 0.1 ] super().__init__() def loss(self, batch_idx): # Evaluate objective x, y = self.data[batch_idx] yhat = self.model(x) return F.cross_entropy(yhat, y) def demographic_parity(self, batch_idx, primal): protected_idx = 3 x, y = self.data[batch_idx] group_idx = (x[:,protected_idx] == 1) if primal: # Sigmoid approximation of indicator function yhat = self.model(x) pop_indicator = torch.sigmoid(8*(yhat[:,1] - 0.5)) group_indicator = torch.sigmoid(8*(yhat[group_idx,1] - 0.5)) else: # Indicator function yhat = self.model.predict(x) pop_indicator = yhat.float() group_indicator = yhat[group_idx].float() return pop_indicator.mean() - group_indicator.mean()
-
lagrangian
(batch_idx=None)[source]¶ Evaluate Lagrangian (and its gradient)
- Parameters
batch_idx (list [int], optional) – Indices of batch. The default is None to evaluate over the full dataset. The evaluation is done in batches size according to
batch_size
to avoid loading the full dataset to the memory.- Returns
L (float) – Lagrangian value.
obj_value (float) – Objective value.
constraints_slacks (list [torch.tensor, (1, )]) – Slacks of average constraints
pointwise_slacks (list [torch.tensor, (
len(batch_idx)
, )]) – Slacks of pointwise constraints
-
objective
(batch_idx=None)[source]¶ Evaluate the objective function
- Parameters
batch_idx (list [int], optional) – Indices of batch. The default is None to evaluate over the full dataset. The evaluation is done in batches size according to
batch_size
to avoid loading the full dataset to the memory.- Returns
obj_value – Objective value.
- Return type
float
-
slacks
(batch_idx=None)[source]¶ Evaluate constraint slacks
- Parameters
batch_idx (list [int], optional) – Indices of batch. The default is None to evaluate over the full dataset. The evaluation is done in batches size according to
batch_size
to avoid loading the full dataset to the memory.- Returns
constraint_slacks (list [float]) – Constraint violation of the average constraints.
pointwise_slacks (list [torch.tensor, (
len(batch_idx)
, )]) – Constraint violation of the pointwise constraints.
Solvers¶
-
class
csl.
PrimalThenDual
(user_settings={})[source]¶ Bases:
csl.solver_base.PrimalDualBase
-
__init__
(user_settings={})[source]¶ Primal-then-dual solver.
Update primal using a full pass over the dataset then update dual using a full pass over the dataset.
- Parameters
user_settings (dict, optional) – Dictionary containing solver settings. See SolverSettings for basic solver settings and defaults.
Additional specific settings:
batch_size
: Mini-batch size. The default is None (uses full dataset at once).shuffle
: Shuffle dataset before batching. The default is True.dual_period
: Epoch period of dual update (update dual once everydual_period
epochs). The default is 1, run once per primal epoch.
-
-
csl.
PrimalDual
¶ alias of
csl.solvers.PrimalThenDual
-
class
csl.
SimultaneousPrimalDual
(user_settings={})[source]¶ Bases:
csl.solver_base.PrimalDualBase
-
__init__
(user_settings={})[source]¶ Simultaneous primal-dual solver.
For each batch, update primal then update dual.
- Parameters
user_settings (dict, optional) – Dictionary containing solver settings. See SolverSettings for basic solver settings and defaults.
Additional specific settings:
batch_size
: Mini-batch size. The default is None (uses full dataset at once).shuffle
: Shuffle dataset before batching. The default is True.
-
Base solver¶
-
class
csl.solver_base.
SolverSettings
(specific_settings={})[source]¶ Bases:
object
Primal-dual solver settings
- Variables
settings (dict, optional) –
Dictionary containing the solver settings. The base solver valid keys and default values are listed below. Specific solvers may have additional settings.
iterations
: Maximum number of iterations. The default is 100.primal_solver
: Constructor for the PyTorch solver used to solve the primal problem (takes parameter, returns torch.optim). The default is ADAM.lr_p_scheduler
: Primal step size scheduler. The default is None (no decay).dual_solver
: Constructor for the PyTorch solver used to solve the dual problem (takes parameter, returns torch.optim). The default is ADAM.lr_d_scheduler
: Dual step size scheduler. The default is None (no decay).logger
: A logging object. The default outputs directly to the console.verbose
: Period of log printing. Every iteration is printed when logger is at DEBUG level. Default isiterations/10
. Set to 0 to deactivate.device
: Device used for computations. The default is GPU, if available, and CPU otherwise.COMPUTE_TRUE_DGAP
: Compute actual primal and dual values instead of using approximations. Note: this effectively requires an extra pass over the whole dataset per iteration. The default is False.STOP_DIVERGENCE
: Maximum value allowed before declaring that the algorithm has diverged. The default is 1e4.STOP_PVAL
: Primal value threshold. The default is None.STOP_PGRAD
: Primal gradient squared norm threshold. The default is None.STOP_ABS_DGAP
: Absolute duality gap threshold. The default is None.STOP_DGRAD
: Dual gradient squared norm threshold. The default is None.STOP_REL_DGAP
: Relative duality gap threshold. The default is None.STOP_ABS_FEAS
: Absolute feasibility threshold. The default is None.STOP_REL_FEAS
: Relative feasibility threshold. The default is None.STOP_NFEAS
: Proportion of feasible constraints threshold. The default is None.STOP_PATIENCE
: Iterations with no update threshold. The default is None.STOP_USER_DEFINED
: User-defined function. Takes problem and the solver state_dict and returns True to stop or False to continue. The default is None.
- Raises
ValueError – When trying to get or set a non-existant setting
-
__init__
(specific_settings={})[source]¶ Primal-dual solver settings constructor
- Parameters
specific_settings (dict, optional) – Solver-specific global settings with default values. The default is
{}
.
-
class
csl.solver_base.
PrimalDualBase
(settings)[source]¶ Bases:
object
Primal-dual base solver
- Variables
primal_solver (torch.optim) – Primal problem solver
primal_step_size (torch.optim.lr_scheduler) – Primal problem step size scheduler
dual_solver (torch.optim) – Dual problem solver
dual_step_size (torch.optim.lr_scheduler) – Dual problem step size scheduler
state_dict (dict) – Internal solver state
settings (SolverSettings) – Solver settings
Notes
Stopping criteria
By default, the solver stops only once it reaches the maximum number of iterations or if divergence is detected (using the threshold
STOP_DIVERGENCE
). When defined, other stopping modes are:Primal absolute optimality: based on
STOP_PVAL
orSTOP_PGRAD
(either or both depending on whether exist). Applies only to unconstrained problems.Primal absolute optimality and absolute feasibility: as above but also check
STOP_ABS_FEAS
(on average constraints) andSTOP_NFEAS
(for pointwise constraints), depending on whether they exist.Primal-dual absolute optimality and absolute feasibility: checks optimality using
STOP_ABS_DGAP
or bothSTOP_PGRAD
andSTOP_DGRAD
. Feasibility is checked usingSTOP_ABS_FEAS
(on average constraints) andSTOP_NFEAS
(for pointwise constraints), depending on whether they exist.Primal-dual relative optimality and relative feasibility: checks optimality using
STOP_REL_DGAP
and feasibility usingSTOP_REL_FEAS
(for average constraints) andSTOP_NFEAS
(for pointwise constraints).Stalled: based on whether neither primal value nor any constraint violation has improved over the span of
STOP_PATIENCE
iterations.User-defined criterion: stops if
STOP_USER_DEFINED
returnsTrue
.
Implementing a new solver
Subclasses must:
Call
PrimalDualBase.__init__
with a SolverSettings object that includes its solver-specific settings (if any).Define
primal_dual_update
Solver states
Specific solvers (and users) are free to modify and add to
state_dict
. This can be useful to save internal states of the user-defined stopping criterionSTOP_USER_DEFINED
, which can also be used as a validation hook. You should use a unique capitalized prefix in order to avoid interfering with the normal operation of the primal-dual solver (unless you know what you are doing).PrimalDualBase
has the following internal states:iteration
(int): Iteration numberno_update_iterations
(int): Number of iterations without updates left until solver gives up and stops early. Undefined ifSTOP_PATIENCE
is None. Initial value:STOP_PATIENCE
.primal_solver
(dict):state_dict
of the primal problem solverdual_solver
(dict):state_dict
of the dual problem solverprimal_value
(float): Current primal value, i.e., objective function valueprimal_grad_norm
(float): Squared norm of the primal gradient.lagrangian_value
(float): Current value of the Lagrangian. If the problem is convex, converges to the value of the dual function.dual_grad_norm
(float): Squared norm of the dual gradient (constraint slacks).duality_gap
(float): Duality gap, i.e., P - Drel_duality_gap
(float): Relative duality gap, i.e., (P - D)/Pconstraint_feas
(np.array): Constraint violation of average constraints. Non-positive if the solution satisfies the constraint.constraint_rel_feas
(np.array): Relative constraint violation of average constraints, i.e., slack divided by right-hand side. Non-positive if the solution satisfies the constraint.constraint_nfeas
(np.array): Proportion of feasible average constraints.lambdas_max
(float): Maximum value of dual variables (average constraints).pointwise_feas
(np.array): Constraint violation of pointwise constraints. Non-positive if the solution satisfies the constraint.pointwise_nfeas
(np.array): Proportion of feasible points for each pointwise constraint.mus_max
(float): Maximum value of dual variables (pointwise constraints).primal_value_log
(np.array): Primal value across iterations.lagrangian_value_log
(np.array): Lagrangian value across iterations.lambdas_log
(np.array): Dual variables (average constraints) across iterations.feas_log
(np.array): Constraint violation (average constraints) across iterations.rel_feas_log
(np.array): Relative constraint violation (average constraints) across iterations.mus_log
(np.array): Dual variables (pointwise constraints) across iterations.nfeas_log
(np.array): Proportion of feasible pointwise constraints across iterations.HAS_CONSTRAINTS
(bool): True if learning problem has constraints and False otherwise.N_AVG_CONSTRAINTS
(int): Number of average constraints of learning problem.N_PTW_CONSTRAINTS
(int): Number of pointwise constraints of learning problem.
Except for
iteration
, states ending in_log
, and flags (capitalized states), the value of the state in the previous iteration can be accessed by appending_prev
to the state variable name.Warning
The values of
primal_value
,lagrangian_value
,duality_gap
,rel_duality_gap
,constraint_feas
,constraint_rel_feas
,pointwise_feas
,pointwise_nfeas
should be considered as estimates unlessCOMPUTE_TRUE_DGAP
is True. This is particularly an issue when the solver uses batches rather than operating over the full dataset (e.g., SGD). In these cases, the primal and/or dual variables are updated between between batches, so their average is not representative of the current performance. WhenCOMPUTE_TRUE_DGAP
is True, the base solver does an extra pass through the dataset in order to re-evaluate these quantities for the current model. This may considerably increase the computation time of each epoch.-
__init__
(settings)[source]¶ Primal-dual base solver constructor
- Parameters
settings (SolverSettings) – Solver settings.
-
plot
()[source]¶ Trace plots of solver
If the problem has no constraints, displays a trace plot of the primal value estimate (see PrimalDualBase).
If the problem has constraints, displays trace plot of primal value, lagrangian value, relative duality gap, dual variables, and feasibility (constraint violation of average constraints and proportion of feasible pointwise constraints).
- Returns
fig (matplotlib.figure.Figure) – MATPLOTLIB figure handle.
axes (matplotlib.axes.Axes) – MATPLOTLIB axes handle.
-
primal_dual_update
(problem)[source]¶ Primal-dual update
- Parameters
problem (csl.ConstrainedLearningProblem) – Constrained learning problem.
- Returns
primal_value_est (float) – Estimate of the primal value.
primal_grad_norm_est (float) – Estimate of the primal gradient squared norm.
constraint_slacks_est (list [torch.tensor, (1, )] or None) – Estimate of the value of average constraints or None if there was no dual update.
pointwise_slacks_est (list [torch.tensor, (N, )] or None) – Estimate of the value of pointwise constraints or None if there was no dual update.
dual_grad_norm_est (float or None) – Estimate of the dual gradient squared norm or None if there was no dual update.