Source code for scrnatools.plotting._gene_heatmap

"""
Plots a heatmap of expression of genes.
From scrnatools package

Created on Tue March 28 17:47:46 2023

@author: joe germino (joe.germino@ucsf.edu), nolan horner (nolan.horner@ucsf.edu)
"""
# external package imports
from anndata import AnnData
from typing import List
import seaborn as sns
import matplotlib.pyplot as plt
import scrnatools

# scrnatools package imports
from .._configs import configs
from .._utils import check_path

logger = configs.create_logger(__name__.split('_', 1)[1])

# -------------------------------------------------------function----------------------------------------------------- #


[docs]def gene_heatmap( adata: AnnData, gene_list: List[str], obs_key: str, layer: str = "X", obs_values: List[str] = ['All'], cbar_args: List[int] = None, swap_axes: bool = False, cell_size: int = 30, save_path: str = None, dpi: int = 300, *args, **kwargs ): """Plots a heatmap of expression of genes. Args: adata (AnnData): The dataset containing the gene expression and cell data. gene_list (List[str]): A list of genes to plot. obs_key (str): The categorical grouping to display on the heatmap. layer (str, optional): The type of the expression data to use for the heatmap, can be a layer in 'adata.layers' or 'X' to use the data stored in adata.X. Defaults to "X". obs_values (List[str], optional): Values from obs key group to display on heatmap. Defaults to ['All']. cbar_args (List[int], optional): List of integers to position color bar on heatmap. [x position, y position, width, height]. 'None' automatically adjusts to right middle of heatmap.. Defaults to None. swap_axes (bool, optional): AWhether to swap x and y axes. Defaults to False. cell_size (int, optional): Integer specifying the size of each cell in the heatmap. Defaults to 30. save_path (str, optional): The path to save the figure. Defaults to None. dpi (int, optional): The resolution of the saved image. Defaults to 300. Raises: ValueError: If a gene in 'gene_list' provided is not in provided AnnData layer. ValueError: If 'gene_list' provided less than 2. ValueError: If the 'obs_values' provided less than 2. ValueError: If a value in 'obs_values' provided is not in provided AnnData layer obs key. """ invalid_genes = [i for i in gene_list if i not in adata.var_names] if len(invalid_genes) > 0: raise ValueError(f"Genes not found in data: {invalid_genes}") if len(gene_list) < 2: raise ValueError( f"Please choose two or more genes. {len(gene_list)} genes were entered.") if layer == "raw": raw_adata = adata.raw.to_adata() expression_matrix = scrnatools.tl.get_expression_matrix( raw_adata[:, gene_list], gene_data="X" ) else: expression_matrix = scrnatools.tl.get_expression_matrix( adata[:, gene_list], gene_data=layer ) expression_matrix[obs_key] = adata.obs[obs_key] # matrix with mean expression levels of selected genes for each cell type expression_matrix = expression_matrix.groupby(obs_key).mean() if obs_values != ['All']: # include only cell types from cell_type_list if len(obs_values) < 2: raise ValueError( f"Please choose more than one obs value. Or use '[All]' to see all {obs_key}s") for i in obs_values: if i not in expression_matrix.index: raise ValueError( f"{i} is not in adata's obs_key {obs_key}.\nPossible {obs_key}s: {list(expression_matrix.index.categories)}" ) for j in expression_matrix.index: if j not in obs_values: expression_matrix = expression_matrix.drop(j) if swap_axes: expression_matrix = expression_matrix.transpose() figdpi = plt.rcParams['figure.dpi'] totalWidth = plt.rcParams['figure.subplot.right'] - \ plt.rcParams['figure.subplot.left'] totalHeight = plt.rcParams['figure.subplot.top'] - \ plt.rcParams['figure.subplot.bottom'] nrows, ncols = expression_matrix.shape figWidth = (ncols*cell_size/figdpi)/totalWidth figHeight = (nrows*cell_size/figdpi)/totalHeight sns.set_theme(context="paper", style="white", ) cg = sns.clustermap( expression_matrix, figsize=(figWidth, figHeight), *args, **kwargs ) # calculate the size of the heatmap axes axWidth = (ncols*cell_size)/(figWidth*figdpi) axHeight = (nrows*cell_size)/(figHeight*figdpi) # resize heatmap ax_heatmap_og_pos = cg.ax_heatmap.get_position() cg.ax_heatmap.set_position( [ax_heatmap_og_pos.x0, ax_heatmap_og_pos.y0, axWidth, axHeight] ) # layout of colorbar if cbar_args is None: cbar_x_pos = 0.85 + ax_heatmap_og_pos.x0 cbar_y_pos = ax_heatmap_og_pos.y0 cbar_width = 25/(figWidth*figdpi) cbar_height = axHeight cbar_args = [cbar_x_pos, cbar_y_pos, cbar_width, cbar_height] # set the layout cg.ax_cbar.set_position(cbar_args) cg.ax_cbar.grid(False) cg.ax_row_dendrogram.set_visible(False) cg.ax_col_dendrogram.set_visible(False) cg.ax_heatmap.yaxis.tick_left() cg.ax_heatmap.set_ylabel("") cg.ax_heatmap.set_xlabel("") cg.ax_heatmap.grid(False) cg.ax_heatmap.tick_params(axis='y', labelrotation=0) cg.ax_heatmap.tick_params(axis='x', labelrotation=90) cg.ax_heatmap.tick_params(bottom=True, left=True) if save_path is not None: if "/" not in save_path: save_path = f"./{save_path}" check_path(save_path.rsplit("/", 1)[0]) logger.info(f"Saving figure to {save_path}") plt.savefig(save_path, dpi=dpi, facecolor="white", bbox_inches="tight") plt.show()