Bering.training.TrainerNode

class Bering.training.TrainerNode(model=<class 'Bering.models._gnn.GCN'>, lr=0.001, weight_decay=0.0005, weight_seg=1.0, weight_bg=1.0)[source]

Trainer for node classification model.

Parameters:
  • model (Optional[Module]) – Node classification model

  • lr (float) – Learning rate

  • weight_decay (float) – Weight decay

  • weight_seg (float) – Weight for segmented transcripts in loss function

  • weight_bg (float) – Weight for background transcripts in loss function

Methods

predict(batch_data[, device])

Predict the class probabilities of the input data.

update(loader)

Update the model on the training set.

validate(loader)

Validate the model on the test set.