Module OPTIMA.lightning.training

Collection of classes and functions specific to the training of Lightning models.

Expand source code
# -*- coding: utf-8 -*-
"""Collection of classes and functions specific to the training of Lightning models."""
import os
import logging

import numpy as np
import torch.utils.data
from lightning.pytorch.callbacks import Callback

import OPTIMA.core.training
import OPTIMA.lightning.inputs


class EarlyStopperForLightningTuning(OPTIMA.core.training.EarlyStopperForTuning, Callback):
    """_summary_.

    Returns
    -------
    _type_
        _description_
    """

    def __init__(self, *args, run_config, model_config, inputs_train, inputs_val, **kwargs):
        """_summary_.

        Parameters
        ----------
        *args : _type_
            _description_
        run_config : _type_
            _description_
        model_config : _type_
            _description_
        inputs_train : _type_
            _description_
        inputs_val : _type_
            _description_
        **kwargs : _type_
            _description_

        Returns
        -------
        _type_
            _description_
        """
        # we need to convert numpy arrays to dataloaders
        inputs_train_tensor = torch.from_numpy(inputs_train.copy()).type(torch.float)
        inputs_val_tensor = torch.from_numpy(inputs_val.copy()).type(torch.float)
        dl_train = OPTIMA.lightning.inputs.get_dataloader(run_config, model_config, inputs_train_tensor)
        dl_val = OPTIMA.lightning.inputs.get_dataloader(run_config, model_config, inputs_val_tensor)

        # provide the dataloaders to the EarlyStopper
        OPTIMA.core.training.EarlyStopperForTuning.__init__(
            self, *args, inputs_train=dl_train, inputs_val=dl_val, **kwargs
        )
        Callback.__init__(self)

    def on_validation_end(self, trainer, pl_module):
        """_summary_.

        Parameters
        ----------
        trainer : _type_
            _description_
        pl_module : _type_
            _description_

        Returns
        -------
        _type_
            _description_
        """
        logs = {}
        for metric, value in trainer.callback_metrics.items():
            logs[metric] = value.detach().cpu().numpy()

        super().at_epoch_end(trainer.current_epoch, logs=logs, pl_module=pl_module, trainer=trainer)

    def on_train_end(self, trainer, pl_module) -> None:
        """_summary_.

        Parameters
        ----------
        trainer : _type_
            _description_
        pl_module : _type_
            _description_

        Returns
        -------
        None
            _description_
        """
        super().finalize()

    def get_train_val_metric_names(self, metric, **kwargs: dict):
        """_summary_.

        Parameters
        ----------
        metric : _type_
            _description_
        **kwargs : dict
            _description_

        Returns
        -------
        _type_
            _description_
        """
        return f"train_{metric}", f"val_{metric}"

    def get_weights(self, pl_module, **kwargs) -> list[np.ndarray]:
        """_summary_.

        Parameters
        ----------
        pl_module : _type_
            _description_
        **kwargs : _type_
            _description_

        Returns
        -------
        list[np.ndarray]
            _description_
        """
        return pl_module.state_dict()  # ist ein pytorch befehl
        # return self.trainer.model.state_dict()

    def set_weights(self, weights: list[np.ndarray], pl_module, **kwargs):
        """_summary_.

        Parameters
        ----------
        weights : list[np.ndarray]
            _description_
        pl_module : _type_
            _description_
        **kwargs : _type_
            _description_

        Returns
        -------
        _type_
            _description_
        """
        pl_module.load_state_dict(weights)

    def save_model(self, output_dir: str, model_name: str, trainer, **kwargs):
        """_summary_.

        Parameters
        ----------
        output_dir : str
            _description_
        model_name : str
            _description_
        trainer : _type_
            _description_
        **kwargs : _type_
            _description_

        Returns
        -------
        _type_
            _description_
        """
        try:
            trainer.save_checkpoint(filepath=os.path.join(output_dir, f"{model_name}.ckpt"))
        except BlockingIOError:
            logging.warning(
                "BlockingIOError: [Errno 11] Unable to create file (unable to lock file, errno = 11, error message "
                "= 'Resource temporarily unavailable'). Skipping the save of this checkpoint!"
            )
        except OSError:
            logging.warning("OSError detected. Skipping the save of this checkpoint!")

    def predict(
        self, inputs: torch.utils.data.DataLoader, trainer, pl_module, verbose: int = 0, **kwargs
    ) -> np.ndarray:
        """_summary_.

        Parameters
        ----------
        inputs : torch.utils.data.DataLoader
            _description_
        trainer : _type_
            _description_
        pl_module : _type_
            _description_
        verbose : int
            _description_ (Default value = 0)
        **kwargs : _type_
            _description_

        Returns
        -------
        np.ndarray
            _description_
        """
        # cannot call the trainer's predict() method here, i.e. cannot do "pred_list = trainer.predict(pl_module, inputs)"
        # as this causes a weird error with the dataloader:
        # "lightning.fabric.utilities.exceptions.MisconfigurationException: `train_dataloader` must be implemented to be
        # used with the Lightning Trainer"???
        # Instead, make predictions manually --> need to move the inputs to the correct device!
        pred_list = []
        for x in inputs:
            pred_list.append(pl_module.predict_step(x.to(pl_module.device)))
        return np.concatenate([t.detach().cpu().numpy() for t in pred_list], axis=0)

    def stop_training(self, trainer, **kwargs) -> None:
        """Mark the training for termination due to Early Stopping.

        Parameters
        ----------
        trainer : _type_
            _description_
        **kwargs : _type_
            _description_
        """
        trainer.should_stop = True

