Bering.training.TrainerEdge
- class Bering.training.TrainerEdge(model, nodeclf_model, num_pos_edges, num_neg_edges, lr=0.001, weight_decay=0.0005, weight_posEdge=1.0, weight_negEdge=2.0)[source]
Trainer for edge classification model.
- Parameters:
model (
EdgeClf) – Edge classification modelnodeclf_model (
GCN) – Node classification model. This model is used to get the latent representation from the node embeddings.num_pos_edges (
int) – Number of positive edgesnum_neg_edges (
int) – Number of negative edgeslr (
float) – Learning rateweight_decay (
float) – Weight decayweight_posEdge (
float) – Weight for positive edges in loss functionweight_negEdge (
float) – Weight for negative edges in loss function
Methods
predict(batch_data, image)Predict the edge labels from the input data
update(loader, image)Update the model on the training set.
validate(loader, image)Validate the model on the test set.