Source code for scrnatools.qc._scrublet

"""
Filters out doublets using scrublet.
From scrnatools package

Created on Mon Jan 10 15:57:46 2022

@author: joe germino (joe.germino@ucsf.edu)
"""
# external package imports
from anndata import AnnData
import matplotlib.pyplot as plt
import scrublet as scr
import pandas as pd

# scrnatools package imports
from .._configs import configs

logger = configs.create_logger(__name__.split('_', 1)[1])

# -------------------------------------------------------function----------------------------------------------------- #


[docs]def scrublet( adata: AnnData, raw_counts_layer: str, doublet_threshold: float = 0.2, batch_key: str = None, subset: bool = False, ) -> AnnData: """Filters out doublets using scrublet. Args: adata (AnnData): The dataset to process raw_counts_layer (str): A key in 'adata.layers' pointing to the raw counts data. doublet_threshold (float, optional): The doublet score threshold to call doublets/singlets on. Defaults to 0.2. batch_key (str, optional): A column in 'adata.obs' annotating different batches of data. If 'None' treats all cells as coming from the same batch. Defaults to None. subset (bool, optional): Whether to filter the doublets in place or just annotate them. Returns: AnnData: The dataset provided with doublets filtered/labeled based on doublet_threshold """ # Setup scrublet_predictions = pd.DataFrame() if batch_key is None: adata.obs["scrublet_batch_key"] = "None" batch_key = "scrublet_batch_key" # Calculate doublet scores for each batch indepenently plt.rcParams["axes.grid"] = False for key in adata.obs[batch_key].unique(): subset_data = adata[adata.obs[batch_key] == key] counts_matrix = subset_data.layers[raw_counts_layer] scrub = scr.Scrublet(counts_matrix, expected_doublet_rate=0.06) doublet_scores, predicted_doublets = scrub.scrub_doublets( min_counts=2, min_cells=3, min_gene_variability_pctl=85, n_prin_comps=30 ) scrub.call_doublets(threshold=doublet_threshold) scrub.plot_histogram() scrublet_prediction = pd.DataFrame( {"scrublet_score": doublet_scores}, index=subset_data.obs.index ) scrublet_predictions = pd.concat( [scrublet_predictions, scrublet_prediction] ) # Add scrublet scores to adata.obs and filter out cells with a score > doublet_threshold adata.obs["scrublet_score"] = scrublet_predictions.scrublet_score adata.obs["scrublet_called_doublet"] = adata.obs.scrublet_score > doublet_threshold num_doublets = adata.obs.scrublet_called_doublet.sum() pct_doublets = round(num_doublets / len(adata) * 100, 3) logger.info( f"{pct_doublets}% of cells classified as doublets ({num_doublets} cells)") if subset: adata = adata[~adata.obs.scrublet_called_doublet] return adata