Classes

class EarlyStopperForLightningTuning (*args, run_config, model_config, inputs_train, inputs_val, **kwargs)

summary.

Returns

_type_
description

summary.

Parameters

*args : _type_
description
run_config : _type_
description
model_config : _type_
description
inputs_train : _type_
description
inputs_val : _type_
description
**kwargs : _type_
description

Returns

_type_
description
Expand source code
class EarlyStopperForLightningTuning(OPTIMA.core.training.EarlyStopperForTuning, Callback):
    """_summary_.

    Returns
    -------
    _type_
        _description_
    """

    def __init__(self, *args, run_config, model_config, inputs_train, inputs_val, **kwargs):
        """_summary_.

        Parameters
        ----------
        *args : _type_
            _description_
        run_config : _type_
            _description_
        model_config : _type_
            _description_
        inputs_train : _type_
            _description_
        inputs_val : _type_
            _description_
        **kwargs : _type_
            _description_

        Returns
        -------
        _type_
            _description_
        """
        # we need to convert numpy arrays to dataloaders
        inputs_train_tensor = torch.from_numpy(inputs_train.copy()).type(torch.float)
        inputs_val_tensor = torch.from_numpy(inputs_val.copy()).type(torch.float)
        dl_train = OPTIMA.lightning.inputs.get_dataloader(run_config, model_config, inputs_train_tensor)
        dl_val = OPTIMA.lightning.inputs.get_dataloader(run_config, model_config, inputs_val_tensor)

        # provide the dataloaders to the EarlyStopper
        OPTIMA.core.training.EarlyStopperForTuning.__init__(
            self, *args, inputs_train=dl_train, inputs_val=dl_val, **kwargs
        )
        Callback.__init__(self)

    def on_validation_end(self, trainer, pl_module):
        """_summary_.

        Parameters
        ----------
        trainer : _type_
            _description_
        pl_module : _type_
            _description_

        Returns
        -------
        _type_
            _description_
        """
        logs = {}
        for metric, value in trainer.callback_metrics.items():
            logs[metric] = value.detach().cpu().numpy()

        super().at_epoch_end(trainer.current_epoch, logs=logs, pl_module=pl_module, trainer=trainer)

    def on_train_end(self, trainer, pl_module) -> None:
        """_summary_.

        Parameters
        ----------
        trainer : _type_
            _description_
        pl_module : _type_
            _description_

        Returns
        -------
        None
            _description_
        """
        super().finalize()

    def get_train_val_metric_names(self, metric, **kwargs: dict):
        """_summary_.

        Parameters
        ----------
        metric : _type_
            _description_
        **kwargs : dict
            _description_

        Returns
        -------
        _type_
            _description_
        """
        return f"train_{metric}", f"val_{metric}"

    def get_weights(self, pl_module, **kwargs) -> list[np.ndarray]:
        """_summary_.

        Parameters
        ----------
        pl_module : _type_
            _description_
        **kwargs : _type_
            _description_

        Returns
        -------
        list[np.ndarray]
            _description_
        """
        return pl_module.state_dict()  # ist ein pytorch befehl
        # return self.trainer.model.state_dict()

    def set_weights(self, weights: list[np.ndarray], pl_module, **kwargs):
        """_summary_.

        Parameters
        ----------
        weights : list[np.ndarray]
            _description_
        pl_module : _type_
            _description_
        **kwargs : _type_
            _description_

        Returns
        -------
        _type_
            _description_
        """
        pl_module.load_state_dict(weights)

    def save_model(self, output_dir: str, model_name: str, trainer, **kwargs):
        """_summary_.

        Parameters
        ----------
        output_dir : str
            _description_
        model_name : str
            _description_
        trainer : _type_
            _description_
        **kwargs : _type_
            _description_

        Returns
        -------
        _type_
            _description_
        """
        try:
            trainer.save_checkpoint(filepath=os.path.join(output_dir, f"{model_name}.ckpt"))
        except BlockingIOError:
            logging.warning(
                "BlockingIOError: [Errno 11] Unable to create file (unable to lock file, errno = 11, error message "
                "= 'Resource temporarily unavailable'). Skipping the save of this checkpoint!"
            )
        except OSError:
            logging.warning("OSError detected. Skipping the save of this checkpoint!")

    def predict(
        self, inputs: torch.utils.data.DataLoader, trainer, pl_module, verbose: int = 0, **kwargs
    ) -> np.ndarray:
        """_summary_.

        Parameters
        ----------
        inputs : torch.utils.data.DataLoader
            _description_
        trainer : _type_
            _description_
        pl_module : _type_
            _description_
        verbose : int
            _description_ (Default value = 0)
        **kwargs : _type_
            _description_

        Returns
        -------
        np.ndarray
            _description_
        """
        # cannot call the trainer's predict() method here, i.e. cannot do "pred_list = trainer.predict(pl_module, inputs)"
        # as this causes a weird error with the dataloader:
        # "lightning.fabric.utilities.exceptions.MisconfigurationException: `train_dataloader` must be implemented to be
        # used with the Lightning Trainer"???
        # Instead, make predictions manually --> need to move the inputs to the correct device!
        pred_list = []
        for x in inputs:
            pred_list.append(pl_module.predict_step(x.to(pl_module.device)))
        return np.concatenate([t.detach().cpu().numpy() for t in pred_list], axis=0)

    def stop_training(self, trainer, **kwargs) -> None:
        """Mark the training for termination due to Early Stopping.

        Parameters
        ----------
        trainer : _type_
            _description_
        **kwargs : _type_
            _description_
        """
        trainer.should_stop = True

