Source code for Bering.training.train

import os
import logging
import datetime
import numpy as np
# from rich.progress import track
from tqdm import tqdm
from typing import Sequence

import torch
from torch_geometric.loader import DataLoader

from ._record import record_init, record, is_notebook
from ._trainer_node import TrainerNode
from ._trainer_edge import TrainerEdge
from ..models import GCN, BaselineMLP, EdgeClf
from ..objects import Bering_Graph as BrGraph

logger = logging.getLogger(__name__)

[docs]class EarlyStopper: ''' Early Stop implementation for loss values Paramemters ----------- patience Maximal number of consecutive times allowed to have loss greater than min_loss + min_delta before stopping min_delta Minimal gap of a new loss and minimal loss for adding to a count to counter ''' def __init__( self, patience: int = 6, min_delta: float = 0.025, ): self.patience = patience self.min_delta = min_delta self.counter = 0 self.min_validation_loss = np.inf
[docs] def early_stop(self, validation_loss): ''' Check if the training should be stopped Parameters ---------- validation_loss Validation loss value ''' if validation_loss < self.min_validation_loss: self.min_validation_loss = validation_loss self.counter = 0 elif validation_loss > (self.min_validation_loss + self.min_delta): self.counter += 1 if self.counter >= self.patience: return True return False
def _trainNode( trainer: TrainerNode, train_loader: DataLoader, test_loader: DataLoader, trainer_type: str, plot_folder: str, epoches: int = 200, early_stop: bool = True, early_stop_patience: int = 5, early_stop_min_delta: float = 0.025, plot_ax_size: float = 5.0, ): ''' Training procedure for node classifier ''' record_init(trainer, trainer_type) early_stopper = EarlyStopper(early_stop_patience, early_stop_min_delta) pbar = tqdm(range(epoches), desc = 'Training node classifier', colour='blue') epoch_interval = 5 plotting = False if is_notebook() else True for epoch in pbar: train_loss = trainer.update(train_loader) # loss function = CrossEntropy validation_loss = trainer.validate(test_loader) if early_stop and early_stopper.early_stop(validation_loss): break if (epoch % epoch_interval == 0) or epoch == (epoches - 1): if epoch == (epoches - 1): plotting = True record( trainer, None, trainer_type, epoch, 'node', train_loss, validation_loss, train_loader, test_loader, logger, plot_folder = plot_folder, plotting = plotting, ax_size = plot_ax_size, plot_name = f'{trainer_type}_node.png' ) return trainer def _trainEdge( trainer: TrainerEdge, image, train_loader: DataLoader, test_loader: DataLoader, trainer_type: str, plot_folder: str, epoches: int = 200, early_stop: bool = True, early_stop_patience: int = 5, early_stop_min_delta: float = 0.025, plot_ax_size: float = 5.0, ): ''' Training procedure for edge classifier ''' record_init(trainer, trainer_type) early_stopper = EarlyStopper(early_stop_patience, early_stop_min_delta) pbar = tqdm(range(epoches), desc = 'Training edge classifier', colour='red') epoch_interval = 5 plotting = False if is_notebook() else True for epoch in pbar: train_loss = trainer.update(train_loader, image) # loss function = CrossEntropy validation_loss = trainer.validate(test_loader, image) if early_stop and early_stopper.early_stop(validation_loss): break if (epoch % epoch_interval == 0) or epoch == (epoches - 1): if epoch == (epoches - 1): plotting = True _, auc_test, _, prec_test, _, _ = record( trainer, image, trainer_type, epoch, 'edge', train_loss, validation_loss, train_loader, test_loader, logger, plot_folder = plot_folder, plotting = plotting, plot_name = f'{trainer_type}_edge.png', ax_size = plot_ax_size, ) return trainer
[docs]def Training( bg: BrGraph, node_gcnq_hidden_dims: Sequence[int] = [256, 128, 64, 32, 16], node_mlp_hidden_dims: Sequence[int] = [16, 32, 32], node_lr: float = 1e-3, node_weight_decay: float = 5e-4, node_foreground_weight: float = 1.0, node_background_weight: float = 1.0, node_epoches: int = 50, node_early_stop: bool = False, node_early_stop_patience: int = 5, node_early_stop_delta: float = 0.05, edge_distance_type: str = 'rbf', edge_rbf_start: int = 0, edge_rbf_stop: int = 64, edge_rbf_n_kernels: int = 64, edge_rbf_learnable: bool = True, edge_image_conv2d_hidden_dims: Sequence[int] = [6, 16, 32, 64, 128], edge_image_mlp_hidden_dims: Sequence[int] = [32, 64], edge_decoder_mlp_hidden_dims: Sequence[int] = [16, 8], edge_num_positive: int = 1000, edge_num_negative: int = 1000, edge_subimage_binsize: int = 5, edge_lr: float = 1e-3, edge_weight_decay: float = 5e-4, edge_epoches: int = 50, edge_early_stop: bool = False, edge_early_stop_patience: int = 5, edge_early_stop_delta: float = 0.05, plot_ax_size: float = 5.0, finetune: bool = False, baseline: bool = False, ): ''' Training both node classification and edge classification models. The training procedure is done by training node classifier :func:`~TrainerNode` first and then training edge classifier :func:`~TrainerEdge`. Parameters ---------- bg Bering_Graph object node_gcnq_hidden_dims List of hidden layer dimensions for GCN in :func:`Bering.models.GCN` node_mlp_hidden_dims List of hidden layer dimensions for MLP in :func:`Bering.models.GCN` node_lr Learning rate for node classifier in :func:`TrainerNode` node_weight_decay Weight decay for node classifier in :func:`TrainerNode` node_foreground_weight Weight for segmented transcripts in loss function for node classifier in :func:`TrainerNode` node_background_weight Weight for background transcripts in loss function for node classifier in :func:`TrainerNode` node_epoches Number of epoches for node classifier node_early_stop Whether to use early stop for node classifier in :func:`~EarlyStopper` node_early_stop_patience Maximal number of consecutive times allowed to have loss greater than min_loss + min_delta before stopping for node classifier. See :func:`~EarlyStopper` node_early_stop_delta Minimal gap of a new loss and minimal loss for adding to a count to counter for node classifier. See :func:`~EarlyStopper` edge_distance_type Distance type for edge classifier. See :func:`~Bering.models.EdgeClf` edge_rbf_start Start of RBF kernel \mu for edge classifier. See :func:`~Bering.models.EdgeClf` edge_rbf_stop Stop of RBF kernel \mu for edge classifier. See :func:`~Bering.models.EdgeClf` edge_rbf_n_kernels Number of RBF kernels for edge classifier. See :func:`~Bering.models.EdgeClf` edge_rbf_learnable Whether to learn RBF kernels for edge classifier. See :func:`~Bering.models.EdgeClf` edge_image_conv2d_hidden_dims List of hidden layer dimensions for image encoder for edge classifier. See :func:`~Bering.models.EdgeClf` edge_image_mlp_hidden_dims List of hidden layer dimensions for image encoder for edge classifier. See :func:`~Bering.models.EdgeClf` edge_decoder_mlp_hidden_dims List of hidden layer dimensions for decoder for edge classifier. See :func:`~Bering.models.EdgeClf` edge_num_positive Number of positive edges for edge classifier. See :func:`~Bering.models.EdgeClf.forward` edge_num_negative Number of negative edges for edge classifier. See :func:`~Bering.models.EdgeClf.forward` edge_subimage_binsize Bin size for subimages for edge classifier. See :func:`~Bering.models.EdgeClf` edge_lr Learning rate for edge classifier. See :func:`TrainerEdge` edge_weight_decay Weight decay for edge classifier. See :func:`TrainerEdge` edge_epoches Number of epoches for edge classifier edge_early_stop Whether to use early stop for edge classifier. See :func:`~EarlyStopper` edge_early_stop_patience Maximal number of consecutive times allowed to have loss greater than min_loss + min_delta before stopping for edge classifier. See :func:`~EarlyStopper` edge_early_stop_delta Minimal gap of a new loss and minimal loss for adding to a count to counter for edge classifier. See :func:`~EarlyStopper` plot_ax_size Size of the plot for node and edge classifiers. See :func:`~record` finetune Whether to finetune the model. - If ``True``, the model will be fine-tuned with pre-trained :func:`Bering.models.GCN` and :func::func:`Bering.models.EdgeClf` - If ``False``, the model will be trained from scratch with :func:`Bering.models.GCN` and `Bering.models.EdgeClf` baseline Whether to use baseline model. If ``True``, the model will be trained with :func:`Bering.models.BaselineMLP`. If ``False``, the model will be trained with :func:`Bering.models.GCN`. Returns ------- ``bg.trainer_node`` Node classifier: :func:`~TrainerNode` ``bg.trainer_edge`` Edge classifier: :func:`~TrainerEdge` ''' # initialize models if not finetune: if not baseline: nodeclf = GCN( n_features = bg.n_node_features, n_classes = bg.n_labels, gcn_hidden_layer_dims = node_gcnq_hidden_dims, mlp_hidden_layer_dims = node_mlp_hidden_dims, ) else: nodeclf = BaselineMLP( n_features = bg.n_node_features, n_classes = bg.n_labels, mlp_hidden_layer_dims = node_gcnq_hidden_dims[:-1] + node_mlp_hidden_dims, ) edgeclf = EdgeClf( n_node_latent_features = node_mlp_hidden_dims[1], image = bg.image_raw, image_model = (bg.image_raw is not None), decoder_mlp_layer_dims = edge_decoder_mlp_hidden_dims, distance_type = edge_distance_type, rbf_start = edge_rbf_start, rbf_stop = edge_rbf_stop, rbf_n_kernels = edge_rbf_n_kernels, rbf_learnable = edge_rbf_learnable, encoder_image_layer_dims_conv2d = edge_image_conv2d_hidden_dims, encoder_image_layer_dims_mlp = edge_image_mlp_hidden_dims, subimage_binsize = edge_subimage_binsize, max_subimage_size = bg.window_size - edge_subimage_binsize, min_subimage_size = edge_subimage_binsize, ) # train node clf performance_folder = 'figures/performance_' + datetime.datetime.now().strftime('%m-%d %H-%M-%S'); os.mkdir(performance_folder) if not finetune: bg.trainer_node = TrainerNode(nodeclf, lr = node_lr, weight_decay = node_weight_decay, weight_seg = node_foreground_weight, weight_bg = node_background_weight) keyword = 'training_GCN' if baseline == False else 'training_baseline' bg.trainer_node = _trainNode( bg.trainer_node, bg.train_loader, bg.test_loader, keyword, performance_folder, epoches = node_epoches, early_stop = node_early_stop, early_stop_patience = node_early_stop_patience, early_stop_min_delta = node_early_stop_delta, plot_ax_size = plot_ax_size, ) # freeze node clf for param in bg.trainer_node.model.parameters(): param.require_grad = False # train edge clf if bg.image_raw is None: image_ = None else: image_ = torch.from_numpy(bg.image_raw).double().cuda() image_ = image_[None, :, :, :] if not finetune: bg.trainer_edge = TrainerEdge( edgeclf, bg.trainer_node.model, lr = edge_lr, weight_decay = edge_weight_decay, num_pos_edges = edge_num_positive, num_neg_edges = edge_num_negative, ) else: bg.trainer_edge.nodeclf_model = bg.trainer_node.model bg.trainer_edge = _trainEdge( bg.trainer_edge, image_, bg.train_loader, bg.test_loader, keyword, performance_folder, epoches = edge_epoches, early_stop = edge_early_stop, early_stop_patience = edge_early_stop_patience, early_stop_min_delta = edge_early_stop_delta, plot_ax_size = plot_ax_size, )