Bering.training.Training
- Bering.training.Training(bg, node_gcnq_hidden_dims=[256, 128, 64, 32, 16], node_mlp_hidden_dims=[16, 32, 32], node_lr=0.001, node_weight_decay=0.0005, node_foreground_weight=1.0, node_background_weight=1.0, node_epoches=50, node_early_stop=False, node_early_stop_patience=5, node_early_stop_delta=0.05, edge_distance_type='rbf', edge_rbf_start=0, edge_rbf_stop=64, edge_rbf_n_kernels=64, edge_rbf_learnable=True, edge_image_conv2d_hidden_dims=[6, 16, 32, 64, 128], edge_image_mlp_hidden_dims=[32, 64], edge_decoder_mlp_hidden_dims=[16, 8], edge_num_positive=1000, edge_num_negative=1000, edge_subimage_binsize=5, edge_lr=0.001, edge_weight_decay=0.0005, edge_epoches=50, edge_early_stop=False, edge_early_stop_patience=5, edge_early_stop_delta=0.05, plot_ax_size=5.0, finetune=False, baseline=False)[source]
Training both node classification and edge classification models. The training procedure is done by training node classifier
TrainerNode()first and then training edge classifierTrainerEdge().- Parameters:
bg (
Bering_Graph) – Bering_Graph objectnode_gcnq_hidden_dims (
Sequence[int]) – List of hidden layer dimensions for GCN inBering.models.GCN()node_mlp_hidden_dims (
Sequence[int]) – List of hidden layer dimensions for MLP inBering.models.GCN()node_lr (
float) – Learning rate for node classifier inTrainerNode()node_weight_decay (
float) – Weight decay for node classifier inTrainerNode()node_foreground_weight (
float) – Weight for segmented transcripts in loss function for node classifier inTrainerNode()node_background_weight (
float) – Weight for background transcripts in loss function for node classifier inTrainerNode()node_epoches (
int) – Number of epoches for node classifiernode_early_stop (
bool) – Whether to use early stop for node classifier inEarlyStopper()node_early_stop_patience (
int) – Maximal number of consecutive times allowed to have loss greater than min_loss + min_delta before stopping for node classifier. SeeEarlyStopper()node_early_stop_delta (
float) – Minimal gap of a new loss and minimal loss for adding to a count to counter for node classifier. SeeEarlyStopper()edge_distance_type (
str) – Distance type for edge classifier. SeeEdgeClf()edge_rbf_start (
int) – Start of RBF kernel mu for edge classifier. SeeEdgeClf()edge_rbf_stop (
int) – Stop of RBF kernel mu for edge classifier. SeeEdgeClf()edge_rbf_n_kernels (
int) – Number of RBF kernels for edge classifier. SeeEdgeClf()edge_rbf_learnable (
bool) – Whether to learn RBF kernels for edge classifier. SeeEdgeClf()edge_image_conv2d_hidden_dims (
Sequence[int]) – List of hidden layer dimensions for image encoder for edge classifier. SeeEdgeClf()edge_image_mlp_hidden_dims (
Sequence[int]) – List of hidden layer dimensions for image encoder for edge classifier. SeeEdgeClf()edge_decoder_mlp_hidden_dims (
Sequence[int]) – List of hidden layer dimensions for decoder for edge classifier. SeeEdgeClf()edge_num_positive (
int) – Number of positive edges for edge classifier. Seeforward()edge_num_negative (
int) – Number of negative edges for edge classifier. Seeforward()edge_subimage_binsize (
int) – Bin size for subimages for edge classifier. SeeEdgeClf()edge_lr (
float) – Learning rate for edge classifier. SeeTrainerEdge()edge_weight_decay (
float) – Weight decay for edge classifier. SeeTrainerEdge()edge_epoches (
int) – Number of epoches for edge classifieredge_early_stop (
bool) – Whether to use early stop for edge classifier. SeeEarlyStopper()edge_early_stop_patience (
int) – Maximal number of consecutive times allowed to have loss greater than min_loss + min_delta before stopping for edge classifier. SeeEarlyStopper()edge_early_stop_delta (
float) – Minimal gap of a new loss and minimal loss for adding to a count to counter for edge classifier. SeeEarlyStopper()plot_ax_size (
float) – Size of the plot for node and edge classifiers. Seerecord()finetune (
bool) – Whether to finetune the model. - IfTrue, the model will be fine-tuned with pre-trainedBering.models.GCN()and :func:Bering.models.EdgeClf()- IfFalse, the model will be trained from scratch withBering.models.GCN()and Bering.models.EdgeClfbaseline (
bool) – Whether to use baseline model. IfTrue, the model will be trained withBering.models.BaselineMLP(). IfFalse, the model will be trained withBering.models.GCN().
- Returns:
:
bg.trainer_nodeNode classifier:
TrainerNode()bg.trainer_edgeEdge classifier:
TrainerEdge()