Ancestors

Methods

def get_train_val_metric_names(self, metric, **kwargs: dict)

summary.

Parameters

metric : _type_
description
**kwargs : dict
description

Returns

_type_
description
Expand source code
def get_train_val_metric_names(self, metric, **kwargs: dict):
    """_summary_.

    Parameters
    ----------
    metric : _type_
        _description_
    **kwargs : dict
        _description_

    Returns
    -------
    _type_
        _description_
    """
    return f"train_{metric}", f"val_{metric}"
def get_weights(self, pl_module, **kwargs) ‑> list[numpy.ndarray]

summary.

Parameters

pl_module : _type_
description
**kwargs : _type_
description

Returns

list[np.ndarray]
description
Expand source code
def get_weights(self, pl_module, **kwargs) -> list[np.ndarray]:
    """_summary_.

    Parameters
    ----------
    pl_module : _type_
        _description_
    **kwargs : _type_
        _description_

    Returns
    -------
    list[np.ndarray]
        _description_
    """
    return pl_module.state_dict()  # ist ein pytorch befehl
    # return self.trainer.model.state_dict()
def on_train_end(self, trainer, pl_module) ‑> None

summary.

Parameters

trainer : _type_
description
pl_module : _type_
description

Returns

None
description
Expand source code
def on_train_end(self, trainer, pl_module) -> None:
    """_summary_.

    Parameters
    ----------
    trainer : _type_
        _description_
    pl_module : _type_
        _description_

    Returns
    -------
    None
        _description_
    """
    super().finalize()
