Source code for Bering.tools._node_classification

import logging
import torch
import numpy as np
import pandas as pd

from ..objects import Bering_Graph as BrGraph
from ..graphs import BuildGraph_fromRaw

logger = logging.getLogger(__name__)

def _get_pos(
    df_spots: pd.DataFrame,
):
    x = df_spots.x.values
    y = df_spots.y.values
    z = df_spots.z.values
    tps_names = df_spots.index.values

    # pos = np.array([tps_names, x, y]).T
    pos = np.array([tps_names, x, y, z]).T
    pos = torch.from_numpy(pos).double()
    return pos

def _get_node_embedding_prediction_byTiling(bg, df_spots, num_chunks, n_neighbors, beta, dummy = 1e-3):
    '''
    Split spots tables into tiles by coordinates if the number of spots is too large (use 2d chunks here)
    '''
    x, y = df_spots.x.values, df_spots.y.values
    num_chunks_axis = np.round(np.sqrt(num_chunks)).astype(int)
    num_chunks = num_chunks_axis ** 2
    logger.info(f'Number of chunks for node classification (adjusted): {num_chunks}')

    # tile_size_x = (np.max(x) - np.min(x)) / num_chunks_axis
    # tile_size_y = (np.max(y) - np.min(y)) / num_chunks_axis
    x_percentiles = np.percentile(x, np.linspace(0, 100, num_chunks_axis + 1))
    y_percentiles = np.percentile(y, np.linspace(0, 100, num_chunks_axis + 1))

    if not hasattr(bg.trainer_node.model, 'num_mlp_layers_remain'):
        z_all = torch.zeros((df_spots.shape[0], bg.trainer_node.model.mlp_layer_dims[1]), dtype = torch.double) # GCN
    else:
        z_all = torch.zeros((df_spots.shape[0], bg.trainer_node.model.mlp_layer_dims[-(bg.trainer_node.model.num_mlp_layers_remain+1)]), dtype = torch.double) # MLP
    
    logger.info(f'size of z_all: {z_all.shape}')
    preds_logits = torch.zeros((df_spots.shape[0], bg.n_labels), dtype = torch.double)

    for i in range(num_chunks_axis):
        for j in range(num_chunks_axis):
            # min_x, max_x = np.min(x) + i * tile_size_x, np.min(x) + (i + 1) * tile_size_x
            # min_y, max_y = np.min(y) + j * tile_size_y, np.min(y) + (j + 1) * tile_size_y
            min_x, max_x = x_percentiles[i], x_percentiles[i + 1]
            min_y, max_y = y_percentiles[j], y_percentiles[j + 1]

            if i == 0:
                min_x -= dummy
            if i == num_chunks_axis - 1:
                max_x += dummy
            if j == 0:
                min_y -= dummy
            if j == num_chunks_axis - 1:
                max_y += dummy

            tile_indices = np.where(
                (x >= min_x) & (x < max_x) & (y >= min_y) & (y < max_y)
            )[0]
            # logger.info(f'tile_indices (top 10 and bottom 10): {tile_indices[:10]}, {tile_indices[-10:]}')

            if len(tile_indices) > n_neighbors + 1:
                df_spots_section = df_spots.iloc[tile_indices, :]
                logger.info(f'Number of spots in tile (i,j) {(i,j)} is {df_spots_section.shape[0]}')
                graph_section = BuildGraph_fromRaw(bg, df_spots_section, bg.features.copy(), n_neighbors, beta).cpu()
                z_section = bg.trainer_node.model.get_latent(graph_section)
                logger.info(f'size of z_section (1): {z_section.shape}')
                preds_logits_section = bg.trainer_node.predict(graph_section, device = 'cpu').cpu()
            else:
                z_section = torch.zeros((len(tile_indices), z_all.shape[1]), dtype = torch.double)
                logger.info(f'size of z_section (2): {z_section.shape}')
                preds_logits_section = torch.zeros((len(tile_indices), bg.n_labels), dtype = torch.double)

            # if i == 0 and j == 0:
            #     z_all = z_section
            #     preds_logits = preds_logits_section
            # else:
            #     z_all = torch.cat((z_all, z_section), dim=0)
            #     preds_logits = torch.cat((preds_logits, preds_logits_section), dim=0)

            z_all[tile_indices, :] = z_section
            preds_logits[tile_indices, :] = preds_logits_section
    
    return z_all, preds_logits

[docs]def node_classification( bg: BrGraph, df_spots: pd.DataFrame, n_neighbors: int = 10, beta: float = 1.0, prob_threshold: float = 0.3, max_num_spots: int = 1500000, #1.5 million num_chunks: int = 25, ): ''' Node classification for all spots in the slice Parameters ---------- bg: BrGraph Bering Graph object df_spots: pd.DataFrame spots table. It can be ``bg.spots_all`` in case of whole slice prediction n_neighbors: int number of neighbors for graph construction prob_threshold: float minimal threshold of predicted probability for spots to be considered as foreground max_num_spots: int maximum number of spots for node classification in each chunk. If the number of spots is larger than this number, the spots table will be split into chunks by coordinates. num_chunks: int number of chunks for node classification. This is done by splitting the spots table into chunks by coordinates. This is used when the number of spots is too large. Refer to `_get_node_embedding_prediction_byTiling` for details. Returns ------- preds_labels: np.array predicted labels for all spots graph_all: torch_geometric.data.Data graph (``torch_geometric.data.Data`` object) for the whole slice ''' # build graph logger.info(f'Building Graph for the whole slice') try: graph_all = BuildGraph_fromRaw(bg, df_spots, bg.features.copy(), n_neighbors, beta).cpu() bg.graph_all = graph_all except: graph_all = None bg.pos_all = _get_pos(bg.spots_all) logger.info(f'There are too many spots in the slice. Skip the generation of the whole graph.') # get latent logger.info(f'Get the latent space for all {df_spots.shape[0]} nodes') bg.trainer_node.model.to('cpu') if bg.spots_all.shape[0] <= max_num_spots: bg.z_all = bg.trainer_node.model.get_latent(graph_all) preds_logits = bg.trainer_node.predict(graph_all, device = 'cpu').cpu() else: logger.info(f'Number of chunks for node classification: {num_chunks}') bg.z_all, preds_logits = _get_node_embedding_prediction_byTiling(bg, df_spots, num_chunks, n_neighbors, beta) # prediction results max_probs, preds_logits = torch.max(preds_logits, dim = 1) back_indices = torch.where(max_probs <= prob_threshold)[0].unsqueeze(1) preds_logits[back_indices] = bg.n_labels - 1 preds_labels = np.array([bg.label_indices_dict[i.item()] for i in preds_logits]) bg.spots_all['predicted_node_labels'] = preds_labels bg.spots_all['predicted_probability'] = max_probs.numpy() bg.foreground_indices = np.where(bg.spots_all['predicted_node_labels'].values != 'background')[0] # output results return preds_labels, graph_all