Bering.training.TrainerNode
- class Bering.training.TrainerNode(model=<class 'Bering.models._gnn.GCN'>, lr=0.001, weight_decay=0.0005, weight_seg=1.0, weight_bg=1.0)[source]
Trainer for node classification model.
- Parameters:
Methods
predict(batch_data[, device])Predict the class probabilities of the input data.
update(loader)Update the model on the training set.
validate(loader)Validate the model on the test set.