Bering.training.TrainerNode.predict

TrainerNode.predict(batch_data, device=None)[source]

Predict the class probabilities of the input data.

Parameters:
  • batch_data – Input data: torch_geometric.data.Data

  • device – Device to run the model. Options: ‘cuda’ or ‘cpu’