Source code for Bering.training._trainer_edge

import logging
import numpy as np
        
import torch
import torch.optim as optim
from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score

from ..models import GCN, EdgeClf
from ._settings import TRAIN_KEYS as TR_KEYS
logger = logging.getLogger(__name__)

[docs]class TrainerEdge(object): ''' Trainer for edge classification model. Parameters ---------- model Edge classification model nodeclf_model Node classification model. This model is used to get the latent representation from the node embeddings. num_pos_edges Number of positive edges num_neg_edges Number of negative edges lr Learning rate weight_decay Weight decay weight_posEdge Weight for positive edges in loss function weight_negEdge Weight for negative edges in loss function ''' def __init__( self, model: EdgeClf, nodeclf_model: GCN, num_pos_edges: int, num_neg_edges: int, lr: float = TR_KEYS.LEARNING_RATE, weight_decay: float = TR_KEYS.WEIGHT_DECAY, weight_posEdge: float = TR_KEYS.LOSS_WEIGHTS_POSEDGES, weight_negEdge: float = TR_KEYS.LOSS_WEIGHTS_NEGEDGES ): self.model = model self.lr = lr self.weight_decay = weight_decay self.optimizer = optim.Adam(self.model.parameters(), lr = lr, weight_decay = weight_decay) self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model = self.model.to(self.device) self.nodeclf_model = nodeclf_model self.model.double() self.weight_posEdge = weight_posEdge self.weight_negEdge = weight_negEdge self.num_pos_edges = num_pos_edges self.num_neg_edges = num_neg_edges
[docs] def update( self, loader, image: torch.Tensor, ): ''' Update the model on the training set. Parameters ---------- loader Training data loader: ``torch_geometric.data.DataLoader`` image image as the input for the image encoder ''' running_loss = 0.0 for idx, batch_data in enumerate(loader): logger.info(f'Training batch {idx}') batch_data.to(self.device) self.model.train() self.optimizer.zero_grad() self.nodeclf_model.eval() z = self.nodeclf_model.get_latent(batch_data) if image is None: logits, edge_labels, _ = self.model(z, batch_data, self.num_pos_edges, self.num_neg_edges, None) else: logits, edge_labels, _ = self.model(z, batch_data, self.num_pos_edges, self.num_neg_edges, torch.clone(image)) loss = torch.nn.BCELoss()(logits.double(), edge_labels.double()) loss.backward() self.optimizer.step() running_loss += loss.item() return running_loss
[docs] @torch.no_grad() def validate( self, loader, image: torch.Tensor, ): ''' Validate the model on the test set. Parameters ---------- loader Testing data loader: ``torch_geometric.data.DataLoader`` image image as the input for the image encoder ''' running_loss = 0.0 for idx, batch_data in enumerate(loader): logger.info(f'Testing batch {idx}') batch_data.to(self.device) self.model.eval() self.nodeclf_model.eval() z = self.nodeclf_model.get_latent(batch_data) if image is None: logits, edge_labels, _ = self.model(z, batch_data, self.num_pos_edges, self.num_neg_edges, None) else: logits, edge_labels, _ = self.model(z, batch_data, self.num_pos_edges, self.num_neg_edges, torch.clone(image)) # edge_weights = torch.FloatTensor([self.weight_negEdge] * len(edge_labels)).cuda() # edge_weights[torch.where(edge_labels == 1)[0]] = self.weight_posEdge # loss = torch.nn.BCELoss(weight = edge_weights)(logits.double(), edge_labels.double()) loss = torch.nn.BCELoss()(logits.double(), edge_labels.double()) running_loss += loss.item() return running_loss
[docs] @torch.no_grad() def predict(self, batch_data, image): ''' Predict the edge labels from the input data Parameters ---------- batch_data Input data image image as the input for the image encoder ''' self.model.eval() batch_data = batch_data.to(self.device) self.nodeclf_model.eval() z = self.nodeclf_model.get_latent(batch_data) if image is None: preds, y, _ = self.model(z, batch_data, self.num_pos_edges, self.num_neg_edges, None) else: preds, y, _ = self.model(z, batch_data, self.num_pos_edges, self.num_neg_edges, torch.clone(image)) y, preds = y.detach().cpu().numpy(), preds.detach().cpu().numpy() preds_binary = np.round(preds).astype(np.int16) auc_score = roc_auc_score(y, preds) precision = average_precision_score(y, preds) accu = accuracy_score(y, preds_binary) err_pn = np.round(len(np.intersect1d(np.where(y == 1)[0], np.where(preds_binary == 0)[0])) / len(y), 2) # error rate of positive -> negative prediction err_np = np.round(len(np.intersect1d(np.where(y == 0)[0], np.where(preds_binary == 1)[0])) / len(y), 2) return auc_score, precision, accu, err_pn, err_np