"""
Plots violinplot subplots of expression of genes.
From scrnatools package
Created on Wed March 29 17:28:15 2023
@author: joe germino (joe.germino@ucsf.edu), nolan horner (nolan.horner@ucsf.edu)
"""
# external package imports
from anndata import AnnData
from typing import List, Tuple
import seaborn as sns
import matplotlib.pyplot as plt
import math
import scrnatools
# scrnatools package imports
from .._configs import configs
from .._utils import check_path
logger = configs.create_logger(__name__.split('_', 1)[1])
# -------------------------------------------------------function----------------------------------------------------- #
[docs]def gene_violinplot(
adata: AnnData,
gene_list: List[str],
x_key: str,
layer: str = "X",
x_values: List[str] = ['All'],
hue_key: str = None,
hue_values: List[str] = ['All'],
ncols: int = None,
nrows: int = None,
save_path: str = None,
dpi: int = 300,
fig_size: Tuple[float] = (0.5, 2.5),
*args, **kwargs
):
"""Plots violinplot subplots of expression of genes.
Args:
adata (AnnData): The dataset containing the gene expression and cell data.
gene_list (List[str]): A list of genes to plot.
x_key (str): The categorical grouping to display on the x axis of the violinplot.
layer (str, optional): The layer containing expression data to use for the violinplot, can be a layer in 'adata.layers' or 'X' to use the data stored in adata.X. Defaults to "X".
x_values (List[str], optional): Values from x key group to display on violinplot. Defaults to ['All'].
hue_key (str, optional): The categorical grouping to color the grouped violin plots by. Values will appear in legend. If 'hue_key' is 'None' no hue splitting of x values will occur. Defaults to None.
hue_values (List[str], optional): Values from hue_key to display on violinplot. Defaults to ['All'].
ncols (int, optional): Number of columns to display the violinplots. Defaults to None.
nrows (int, optional): Number of rows to display the violinplots. Defaults to None.
save_path (str, optional): The path to save the figure. Defaults to None.
dpi (int, optional): The resolution of the saved image. Defaults to 300.
fig_size (Tuple[float], optional): The scaling factors for the column and row size in the figure. Defaults to (0.5, 2.5).
Raises:
ValueError: If a gene in 'gene_list' provided is not in provided AnnData layer.
ValueError: If the 'x_values' provided less than 2.
ValueError: If the 'hue_values' is provided but hue_key is not provided.
ValueError: If the 'hue_values' provided less than 2 if hue_key is also provided.
ValueError: If ncols * nrows is less than length of gene list provided.
ValueError: If a value in 'x_values' provided is not in AnnData layer x key.
ValueError: If a value in 'hue_values' provided is not in AnnData layer hue key given hue key is also provided.
"""
# check inputs
invalid_genes = [i for i in gene_list if i not in adata.var_names]
if len(invalid_genes) > 0:
raise ValueError(f"Genes not found in data: {invalid_genes}")
if len(x_values) < 2 and x_values != ['All']:
raise ValueError(
f"Please choose more than one x value. Or use '[All]' to see all {x_key}s"
)
if hue_key is None:
if hue_values != ['All']:
raise ValueError(f"Please enter a hue_key.")
else:
if len(hue_values) < 2 and hue_values != ['All']:
raise ValueError(
f"Please choose more than one hue value. Or use '[All]' to see all {hue_key}s"
)
if x_values == ['All']:
x_values = adata.obs[x_key].unique()
if hue_values == ['All'] and hue_key != None:
hue_values = adata.obs[hue_key].unique()
# format figure
max_input_len = max(len(x_values), len(hue_values))
min_input_len = min(len(hue_values), len(x_values))
if (ncols is None) & (nrows is None):
if max_input_len > 5:
ncols = 1
nrows = len(gene_list)
elif max_input_len > 2:
if min_input_len <= 2:
ncols = 3
else:
ncols = 2
else:
if min_input_len <= 2:
ncols = 5
else:
ncols = 3
if (len(gene_list) % ncols) > 0:
nrows = int((len(gene_list) / ncols)) + 1
else:
nrows = int((len(gene_list) / ncols))
elif (ncols is None) & (nrows is not None):
if (len(gene_list) % nrows) > 0:
ncols = int(len(gene_list) / nrows) + 1
else:
ncols = int((len(gene_list) / nrows))
elif (nrows is None) & (ncols is not None):
if (len(gene_list) % ncols) > 0:
nrows = int((len(gene_list) / ncols)) + 1
else:
nrows = int((len(gene_list) / ncols))
if ncols * nrows < len(gene_list):
raise ValueError(
"Number of rows and columns must fit number of genes in gene list (nrows * ncols >= length of gene list).")
fig = plt.figure(figsize=(ncols*len(x_values) *
max(2, len(hue_values))*fig_size[0], nrows*fig_size[1]))
# subset data on x values and hue values -> expression matrix
if hue_key is None:
subsetAdata = adata[adata.obs[x_key].isin(x_values), gene_list]
else:
subsetAdata = adata[adata.obs[x_key].isin(
x_values) & adata.obs[hue_key].isin(hue_values), gene_list]
expression_matrix = scrnatools.tl.get_expression_matrix(
subsetAdata, gene_data=layer)
expression_matrix[x_key] = subsetAdata.obs[x_key]
if hue_key != None:
expression_matrix[hue_key] = subsetAdata.obs[hue_key]
# check inputted x values and hue values are in subsetted data
for i in x_values:
if i not in list(expression_matrix[x_key].unique()):
raise ValueError(
f"{i} is not in adata's x_key {x_key}.\nPossible {x_key}s: {list(adata.obs[x_key].unique())}")
if hue_key != None:
for i in hue_values:
if i not in list(expression_matrix[hue_key].unique()):
raise ValueError(
f"{i} is not in adata's hue_key {hue_key}.\nPossible {hue_key}s: {list(adata.obs[hue_key].unique())}")
# generate violin plots
sns.set_theme(context="paper", style="white", )
for plt_num, gene in enumerate(gene_list):
plt.subplot(nrows, ncols, plt_num + 1)
ax = sns.violinplot(
expression_matrix,
x=x_key,
y=gene,
hue=hue_key,
scale='width',
width=0.9,
order=x_values,
hue_order=hue_values,
*args, **kwargs
)
if hue_key is not None:
if (plt_num > 0):
plt.legend([], [], frameon=False)
else:
plt.legend(loc='upper right')
plt.xticks(rotation=90)
ax.grid(False)
ax.set(xlabel=None)
ax.set_title(f"{gene}")
ax.tick_params(bottom=True, left=True)
fig.tight_layout()
if save_path is not None:
if "/" not in save_path:
save_path = f"./{save_path}"
check_path(save_path.rsplit("/", 1)[0])
logger.info(f"Saving figure to {save_path}")
plt.savefig(save_path, dpi=dpi, facecolor="white",)
plt.show()