def on_validation_end(self, trainer, pl_module)

summary.

Parameters

trainer : _type_
description
pl_module : _type_
description

Returns

_type_
description
Expand source code
def on_validation_end(self, trainer, pl_module):
    """_summary_.

    Parameters
    ----------
    trainer : _type_
        _description_
    pl_module : _type_
        _description_

    Returns
    -------
    _type_
        _description_
    """
    logs = {}
    for metric, value in trainer.callback_metrics.items():
        logs[metric] = value.detach().cpu().numpy()

    super().at_epoch_end(trainer.current_epoch, logs=logs, pl_module=pl_module, trainer=trainer)
def predict(self, inputs: torch.utils.data.dataloader.DataLoader, trainer, pl_module, verbose: int = 0, **kwargs) ‑> numpy.ndarray

summary.

Parameters

inputs : torch.utils.data.DataLoader
description
trainer : _type_
description
pl_module : _type_
description
verbose : int
description (Default value = 0)
**kwargs : _type_
description

Returns

np.ndarray
description
Expand source code
def predict(
    self, inputs: torch.utils.data.DataLoader, trainer, pl_module, verbose: int = 0, **kwargs
) -> np.ndarray:
    """_summary_.

    Parameters
    ----------
    inputs : torch.utils.data.DataLoader
        _description_
    trainer : _type_
        _description_
    pl_module : _type_
        _description_
    verbose : int
        _description_ (Default value = 0)
    **kwargs : _type_
        _description_

    Returns
    -------
    np.ndarray
        _description_
    """
    # cannot call the trainer's predict() method here, i.e. cannot do "pred_list = trainer.predict(pl_module, inputs)"
    # as this causes a weird error with the dataloader:
    # "lightning.fabric.utilities.exceptions.MisconfigurationException: `train_dataloader` must be implemented to be
    # used with the Lightning Trainer"???
    # Instead, make predictions manually --> need to move the inputs to the correct device!
    pred_list = []
    for x in inputs:
        pred_list.append(pl_module.predict_step(x.to(pl_module.device)))
    return np.concatenate([t.detach().cpu().numpy() for t in pred_list], axis=0)
def save_model(self, output_dir: str, model_name: str, trainer, **kwargs)

summary.

Parameters

output_dir : str
description
model_name : str
description
trainer : _type_
description
**kwargs : _type_
description

Returns

_type_
description
Expand source code
def save_model(self, output_dir: str, model_name: str, trainer, **kwargs):
    """_summary_.

    Parameters
    ----------
    output_dir : str
        _description_
    model_name : str
        _description_
    trainer : _type_
        _description_
    **kwargs : _type_
        _description_

    Returns
    -------
    _type_
        _description_
    """
    try:
        trainer.save_checkpoint(filepath=os.path.join(output_dir, f"{model_name}.ckpt"))
    except BlockingIOError:
        logging.warning(
            "BlockingIOError: [Errno 11] Unable to create file (unable to lock file, errno = 11, error message "
            "= 'Resource temporarily unavailable'). Skipping the save of this checkpoint!"
        )
    except OSError:
        logging.warning("OSError detected. Skipping the save of this checkpoint!")
def set_weights(self, weights: list[numpy.ndarray], pl_module, **kwargs)

summary.

Parameters

weights : list[np.ndarray]
description
pl_module : _type_
description
**kwargs : _type_
description

Returns

_type_
description
Expand source code
def set_weights(self, weights: list[np.ndarray], pl_module, **kwargs):
    """_summary_.

    Parameters
    ----------
    weights : list[np.ndarray]
        _description_
    pl_module : _type_
        _description_
    **kwargs : _type_
        _description_

    Returns
    -------
    _type_
        _description_
    """
    pl_module.load_state_dict(weights)
def stop_training(self, trainer, **kwargs) ‑> None

Mark the training for termination due to Early Stopping.

Parameters

trainer : _type_
description
**kwargs : _type_
description
Expand source code
def stop_training(self, trainer, **kwargs) -> None:
    """Mark the training for termination due to Early Stopping.

    Parameters
    ----------
    trainer : _type_
        _description_
    **kwargs : _type_
        _description_
    """
    trainer.should_stop = True

Inherited members