Source code for Bering.graphs._loader

import random
import logging
import numpy as np
import pandas as pd

import torch 
from torch_geometric.loader import DataLoader

from ..objects import Bering_Graph as BrGraph
from ._graph import BuildGraph
from ._settings import GRAPH_KEYS as G_KEYS

logger = logging.getLogger(__name__)

__all__ = ["spatial_neighbors"]

[docs]def BuildWindowGraphs( bg: BrGraph, n_cells_perClass: int = 10, window_width: float = G_KEYS.WINDOW_WIDTH, window_height: float = G_KEYS.WINDOW_HEIGHT, n_neighbors: int = G_KEYS.N_NEIGHBORS, min_points: int = G_KEYS.WINDOW_MIN_POINTS, use_unsegmented_ratio: float = 0.8, max_unsegmented_thresh: float = 0.4, cell_percentile_from_border: float = 10, window_shift_ratio: float = 0.25, n_windows_per_cell: int = 5, min_spots_outside: int = 5, **kwargs, ): """ Build Graphs for originally segemented cells. We randomly select a subset of cells and get their neighboring regions (windows) to construct graphs. Parameters ---------- bg Bering_Graph object n_cells_perClass Number of cells per cell class for training window_width Width of each selected region for graph construction window_height Height of each selected region for graph construction n_neighbors Number of neighbors in KNN min_points Minimum number of points in a window use_unsegmented_ratio Proportion of unsegmented spots used in a window max_unsegmented_thresh Maximum proportion of unsegmented spots in a window cell_percentile_from_border Remove cells in the border that are too close (within the defined percentile) to the image border window_shift_ratio To balance the transcripts within cells and out of cells, we shift the window to the centroid of the cell by a ratio of the cell diameter n_windows_per_cell Number of windows per cell. Available options are 1, 3, 5 min_spots_outside Minimum number of spots outside the cell **kwargs Other arguments for BuildGraph Returns ------- ``Bering_Graph.Graphs_golden``: :func:`~BrGraph` object with a list of graphs (``torch_geometric.data.Data``) for training """ # init Spots = bg.spots_all.copy() Graphs = [] bg.window_size = window_width # select cells and find attributes selected_cells = [] cell_meta = bg.segmented.copy() labels = np.setdiff1d(cell_meta.labels.unique(), ['background']) cx_min, cx_max = np.percentile(cell_meta['cx'].values, cell_percentile_from_border), np.percentile(cell_meta['cx'].values, 100 - cell_percentile_from_border) cy_min, cy_max = np.percentile(cell_meta['cy'].values, cell_percentile_from_border), np.percentile(cell_meta['cy'].values, 100 - cell_percentile_from_border) cell_meta = cell_meta.loc[(cell_meta['cx'] > cx_min) & (cell_meta['cx'] < cx_max) & (cell_meta['cy'] > cy_min) & (cell_meta['cy'] < cy_max), :].copy() # remove cells in border for label in labels: cells = cell_meta.loc[cell_meta['labels'] == label, :].index.values selected_cells += random.sample(list(cells), min(n_cells_perClass, len(cells))) counts = 0 for cell_idx, cell in enumerate(selected_cells): cx, cy, cz, d = bg.segmented.loc[cell, 'cx'], bg.segmented.loc[cell, 'cy'], bg.segmented.loc[cell, 'cz'], bg.segmented.loc[cell, 'd'] if n_windows_per_cell == 1: xc_list = [cx] yc_list = [cy] # xc_list = [cx-d*slide_window_ratio] # yc_list = [cy+d*slide_window_ratio] elif n_windows_per_cell == 3: xc_list = [cx - d * window_shift_ratio, cx, cx + d * window_shift_ratio] yc_list = [cy + d * window_shift_ratio, cy, cy - d * window_shift_ratio] zc_list = [cz + d * window_shift_ratio, cz, cz - d * window_shift_ratio] elif n_windows_per_cell == 5: xc_list = [cx - d * window_shift_ratio, cx - d * window_shift_ratio, cx, cx + d * window_shift_ratio, cx + d * window_shift_ratio] yc_list = [cy + d * window_shift_ratio, cy - d * window_shift_ratio, cy, cy + d * window_shift_ratio, cy - d * window_shift_ratio] zc_list = [cz + d * window_shift_ratio, cz - d * window_shift_ratio, cz, cz + d * window_shift_ratio, cz - d * window_shift_ratio] positions = ['topleft', 'bottomleft', 'centroid', 'topright', 'bottomright'] for xc, yc, pos in zip(xc_list, yc_list, positions): # define the core window xmin, ymin, zmin = xc - window_width / 2, yc - window_height / 2, cz - window_width / 2 xmax, ymax, zmax = xc + window_width / 2, yc + window_height / 2, cz + window_width / 2 if bg.dimension == '2d': window_spots = Spots.loc[(Spots.x > xmin) & (Spots.x < xmax) & (Spots.y > ymin) & (Spots.y < ymax), :].copy() elif bg.dimension == '3d': window_spots = Spots.loc[(Spots.x > xmin) & (Spots.x < xmax) & (Spots.y > ymin) & (Spots.y < ymax) & (Spots.z > zmin) & (Spots.z < zmax), :].copy() if window_spots.shape[0] == 0: continue spots_abun = pd.DataFrame(window_spots.groupby(['segmented']).size(), columns = ['counts']) spots_abun.sort_values(by = ['counts'], ascending = False, inplace = True) if spots_abun.shape[0] == 1 or spots_abun.iloc[1,0] < min_spots_outside: continue window_seg = window_spots.loc[window_spots['groups'] == 'segmented', :].copy() window_unseg = window_spots.loc[window_spots['groups'] == 'unsegmented', :].copy() ratio_unseg = window_unseg.shape[0] / (window_unseg.shape[0] + window_seg.shape[0]) if ratio_unseg >= max_unsegmented_thresh: continue # get subset of unlabelled spots unseg_indices = random.sample(list(range(window_unseg.shape[0])), int(window_unseg.shape[0] * use_unsegmented_ratio)) window_unseg = window_unseg.iloc[unseg_indices,:].copy() window_spots = pd.concat([window_seg, window_unseg], axis = 0) if (window_spots.shape[0] < n_neighbors + 1) or (window_spots.shape[0] < min_points): continue counts += 1 graph = BuildGraph(bg, window_spots, n_neighbors = n_neighbors, **kwargs) Graphs.append(graph) if counts % 50 == 0: logger.info(f'Build Neighbor graphs for {counts} th window (golden truth)') logger.info(f'Number of dots in {counts} th window: {window_spots.shape[0]}') avg_neighbors = graph.edge_index.shape[1] / graph.x.shape[0] logger.info(f'Average number of filtered neighbors: {avg_neighbors:.2f} in the window') bg.Graphs_golden = Graphs logger.info(f'Number of node features: {bg.n_node_features}') logger.info(f'\nTotal number of golden-truth graphs is {len(bg.Graphs_golden)}')
[docs]def CreateData( bg: BrGraph, batch_size: int = G_KEYS.BATCH_SIZE, training_ratio: float = G_KEYS.TRAINING_RATIO ): ''' Create training and testing data loader Parameters ---------- bg Bering_Graph object batch_size Batch size for training training_ratio Ratio of training data to the total data Returns ------- - ``Bering_Graph.train_loader``: Training data loader (``torch_geometric.data.DataLoader``) - ``Bering_Graph.test_loader``: Testing data loader (``torch_geometric.data.DataLoader``) ''' # initialize logger.info(f'Create training and testing datasets (golden truth)') dataset = bg.Graphs_golden N = len(dataset) N_train = int(N * training_ratio) random.shuffle(dataset) train_loader = DataLoader(dataset[:N_train], batch_size = batch_size, shuffle = False, num_workers = 0) test_loader = DataLoader(dataset[N_train:], batch_size = batch_size, shuffle = False, num_workers = 0) bg.train_loader = train_loader bg.test_loader = test_loader # bg.Graphs_train = dataset[:N_train] # bg.Graphs_test = dataset[N_train:] del bg.Graphs_golden