Source code for Bering.plotting.plot

import os
import random
import torch
import logging
import warnings
import numpy as np
import pandas as pd
from typing import Optional

import leidenalg as la
from sklearn.metrics import adjusted_rand_score

import matplotlib as mpl
import matplotlib.pyplot as plt

from torch_geometric.data import Data 

from ._settings import _PLOT_SETTINGS, _GET_CMAPS
from ._settings import PLOT_KEYS as PLT_KEYS

from ..graphs import BuildGraph, BuildGraph_fromRaw
from ..segment import find_clusters_predictedLinks
from ..objects import Bering_Graph as BrGraph
from ._plot_elements import _raw_segmentation, _raw_cell_types, _raw_cell_types_addPatch
from ._plot_elements import _predicted_cell_types, _predicted_probability, _draw_cells_withStaining, _draw_cells_withStaining_convexhull

warnings.filterwarnings("ignore")
logger = logging.getLogger(__name__)

_PLOT_SETTINGS()
CMAP = _GET_CMAPS()

def _get_cell_centroid_window(spots, cell_metadata, cell_name, zoomout_scale_x = 8.0, zoomout_scale_y = 8.0):
    cx, cy = cell_metadata.loc[cell_name, 'cx'], cell_metadata.loc[cell_name, 'cy']
    dx, dy = cell_metadata.loc[cell_name, 'dx'], cell_metadata.loc[cell_name, 'dy']

    window_x_min, window_x_max = cx - zoomout_scale_x * dx, cx + zoomout_scale_x * dx
    window_y_min, window_y_max = cy - zoomout_scale_y * dy, cy + zoomout_scale_y * dy
    
    # get spots within window
    window_spots = spots.loc[(spots['x'] > window_x_min) & (spots['x'] < window_x_max) & (spots['y'] > window_y_min) & (spots['y'] < window_y_max), :].copy()
    return window_spots

def _get_extended_window(graph, spots, window_width = 50.0, window_height = 50.0):
    x = graph.pos[:,1].cpu().numpy()
    y = graph.pos[:,2].cpu().numpy()
    cx, cy = np.mean(x), np.mean(y)
    cx, cy = np.round(cx, 2), np.round(cy, 2)

    minx, miny = cx - window_width, cy - window_height
    maxx, maxy = cx + window_width, cy + window_height
    window_spots = spots.loc[(spots['x'] > minx) & (spots['x'] < maxx) & (spots['y'] > miny) & (spots['y'] < maxy), :].copy()

    return window_spots

def _get_graph(bg, cell_name, Spots, cell_metadata, n_neighbors = 10, **kwargs):
    # get spots in window
    window_spots = _get_cell_centroid_window(Spots.copy(), cell_metadata, cell_name, **kwargs)

    # build neighbor graph
    try:
        # graph = BuildGraph(bg, window_spots, n_neighbors = n_neighbors).cpu()
        graph = BuildGraph_fromRaw(bg, window_spots, bg.features.copy(), n_neighbors = n_neighbors).cpu()
    except AssertionError:
        raise Exception('No enough transcripts in this window')

    return window_spots, graph

def _prediction_nodes(graph, trainer_node, n_labels, prob_threshold = 0.3):
    preds_q = trainer_node.predict(graph, device = 'cpu').cpu()
    max_probs, predictions = torch.max(preds_q, dim = 1)

    back_indices = torch.where(max_probs <= prob_threshold)[0].unsqueeze(1)
    predictions[back_indices] = n_labels - 1

    return predictions, max_probs

