Source code for Bering.plotting.plot_post

import os
import random
import torch
import logging
import warnings
import numpy as np
import pandas as pd
from typing import Sequence

import leidenalg as la
from sklearn.metrics import adjusted_rand_score

import matplotlib as mpl
import matplotlib.pyplot as plt

from torch_geometric.data import Data 

from ._settings import _PLOT_SETTINGS, _GET_CMAPS
from ._settings import PLOT_KEYS as PLT_KEYS

from ..graphs import BuildGraph, BuildGraph_fromRaw
from ..segment import find_clusters_predictedLinks
from ..objects import Bering_Graph as BrGraph
from ._plot_elements import _raw_spots, _raw_cell_types, _raw_cell_types_addPatch
from ._plot_elements import _predicted_cell_types, _predicted_probability, _draw_cells_withStaining

warnings.filterwarnings("ignore")
logger = logging.getLogger(__name__)

_PLOT_SETTINGS()
CMAP = _GET_CMAPS()

def _get_extended_window(spots, location, window_width = 50.0, window_height = 50.0):
    loc_x, loc_y = location
    minx, miny = loc_x - window_width, loc_y - window_height
    maxx, maxy = loc_x + window_width, loc_y + window_height
    logger.info(f'loc_x, loc_y: {loc_x}, {loc_y}')
    logger.info(f'minx, miny, maxx, maxy: {minx}, {miny}, {maxx}, {maxy}')

    window_spots = spots.loc[(spots['x'] > minx) & (spots['x'] < maxx) & (spots['y'] > miny) & (spots['y'] < maxy), :].copy()
    return window_spots

def _get_graph(bg, location, Spots, n_neighbors = 10, window_size = 50.0):
    # get spots in window
    window_spots = _get_extended_window(Spots.copy(), location, window_size, window_size)
    # logger.info(f'window size for location {location} is {window_size.shape}')

    # build neighbor graph
    try:
        graph = BuildGraph_fromRaw(bg, window_spots, bg.features.copy(), n_neighbors = n_neighbors).cpu()
    except AssertionError:
        raise Exception('No enough transcripts in this window')

    return window_spots, graph

def _prediction_nodes(graph, trainer_node, n_labels, prob_threshold = 0.3):
    preds_q = trainer_node.predict(graph, device = 'cpu').cpu()
    max_probs, predictions = torch.max(preds_q, dim = 1)

    back_indices = torch.where(max_probs <= prob_threshold)[0].unsqueeze(1)
    predictions[back_indices] = n_labels - 1

    return predictions, max_probs

