Source code for csl.models

# -*- coding: utf-8 -*-
"""Model wrappers for csl module

"""

import torch

[docs]class PytorchModel: """PyTorch model wrapper for constrained learning problems. Attributes ---------- model : `torch.nn.Module` A PyTorch model. parameters : `list` [`torch.tensor`] Model parameters. Obtained directly from the underlying PyTorch module as ``model.parameters()``. Setting ``parameters``, however, expects a module state dictionary obtained by calling ``model.state_dict``. """ def __init__(self, model): self.__dict__['model'] = model def __call__(self, x): """Evaluate model output Parameters ---------- x : `torch.tensor` Input data """ return self.model(x)
[docs] def predict(self, x): """Evaluate model prediction Predicts the label of each data point in ``x``. Assumes the neural network has one output per class and returns the class corresponding to the largest output. Parameters ---------- x : `torch.tensor` Input data """ _, predicted = torch.max(self(x), 1) return predicted
def __getattr__(self, attr): """Get model attribute Passed directly to the underlying PyTorch model except for ``parameters`` which returns a generator over the model parameters. Parameters ---------- attr : `str` Attribute name """ if attr == 'parameters': return self.__dict__['model'].parameters() else: return getattr(self.model, attr) def __setattr__(self, attr, value): """Set model attribute Passed directly to the underlying PyTorch model except for ``parameters`` which expects a complete model state dictionary Parameters ---------- attr : `str` Attribute name """ if attr == 'parameters': self.model.load_state_dict(value) else: setattr(self.model, attr, value)