[docs]def Plot_SliceImages( bg: BrGraph, ): ''' Plot the whole slice, with individual segmented cells as spots. Parameters ---------- bg: BrGraph Bering Graph object ''' if len(bg.label_to_col) == 0: bg.label_to_col = dict(zip(bg.labels, CMAP[:bg.n_labels])) bg.label_to_col['background'] = '#C0C0C0' df_segmented = bg.segmented.copy() df_labels = bg.spots_seg.copy().drop_duplicates(['segmented', 'labels']) df_labels.set_index('segmented', inplace = True) df_segmented['labels'] = df_labels.loc[df_segmented.index.values, 'labels'].values # data x, y = df_segmented.cx.values, df_segmented.cy.values celltypes = df_segmented.labels fig, ax = plt.subplots( figsize = (PLT_KEYS.AX_WIDTH_SLICE, PLT_KEYS.AX_HEIGHT_SLICE) ) ax = _raw_cell_types(None, x, y, ax, celltypes, bg.label_to_col, s = PLT_KEYS.SIZE_PT_CELL_ONSLICE) fig.tight_layout() # SAVE output_name = PLT_KEYS.FOLDER_NAME + '/' + PLT_KEYS.FILE_PREFIX_RAWSLICE + 'labels' + PLT_KEYS.FILE_FORMAT fig.savefig(output_name, bbox_inches = 'tight')
def Plot_Spots( bg: Optional[BrGraph] = None, df_spots_seg: Optional[pd.DataFrame] = None, df_spots_unseg: Optional[pd.DataFrame] = None, ): ''' Visualize both segmented and unsengmented spots Parameters ---------- bg: BrGraph Bering Graph object. - If bg is not None, then df_spots_seg and df_spots_unseg are ignored. - If bg is None, then df_spots_seg and df_spots_unseg must be provided. df_spots_seg: pd.DataFrame DataFrame of segmented spots df_spots_unseg: pd.DataFrame DataFrame of unsegmented spots ''' if bg is not None: df_spots_seg = bg.spots_seg.copy() df_spots_unseg = bg.spots_unseg.copy() x, y = df_spots_seg['x'].values, df_spots_seg['y'].values cell_types = df_spots_seg['labels'].values fig, ax = plt.subplots(figsize = (5, 5)) for idx, cell_type in enumerate(np.unique(cell_types)): xc = x[np.where(cell_types == cell_type)[0]] yc = y[np.where(cell_types == cell_type)[0]] ax.scatter(xc, yc, s = 0.03, label = cell_type, color = np.random.rand(3,)) xb, yb = df_spots_unseg['x'].values, df_spots_unseg['y'].values ax.scatter(xb, yb, color = '#DCDCDC', alpha = 0.2, s = 0.015, label = 'background') h, l = ax.get_legend_handles_labels() plt.legend(h, l, loc = 'upper right', fontsize = 8, markerscale = 15)
[docs]def Plot_Classification( bg: BrGraph, cell_name: str, n_neighbors: int = 10, min_prob: float = 0.3, zoomout_scale: float = 8.0, ): ''' Plot node classfication results on the original data and predicted data. Parameters ---------- bg: BrGraph Bering Graph object cell_name: str Name of the cell to plot n_neighbors: int Number of neighbors to build the knn graph min_prob: float Minimum probability threshold to classify a valid cell type (otherwise background) zoomout_scale: float Zoom out scale (relative to the cell diameter) to show the region Returns ------- - df_window_raw: ``pd.DataFrame`` DataFrame of the raw spots in the window - df_window_pred: ``pd.DataFrame`` DataFrame of the predicted spots in the window - predictions: ``np.ndarray`` Array of predicted labels ''' # BUILD GRAPHS if len(bg.label_to_col) == 0: bg.label_to_col = dict(zip(bg.labels, CMAP[:bg.n_labels])) bg.label_to_col['background'] = '#C0C0C0' Spots = bg.spots_all.copy() cell_metadata = bg.segmented.copy() df_window_raw, graph = _get_graph( bg, cell_name, Spots, cell_metadata, n_neighbors = n_neighbors, zoomout_scale_x = zoomout_scale, zoomout_scale_y = zoomout_scale, ) graph = graph.cpu() predictions, max_probs = _prediction_nodes( graph, bg.trainer_node, bg.n_labels, prob_threshold = min_prob, ) predictions = predictions.numpy() max_probs = max_probs.numpy() # PREPARATION tps = graph.pos[:, 0].cpu().numpy() x = graph.pos[:, 1].cpu().numpy() y = graph.pos[:, 2].cpu().numpy() seg_types = Spots.loc[tps, 'groups'].values raw_labels = Spots.loc[tps, 'labels'].values pred_labels = np.array([bg.label_indices_dict[i] for i in predictions]) label_to_col = bg.label_to_col accuracy = np.sum(raw_labels == pred_labels) / len(raw_labels) # PLOTS fig, axes = plt.subplots(figsize = (PLT_KEYS.AX_WIDTH * 2, PLT_KEYS.AX_HEIGHT *2), ncols = 2, nrows = 2, sharex = True, sharey = True) if bg.image_raw is not None: dapi = bg.image_raw[0,:,:] else: dapi = None axes[0, 0] = _raw_segmentation(dapi, x, y, axes[0,0], seg_types) axes[0, 1] = _raw_cell_types(dapi, x, y, axes[0, 1], raw_labels, label_to_col) axes[1, 0] = _predicted_cell_types(dapi, x, y, axes[1, 0], pred_labels, label_to_col) axes[1, 1], prob_plot = _predicted_probability(dapi, x, y, axes[1, 1], max_probs) axes[1, 0].set_title(f'Predicted Annotations (Accu={accuracy:.2f})') plt.colorbar(prob_plot, ax = axes[1,1], shrink = 0.8) fig.tight_layout() # SAVE output_name = PLT_KEYS.FOLDER_NAME + '/' + PLT_KEYS.FILE_PREFIX_CLASSIFICATION + str(cell_name) + PLT_KEYS.FILE_FORMAT fig.savefig(output_name, bbox_inches = 'tight') # FOREGROUND WINDOW df_window_pred = df_window_raw.iloc[np.where(predictions != bg.n_labels - 1)[0],:].copy() return df_window_raw, df_window_pred, predictions
[docs]def Plot_Segmentation( bg: BrGraph, cell_name: str, df_window_raw: pd.DataFrame = None, df_window_pred: pd.DataFrame = None, predictions: np.ndarray = None, n_neighbors: int = 10, zoomout_scale: float = 4.0, use_image: bool = True, pos_thresh: float = 0.6, resolution: float = 0.05, num_edges_perSpot: int = 300, min_prob_nodeclf: float = 0.3, n_iters: int = 10, convex_hull: bool = False, ): ''' Plot the segmentation results with cell IDs on the original data and predicted data. Parameters ---------- bg: BrGraph Bering Graph object cell_name: str Name of the cell to plot df_window_raw: pd.DataFrame DataFrame of the raw spots in the window df_window_pred: pd.DataFrame DataFrame of the predicted spots in the window predictions: np.ndarray Array of predicted labels. This is used to identify the foreground spots. n_neighbors: int Number of neighbors to build the knn graph zoomout_scale: float Zoom out scale (relative to the cell diameter) to show the region use_image: bool Whether to use the image to build the graph pos_thresh: float Threshold to determine whether the predicted edge is positive in the segmentation step resolution: float Resolution of Leiden clustering algorithm in the segmentation step num_edges_perSpot: int Number of nearest edges used to investigate edge labels (positive or negative) for each spot min_prob_nodeclf: float Minimum probability threshold to classify a valid cell type (otherwise background) n_iters: int Number of iterations. Each iteration runs on a subset of edges. This is used to avoid memory overflow. convex_hull: bool Whether to use convex hull to draw the cells ''' Spots = bg.spots_all.copy() cell_metadata = bg.segmented.copy() if df_window_raw is None: # clean background nodes df_window_raw, graph = _get_graph( bg, cell_name, Spots, cell_metadata, n_neighbors = n_neighbors, zoomout_scale_x = zoomout_scale, zoomout_scale_y = zoomout_scale, ) graph = graph.cpu() predictions, max_probs = _prediction_nodes(graph, bg.trainer_node, bg.n_labels, prob_threshold = min_prob_nodeclf) predictions = predictions.numpy() max_probs = max_probs.numpy() df_window_pred = df_window_raw.iloc[np.where(predictions != bg.n_labels - 1)[0], :].copy() cells_raw = df_window_raw.segmented.values cells_pred = cells_raw.copy() if df_window_pred.shape[0] < num_edges_perSpot: return cls_predlink = find_clusters_predictedLinks( bg, df_spots = df_window_pred, # foreground only use_image = use_image, pos_thresh = pos_thresh, resolution = resolution, num_edges_perSpot = num_edges_perSpot, n_neighbors = n_neighbors, num_iters = n_iters, )[resolution] cells_pred[np.where(predictions != bg.n_labels - 1)[0]] = cls_predlink cells_pred[np.where(predictions == bg.n_labels - 1)[0]] = -1 # get labels raw_labels = df_window_raw.labels.values pred_labels = np.array([bg.label_indices_dict[i] for i in predictions]) # ARI_score_predlink = adjusted_rand_score(cells_raw, cells_pred) x_raw, y_raw = df_window_raw.x.values, df_window_raw.y.values x_predfore, y_predfore = df_window_pred.x.values, df_window_pred.y.values if bg.channels is not None: fig, axes = plt.subplots(figsize = (PLT_KEYS.AX_WIDTH * (bg.n_channels + 1) * 0.8, PLT_KEYS.AX_HEIGHT * 0.8), ncols = bg.n_channels + 1, nrows = 1, sharex = True, sharey = True) for idx in range(bg.n_channels): if convex_hull: axes[idx] = _draw_cells_withStaining_convexhull(bg.image_raw[idx], x_raw, y_raw, axes[idx], cells_raw, raw_labels, bg.label_to_col, 'Raw (w. ' + bg.channels[idx] + ')' ) else: axes[idx] = _draw_cells_withStaining(bg.image_raw[idx], x_raw, y_raw, axes[idx], cells_raw, 'Raw (w. ' + bg.channels[idx] + ')') if convex_hull: axes[idx] = _draw_cells_withStaining_convexhull(bg.image_raw[0], x_predfore, y_predfore, axes[-1], cls_predlink, pred_labels, bg.label_to_col, 'Pred (w. DAPI)') else: axes[-1] = _draw_cells_withStaining(bg.image_raw[0], x_predfore, y_predfore, axes[-1], cls_predlink, 'Pred (w. DAPI)') else: fig, axes = plt.subplots(figsize = (PLT_KEYS.AX_WIDTH * 2 * 0.8, PLT_KEYS.AX_HEIGHT * 0.8), ncols = 2, nrows = 1, sharex = True, sharey = True) if convex_hull: axes[0] = _draw_cells_withStaining_convexhull(None, x_raw, y_raw, axes[0], cells_raw, raw_labels, bg.label_to_col, 'Raw') else: axes[0] = _draw_cells_withStaining(None, x_raw, y_raw, axes[0], cells_raw, 'Raw') # axes[1] = _draw_cells_withStaining(None, x_predfore, y_predfore, axes[1], cls_predlink, 'Pred (ARI=' + str(np.round(ARI_score_predlink,2)) + ')') if convex_hull: axes[1] = _draw_cells_withStaining_convexhull(None, x_predfore, y_predfore, axes[1], cls_predlink, pred_labels, bg.label_to_col, 'Predicted') else: axes[1] = _draw_cells_withStaining(None, x_predfore, y_predfore, axes[1], cls_predlink, 'Predicted') fig.tight_layout() output_name = PLT_KEYS.FOLDER_NAME + '/' + PLT_KEYS.FILE_PREFIX_SEGMENTATION + str(cell_name) + PLT_KEYS.FILE_FORMAT fig.savefig(output_name, bbox_inches = 'tight')