[docs]def Plot_Classification_Post( bg: BrGraph, location: Sequence[float], n_neighbors: int = 10, min_prob: float = 0.3, window_size: float = 50.0, ): ''' Plot original spots and newly-segmented spots ''' # BUILD GRAPHS if len(bg.label_to_col) == 0: bg.label_to_col = dict(zip(bg.labels, CMAP[:bg.n_labels])) bg.label_to_col['background'] = '#C0C0C0' Spots = bg.spots_all.copy() df_window_raw, graph = _get_graph( bg, location, Spots, n_neighbors = n_neighbors, window_size = window_size, ) graph = graph.cpu() predictions, max_probs = _prediction_nodes( graph, bg.trainer_node, bg.n_labels, prob_threshold = min_prob, ) predictions = predictions.numpy() max_probs = max_probs.numpy() # PREPARATION tps = graph.pos[:, 0].cpu().numpy() x = graph.pos[:, 1].cpu().numpy() y = graph.pos[:, 2].cpu().numpy() pred_labels = np.array([bg.label_indices_dict[i] for i in predictions]) label_to_col = bg.label_to_col # PLOTS fig, axes = plt.subplots(figsize = (PLT_KEYS.AX_WIDTH * 3, PLT_KEYS.AX_HEIGHT * 1), ncols = 3, nrows = 1, sharex = True, sharey = True) if bg.image_raw is not None: dapi = bg.image_raw[0,:,:] else: dapi = None axes[0] = _raw_spots(dapi, x, y, axes[0]) axes[1] = _predicted_cell_types(dapi, x, y, axes[1], pred_labels, label_to_col) axes[1].set_title(f'Predicted Annotations') axes[2], prob_plot = _predicted_probability(dapi, x, y, axes[2], max_probs) plt.colorbar(prob_plot, ax = axes[2]) fig.tight_layout() # SAVE output_name = f'{PLT_KEYS.FOLDER_NAME}/{PLT_KEYS.FILE_PREFIX_CLASSIFICATION}x_{(location[0]):.2f}_y_{(location[1]):.2f}{PLT_KEYS.FILE_FORMAT}' fig.savefig(output_name, bbox_inches = 'tight') # FOREGROUND WINDOW df_window_pred = df_window_raw.iloc[np.where(pred_labels != 'background')[0],:].copy() return df_window_raw, df_window_pred, predictions
[docs]def Plot_Segmentation_Post( bg: BrGraph, location: Sequence[float], df_window_raw: pd.DataFrame = None, df_window_pred: pd.DataFrame = None, predictions: np.ndarray = None, n_neighbors: int = 10, window_size: float = 50.0, use_image: bool = True, pos_thresh: float = 0.6, resolution: float = 0.05, num_edges_perSpot: int = 300, min_prob_nodeclf: float = 0.3, n_iters: int = 10, ): ''' Plot Original Cell IDs and Cell ID distribution on latent space Either input a cell name and then extract a table or a pre-filtered spots window. ''' Spots = bg.spots_all.copy() if df_window_raw is None: # clean background nodes df_window_raw, graph = _get_graph( bg, location, Spots, n_neighbors = n_neighbors, window_size = window_size, ) graph = graph.cpu() predictions, max_probs = _prediction_nodes(graph, bg.trainer_node, bg.n_labels, prob_threshold = min_prob_nodeclf) predictions = predictions.numpy() max_probs = max_probs.numpy() df_window_pred = df_window_raw.iloc[np.where(pred_labels != 'background')[0], :].copy() pred_labels = np.array([bg.label_indices_dict[i] for i in predictions]) cells_pred = np.zeros(df_window_raw.shape[0], dtype = int) if df_window_pred.shape[0] < num_edges_perSpot: return cls_predlink = find_clusters_predictedLinks( bg, df_spots = df_window_pred, # foreground only use_image = use_image, pos_thresh = pos_thresh, resolution = resolution, num_edges_perSpot = num_edges_perSpot, n_neighbors = n_neighbors, num_iters = n_iters, )[resolution] cells_pred[np.where(pred_labels != 'background')[0]] = cls_predlink cells_pred[np.where(pred_labels == 'background')[0]] = -1 x_raw, y_raw = df_window_raw.x.values, df_window_raw.y.values x_predfore, y_predfore = df_window_pred.x.values, df_window_pred.y.values if bg.channels is not None: fig, axes = plt.subplots(figsize = (PLT_KEYS.AX_WIDTH * (bg.n_channels + 1) * 0.8, PLT_KEYS.AX_HEIGHT * 0.8), ncols = bg.n_channels + 1, nrows = 1, sharex = True, sharey = True) for idx in range(bg.n_channels): axes[idx] = _raw_spots(bg.image_raw[idx], x_raw, y_raw, axes[idx], 'Raw (w. ' + bg.channels[idx] + ')') axes[-1] = _draw_cells_withStaining(bg.image_raw[0], x_predfore, y_predfore, axes[-1], cls_predlink, 'Pred (w. DAPI)') else: fig, axes = plt.subplots(figsize = (PLT_KEYS.AX_WIDTH * 2 * 0.8, PLT_KEYS.AX_HEIGHT * 0.8), ncols = 2, nrows = 1, sharex = True, sharey = True) axes[0] = _raw_spots(None, x_raw, y_raw, axes[0]) axes[1] = _draw_cells_withStaining(None, x_predfore, y_predfore, axes[1], cls_predlink, 'Pred Cells') fig.tight_layout() output_name = f'{PLT_KEYS.FOLDER_NAME}/{PLT_KEYS.FILE_PREFIX_SEGMENTATION}x_{(location[0]):.2f}_y_{(location[1]):.2f}{PLT_KEYS.FILE_FORMAT}' fig.savefig(output_name, bbox_inches = 'tight') plt.close(fig)