Source code for Bering.models._edgeclf

import logging
import numpy as np
from typing import Sequence, Union, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F 
from torch_geometric.nn import MLP

from ._image_model import ImageEncoder
from ._utils import GaussianSmearing

logger = logging.getLogger(__name__)
from ._edge_process import _sample_edges, _binning_coordinates

[docs]def _get_image_graph( pos: torch.Tensor, image: torch.Tensor, src_coords: torch.Tensor, dst_coords: torch.Tensor, conv2d_padding: int = 10, ): ''' Extract the graph image for edge embedding; recalculate source and destination node coordinates ''' pos_graph = pos[:,[1,2]] xmin, xmax = int(torch.min(pos_graph[:,0]))-conv2d_padding, int(torch.max(pos_graph[:,0]))+conv2d_padding ymin, ymax = int(torch.min(pos_graph[:,1]))-conv2d_padding, int(torch.max(pos_graph[:,1]))+conv2d_padding xmin = max(0, xmin); xmax = min(xmax, int(image.shape[-1])) ymin = max(0, ymin); ymax = min(ymax, int(image.shape[-2])) image_graph = image[:,:,ymin:ymax,xmin:xmax] src_coords[:,0] = src_coords[:,0] - xmin; src_coords[:,1] = src_coords[:,1] - ymin dst_coords[:,0] = dst_coords[:,0] - xmin; dst_coords[:,1] = dst_coords[:,1] - ymin src_coords = torch.round(src_coords).long() dst_coords = torch.round(dst_coords).long() return image_graph, src_coords, dst_coords
[docs]def _get_binned_coordinates( src_coords: torch.Tensor, dst_coords: torch.Tensor, image_binsize: int, min_image_size: int, max_image_size: int, ): minx, maxx, dist_x_appro = _binning_coordinates( src_coords[:,0], dst_coords[:,0], image_binsize = image_binsize, min_image_size = min_image_size, max_image_size = max_image_size, ) miny, maxy, dist_y_appro = _binning_coordinates( src_coords[:,1], dst_coords[:,1], image_binsize = image_binsize, min_image_size = min_image_size, max_image_size = max_image_size, ) dist_bins_2d = torch.concat((dist_x_appro[:,None], dist_y_appro[:,None]), axis = 1) avail_bins = torch.unique(dist_bins_2d, dim = 0) return minx, maxx, miny, maxy, avail_bins, dist_bins_2d
[docs]class EdgeClf(nn.Module): ''' Edge classifier model which learns node classification embedding, image embedding and distance kernel Parameters ---------- n_node_latent_features Number of latent features from node classification model image Image tensor for computing the conv2d embedding image_model Whether to use image model decoder_mlp_layer_dims List of hidden layer dimensions for MLP distance_type Type of RBF distance kernel. Options are None, 'positional', 'rbf' rbf_start Start of RBF kernel parameter \mu. Refer to :func:`~GaussianSmearing` rbf_stop Stop of RBF kernel parameter \mu. Refer to :func:`~GaussianSmearing` rbf_n_kernels Number of kernels in RBF kernel. Refer to :func:`~GaussianSmearing` rbf_learnable Whether to learn the RBF kernel in backpropagation. Refer to :func:`~GaussianSmearing` encoder_image_layer_dims_conv2d List of hidden layer dimensions for CNN in image encoder. Refer to :func:`~ImageEncoder` encoder_image_layer_dims_mlp List of hidden layer dimensions for MLP in image encoder. Refer to :func:`~ImageEncoder` subimage_binsize Binning size of subimage of edges max_subimage_size Maximal size of subimage of edges after crop min_subimage_size Minimal size of subimage of edges after crop ''' def __init__( self, n_node_latent_features: int, image: Union[torch.Tensor, np.ndarray], image_model: bool = True, decoder_mlp_layer_dims: Sequence[int] = [16, 8], distance_type: Optional[str] = 'rbf', rbf_start: float = 0, rbf_stop: float = 64, rbf_n_kernels: int = 64, rbf_learnable: bool = True, encoder_image_layer_dims_conv2d: Sequence[int] = [6, 16, 32, 64, 128], encoder_image_layer_dims_mlp: Sequence[int] = [32, 64], subimage_binsize: int = 5, max_subimage_size: int = 40, min_subimage_size: int = 5, ): super().__init__() # RBF distance kernel self.distance_type = distance_type if self.distance_type == 'rbf': self.rbf_learnable = rbf_learnable self.rbf_kernel = GaussianSmearing( start = rbf_start, stop = rbf_stop, num_kernel = rbf_n_kernels, centered=False, learnable = rbf_learnable, ) # image encoder if (image is not None) and image_model: self.image_model = True self.image_binsize = subimage_binsize self.max_image_size = max_subimage_size self.min_image_size = min_subimage_size self.encoder_image = ImageEncoder(image_dims = image.shape, cnn_layer_dims = encoder_image_layer_dims_conv2d, mlp_layer_dims = encoder_image_layer_dims_mlp) self.n_image_features = encoder_image_layer_dims_mlp[-1] num_parameters_image = sum([p.numel() for p in self.encoder_image.parameters() if p.requires_grad]) else: self.image_model = False self.encoder_image = None num_parameters_image = 0 # n latent embeddings n_latent_features_ = n_node_latent_features * 2 if distance_type is None: n_latent_features_ += 0 if distance_type == 'positional': n_latent_features_ += 2 elif distance_type == 'rbf': n_latent_features_ += rbf_n_kernels if self.encoder_image is not None: n_latent_features_ += encoder_image_layer_dims_mlp[-1] # FC decoder self.decoder = MLP( [n_latent_features_] + list(decoder_mlp_layer_dims) + [1], act = 'relu', norm = 'batch_norm' ) # parameters if self.distance_type == 'rbf': num_parameters_rbf = sum([p.numel() for p in self.rbf_kernel.parameters() if p.requires_grad]) num_parameters_decoder = sum([p.numel() for p in self.decoder.parameters() if p.requires_grad]) logger.info(f'Number of CNN parameters is {num_parameters_image}') if self.distance_type == 'rbf': logger.info(f'Number of RBF kernel parameters is {num_parameters_rbf}') logger.info(f'Number of MLP decoder parameters is {num_parameters_decoder}')
[docs] def forward( self, z_node: torch.Tensor, data: torch.Tensor, num_pos_edges: int, num_neg_edges: int, image: torch.Tensor, conv2d_padding: int = 10, ): ''' Run the decoder model from latent space z. Before running the decoder, random positive and negative edges are generated as the input. Parameters ----------- z Latent features from pretrained node classification (n samples x n latent features) data Input data loader (several graphs) num_pos_edges Number of positive edges num_neg_edges Number of negative edges image Image tensor for computing the conv2d embedding conv2d_padding add paddings in the conv2d embedding ''' # sample random edges each time edge_index, edge_labels, edge_graph_indices = _sample_edges(data, num_pos_edges, num_neg_edges) for idx, graph_index in enumerate(torch.unique(edge_graph_indices)): # get src / dst indices src = edge_index[0, torch.where(edge_graph_indices == graph_index)[0]] dst = edge_index[1, torch.where(edge_graph_indices == graph_index)[0]] edge_labels_graph = edge_labels[torch.where(edge_graph_indices == graph_index)[0]] # get weights weights = data.pos[src, -1] * data.pos[dst, -1] # get attributes edge_attr = torch.cat([z_node[src], z_node[dst]], dim = -1) # src_coords = data.pos[src, :][:,[1,2]] # 2d # dst_coords = data.pos[dst, :][:,[1,2]] src_coords = data.pos[src, :][:,[1,2,3]] # 3d dst_coords = data.pos[dst, :][:,[1,2,3]] if self.distance_type == 'rbf': edge_attr_rbf = self.rbf_kernel(x = src_coords, y = dst_coords) edge_attr = torch.cat((edge_attr, edge_attr_rbf), axis = -1) if self.image_model: import time # get conv2d embeddings t0 = time.time() pos_graph = data.pos[data.ptr[graph_index]:data.ptr[graph_index+1], :] image_graph, src_coords, dst_coords = _get_image_graph(pos_graph, image, src_coords, dst_coords, conv2d_padding) image_graph = self.encoder_image.get_conv2d_embedding(image_graph) t1 = time.time() logger.info(f'---Get image graph time: {(t1-t0):.5f} s') # binning coordinates minx, maxx, miny, maxy, avail_bins, dist_bins_2d = _get_binned_coordinates(src_coords, dst_coords, self.image_binsize, self.min_image_size, self.max_image_size) t2 = time.time() logger.info(f'---Get all binned coordinates time: {(t2-t1):.5f} s') # run the model for eachedge edge_attr_image = torch.empty((src_coords.shape[0], self.n_image_features)).double().cuda() for avail_bin in avail_bins: t3 = time.time() bin_indices = torch.where((dist_bins_2d == avail_bin).all(dim=1))[0] subimages = [] for i,j in enumerate(bin_indices): subimage = image_graph[:,:,miny[j]:maxy[j], minx[j]:maxx[j]] subimages.append(subimage) subimages = torch.cat(subimages, axis = 0) t4 = time.time() logger.info(f'--------bin size: {avail_bin}, num of subimages: {len(bin_indices)}') logger.info(f'--------Concat / read time: {(t4-t3):.5f} s') edge_attr_image_bin = self.encoder_image(subimages) edge_attr_image[bin_indices, :] = edge_attr_image_bin t5 = time.time() logger.info(f'--------encode subimages time: {(t5-t4):.5f} s') edge_attr = torch.cat([edge_attr, edge_attr_image], dim = -1) if idx == 0: edge_attr_combined = edge_attr edge_labels_combined = edge_labels_graph else: edge_attr_combined = torch.cat([edge_attr_combined, edge_attr], dim = 0) edge_labels_combined = torch.cat([edge_labels_combined, edge_labels_graph]) pred_labels = self.decoder(edge_attr_combined) pred_labels = F.sigmoid(pred_labels).squeeze() return pred_labels, edge_labels_combined, weights