Source code for Bering.models._edge_process

import logging
import collections
import torch
import torch.nn.functional as F 

import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader

from ._gnn import GCN

logger = logging.getLogger(__name__)

[docs]def _sample_edges( data: DataLoader, num_pos_edges: int, num_neg_edges: int, num_edges_perG: int = 10000 ): ''' Sample random positive & negative edges as the input of decoder. Parameters ---------- data Input data loader; Node coordinates and cell ids are stored in data.pos num_pos_edges Number of positive edges num_neg_edges Number of negative edges num_edges_perG Number of edges per graph ''' # init N = data.pos.shape[0] cells = data.pos[:,4] # get all edges if hasattr(data, "ptr"): batch_size = len(data.ptr) - 1 num_nodes = data.ptr[1:] - data.ptr[:-1] num_cum_nodes = data.ptr else: batch_size = 1 num_nodes = torch.tensor([N], dtype=torch.long, device="cuda") num_cum_nodes = torch.tensor([0, N], dtype=torch.long, device="cuda") # get edges per graph edges = torch.rand((2, batch_size, num_edges_perG), dtype=torch.float, device="cuda") # for each graph, we sample num_edges_perG edges edges_graph_indices = torch.LongTensor( list(range(batch_size))).repeat_interleave(edges.shape[-1]).reshape(1,batch_size,edges.shape[-1] ).cuda() # add graph indices to each edge edges = (edges * num_nodes[None, :, None]).long() edges = edges + num_cum_nodes[None, :-1, None] edges = torch.cat([edges, edges_graph_indices], dim = 0) # 3 x batch_size x num_edges edges = edges.flatten(-2, -1) edges = edges[:, edges[0] != edges[1]] # get all labels srcs = edges[0]; dsts = edges[1] pos_edge_index = edges[:, torch.where((cells[srcs] == cells[dsts]) & (cells[srcs] != -1) & (cells[dsts] != -1))[0]] neg_edge_index = edges[:, torch.where((cells[srcs] != cells[dsts]) & (cells[srcs] != -1) & (cells[dsts] != -1))[0]] # ensure equal number of pos / neg edges num_min_edges = min(pos_edge_index.shape[1], neg_edge_index.shape[1]) num_pos_edges = min(num_min_edges, num_pos_edges) num_neg_edges = min(num_min_edges, num_neg_edges) logger.info( f'Numer of total positive edges: {pos_edge_index.shape[1]}; negative edges: {neg_edge_index.shape[1]}; final edges for pos/neg types: {num_pos_edges}, {num_neg_edges}' ) # format pos/neg edges pos_edge_index = pos_edge_index[:, torch.randint(0, pos_edge_index.shape[1], (num_pos_edges,), dtype = torch.long)] pos_edge_graph_indices = torch.cat([pos_edge_index[2], pos_edge_index[2]]) pos_edge_index = torch.cat([pos_edge_index[:2, :], pos_edge_index[[1,0], :]], dim = -1) neg_edge_index = neg_edge_index[:, torch.randint(0, neg_edge_index.shape[1], (num_neg_edges,), dtype = torch.long)] neg_edge_graph_indices = torch.cat([neg_edge_index[2], neg_edge_index[2]]) neg_edge_index = torch.cat([neg_edge_index[:2, :], neg_edge_index[[1,0], :]], dim = -1) edge_index = torch.cat((pos_edge_index, neg_edge_index), axis = -1) edge_labels = torch.cat(( torch.ones((pos_edge_index.shape[1], ), dtype = torch.float64, device = 'cuda'), torch.zeros((neg_edge_index.shape[1], ), dtype = torch.float64, device = 'cuda'), )) edge_graph_index = torch.cat([pos_edge_graph_indices, neg_edge_graph_indices]) return edge_index, edge_labels, edge_graph_index
[docs]def _binning_coordinates( src_coords_x: torch.Tensor, dst_coords_x: torch.Tensor, image_binsize: int = 5, min_image_size: int = 5, max_image_size: int = 40, ): ''' Binning the distance between source and destination nodes in 1d (either x or y) Input: source and destination node coordinates in height (or width) with shape (n,) Output: minimal and maximal coordinates in height (or width) for source and destination. New coordinates after binning Parameters ---------- src_coords_x 1d coordinates of source nodes dst_coords_x 1d coordinates of destination nodes image_binsize bin size of the image in height and width min_image_size minimal image size in height and width max_image_size maximal image size in height and width ''' # input src coordinates (n x 2) and dst coordinates (n x 2); all are integars # get min max values minx = torch.min(src_coords_x, dst_coords_x) maxx = torch.max(src_coords_x, dst_coords_x) dist_x = maxx - minx # small distances smalldist_indices = torch.where(dist_x < image_binsize)[0] smalldist = dist_x[smalldist_indices] gap_to_minimal = min_image_size - smalldist gap_half1 = gap_to_minimal // 2 gap_half2 = gap_to_minimal - gap_half1 minx[smalldist_indices] = minx[smalldist_indices] - gap_half1 maxx[smalldist_indices] = maxx[smalldist_indices] + gap_half2 dist_x = maxx - minx # binning dist_x_appro = torch.round(torch.clip(dist_x, min_image_size, max_image_size) / image_binsize).long() * image_binsize gap_x = dist_x_appro - dist_x gap_half1 = gap_x // 2 gap_half2 = gap_x - gap_half1 minx = minx - gap_half1 maxx = maxx + gap_half2 return minx, maxx, dist_x_appro