Bering.training.TrainerEdge

class Bering.training.TrainerEdge(model, nodeclf_model, num_pos_edges, num_neg_edges, lr=0.001, weight_decay=0.0005, weight_posEdge=1.0, weight_negEdge=2.0)[source]

Trainer for edge classification model.

Parameters:
  • model (EdgeClf) – Edge classification model

  • nodeclf_model (GCN) – Node classification model. This model is used to get the latent representation from the node embeddings.

  • num_pos_edges (int) – Number of positive edges

  • num_neg_edges (int) – Number of negative edges

  • lr (float) – Learning rate

  • weight_decay (float) – Weight decay

  • weight_posEdge (float) – Weight for positive edges in loss function

  • weight_negEdge (float) – Weight for negative edges in loss function

Methods

predict(batch_data, image)

Predict the edge labels from the input data

update(loader, image)

Update the model on the training set.

validate(loader, image)

Validate the model on the test set.