Bering.training.Training

Bering.training.Training(bg, node_gcnq_hidden_dims=[256, 128, 64, 32, 16], node_mlp_hidden_dims=[16, 32, 32], node_lr=0.001, node_weight_decay=0.0005, node_foreground_weight=1.0, node_background_weight=1.0, node_epoches=50, node_early_stop=False, node_early_stop_patience=5, node_early_stop_delta=0.05, edge_distance_type='rbf', edge_rbf_start=0, edge_rbf_stop=64, edge_rbf_n_kernels=64, edge_rbf_learnable=True, edge_image_conv2d_hidden_dims=[6, 16, 32, 64, 128], edge_image_mlp_hidden_dims=[32, 64], edge_decoder_mlp_hidden_dims=[16, 8], edge_num_positive=1000, edge_num_negative=1000, edge_subimage_binsize=5, edge_lr=0.001, edge_weight_decay=0.0005, edge_epoches=50, edge_early_stop=False, edge_early_stop_patience=5, edge_early_stop_delta=0.05, plot_ax_size=5.0, finetune=False, baseline=False)[source]

Training both node classification and edge classification models. The training procedure is done by training node classifier TrainerNode() first and then training edge classifier TrainerEdge().

Parameters:
  • bg (Bering_Graph) – Bering_Graph object

  • node_gcnq_hidden_dims (Sequence[int]) – List of hidden layer dimensions for GCN in Bering.models.GCN()

  • node_mlp_hidden_dims (Sequence[int]) – List of hidden layer dimensions for MLP in Bering.models.GCN()

  • node_lr (float) – Learning rate for node classifier in TrainerNode()

  • node_weight_decay (float) – Weight decay for node classifier in TrainerNode()

  • node_foreground_weight (float) – Weight for segmented transcripts in loss function for node classifier in TrainerNode()

  • node_background_weight (float) – Weight for background transcripts in loss function for node classifier in TrainerNode()

  • node_epoches (int) – Number of epoches for node classifier

  • node_early_stop (bool) – Whether to use early stop for node classifier in EarlyStopper()

  • node_early_stop_patience (int) – Maximal number of consecutive times allowed to have loss greater than min_loss + min_delta before stopping for node classifier. See EarlyStopper()

  • node_early_stop_delta (float) – Minimal gap of a new loss and minimal loss for adding to a count to counter for node classifier. See EarlyStopper()

  • edge_distance_type (str) – Distance type for edge classifier. See EdgeClf()

  • edge_rbf_start (int) – Start of RBF kernel mu for edge classifier. See EdgeClf()

  • edge_rbf_stop (int) – Stop of RBF kernel mu for edge classifier. See EdgeClf()

  • edge_rbf_n_kernels (int) – Number of RBF kernels for edge classifier. See EdgeClf()

  • edge_rbf_learnable (bool) – Whether to learn RBF kernels for edge classifier. See EdgeClf()

  • edge_image_conv2d_hidden_dims (Sequence[int]) – List of hidden layer dimensions for image encoder for edge classifier. See EdgeClf()

  • edge_image_mlp_hidden_dims (Sequence[int]) – List of hidden layer dimensions for image encoder for edge classifier. See EdgeClf()

  • edge_decoder_mlp_hidden_dims (Sequence[int]) – List of hidden layer dimensions for decoder for edge classifier. See EdgeClf()

  • edge_num_positive (int) – Number of positive edges for edge classifier. See forward()

  • edge_num_negative (int) – Number of negative edges for edge classifier. See forward()

  • edge_subimage_binsize (int) – Bin size for subimages for edge classifier. See EdgeClf()

  • edge_lr (float) – Learning rate for edge classifier. See TrainerEdge()

  • edge_weight_decay (float) – Weight decay for edge classifier. See TrainerEdge()

  • edge_epoches (int) – Number of epoches for edge classifier

  • edge_early_stop (bool) – Whether to use early stop for edge classifier. See EarlyStopper()

  • edge_early_stop_patience (int) – Maximal number of consecutive times allowed to have loss greater than min_loss + min_delta before stopping for edge classifier. See EarlyStopper()

  • edge_early_stop_delta (float) – Minimal gap of a new loss and minimal loss for adding to a count to counter for edge classifier. See EarlyStopper()

  • plot_ax_size (float) – Size of the plot for node and edge classifiers. See record()

  • finetune (bool) – Whether to finetune the model. - If True, the model will be fine-tuned with pre-trained Bering.models.GCN() and :func:Bering.models.EdgeClf() - If False, the model will be trained from scratch with Bering.models.GCN() and Bering.models.EdgeClf

  • baseline (bool) – Whether to use baseline model. If True, the model will be trained with Bering.models.BaselineMLP(). If False, the model will be trained with Bering.models.GCN().

Returns:

: bg.trainer_node

Node classifier: TrainerNode()

bg.trainer_edge

Edge classifier: TrainerEdge()