import h5py as h5
import s3fs
import numpy as np
import pandas as pd
from tqdm import tqdm
from typing import List, Dict
import urllib.request
import json
import hashlib
import shutil
import ssl
import os
import sys
import re
import itertools
import pickle
import warnings
import scipy.stats as st
from lightgbm import LGBMRegressor
from sklearn.metrics import roc_auc_score, roc_curve, auc
from matplotlib import pyplot as plt
from IPython.display import display, FileLink, Markdown, HTML
LIBRARY_LIST_URL="https://maayanlab.cloud/speedrichr/api/listlibs"
LIBRARY_DOWNLOAD_URL="https://maayanlab.cloud/Enrichr/geneSetLibrary?mode=text&libraryName="
S3_URL="s3://mssm-prismx-100/"
GMT_EXAMPLE="https://maayanlab.cloud/Enrichr/geneSetLibrary?mode=text&libraryName=KEGG_2019_Mouse"
# Use the selected user input
GENE = 'MED23'
LIBRARY = 'GO_Biological_Process_2018'
ENRICHR_MODE = True
def loadGenesS3():
genes = 0
s3 = s3fs.S3FileSystem(anon=True)
with h5.File(s3.open(S3_URL+"correlation_0.h5", 'rb'), 'r', lib_version='latest') as f:
genes = np.array([s.decode("UTF-8") for s in f["meta/genes"]])
return [x.upper() for x in list(genes)]
def loadCorrelationS3(gene, genes, cormat):
cor = 0
s3 = s3fs.S3FileSystem(anon=True)
with h5.File(s3.open(S3_URL+"correlation_"+str(cormat)+".h5", 'rb'), 'r', lib_version='latest') as f:
idx = genes.index(gene.upper())
cor = np.array(f["data/correlation"][idx,:]).astype(np.float64)
return(cor)
def geneCorrelation(gene):
cormats = list(range(100))
cormats.append("global")
genes = loadGenesS3()
results = []
pbar = tqdm(cormats)
for i in pbar:
pbar.set_description("Retrieve %s" % i)
results.append(loadCorrelationS3(gene, genes, i))
results = pd.DataFrame(np.array(results).T, index=genes)
return(results)
def getDataPath() -> str:
path = os.path.join(
os.path.dirname(__file__),
'data/'
)
return(path)
def listLibraries():
return(loadJSON(LIBRARY_LIST_URL)["library"])
def loadLibrary(library: str, overwrite: bool = False) -> str:
ssl._create_default_https_context = ssl._create_unverified_context
if not os.path.exists("gmts/"+library or overwrite):
os.makedirs("gmts", exist_ok=True)
print("Download Enrichr geneset library")
urllib.request.urlretrieve(LIBRARY_DOWNLOAD_URL+library, "gmts/"+library+".gmt")
else:
print("File cached. To reload use loadLibrary(\""+library+"\", overwrite=True) instead.")
lib, rlib, ugenes = readGMT("gmts/"+library+".gmt")
print("# genesets: "+str(len(lib))+"\n# unique genes: "+str(len(ugenes)))
return("gmts/"+library)
def printLibraries():
libs = listLibraries()
for i in range(0, len(libs)):
print(str(i)+" - "+libs[i])
def loadJSON(url):
context = ssl._create_unverified_context()
req = urllib.request.Request(url)
r = urllib.request.urlopen(req, context=context).read()
return(json.loads(r.decode('utf-8')))
def readGMT(gmtFile: str, backgroundGenes: List[str] = [""], verbose=False) -> List:
file = open(gmtFile, 'r')
lines = file.readlines()
library = {}
for line in lines:
sp = line.strip().upper().split("\t")
sp2 = [re.sub(",.*", "",value) for value in sp[2:]]
if len(backgroundGenes) > 1:
backgroundGenes = [x.upper() for x in backgroundGenes]
library[sp[0]] = [value for value in sp2 if value in backgroundGenes]
else:
library[sp[0]] = sp2
ugenes = list(set(list(itertools.chain.from_iterable(library.values()))))
ugenes.sort()
rev_library = {}
for ug in ugenes:
rev_library[ug] = []
for se in library.keys():
for ge in library[se]:
rev_library[ge].append(se)
if verbose:
print("Library loaded. Library contains "+str(len(library))+" gene sets. "+str(len(ugenes))+" unique genes found.")
return [library, rev_library, ugenes]
if ENRICHR_MODE:
libraryPath = loadLibrary(LIBRARY)
else:
os.makedirs("gmts", exist_ok=True)
shutil.copy2(LIBRARY, "gmts/"+LIBRARY)
LIBRARY = LIBRARY.replace(".gmt", "")
library, rev_library, ugenes = readGMT("gmts/"+LIBRARY+".gmt", loadGenesS3())
correlation = geneCorrelation(GENE)
correlation
def getAverageCorrelation(correlation: pd.DataFrame, library: Dict):
avgCor = []
for ll in list(library.keys()):
avgCor.append(correlation.loc[[x.upper() for x in library[ll]],:].mean(axis=0))
avgCor = pd.DataFrame(np.array(avgCor).T, columns=list(library.keys()))
avgCor = avgCor.transpose()
return(avgCor)
avgCor = getAverageCorrelation(correlation, library)
avgCor
def loadPrismXModel():
os.makedirs("model", exist_ok=True)
urllib.request.urlretrieve("https://"+S3_URL.replace("s3:", "").replace("/", "")+".s3.amazonaws.com/prismxmodel.pkl", "model/prismxmodel.pkl")
with warnings.catch_warnings():
warnings.simplefilter('ignore')
model = pickle.load(open("model/prismxmodel.pkl", 'rb'))
return(model)
def prismxPredictions(avgcor, model, verbose: bool=False) -> pd.DataFrame:
avgcor = avgcor.fillna(0)
predictions = pd.DataFrame(model.predict(avgcor))
predictions.index = avgcor.index
predictions.columns = ["predictions"]
mean = np.mean(predictions[:], axis=0)
std = np.std(predictions[:], axis=0)
zscore = (predictions - mean)/std
predictions["z-score"] = zscore
pvalue = 1-st.norm.cdf(zscore)
predictions["p-value"] = pvalue
predictions["bonferroni"] = np.where(pvalue*len(pvalue) >= 1, 1, pvalue*len(pvalue))
return predictions
model = loadPrismXModel()
predictions = prismxPredictions(avgCor, model)
predictions.to_csv(GENE+"_"+LIBRARY+"_predictions.tsv", sep="\t")
top_predictions = pd.DataFrame(predictions.sort_values(by=["predictions"], ascending=False).iloc[0:20:])
def plotROC(fpr, tpr, auc):
plt.figure(figsize=(12, 7), dpi= 300)
plt.plot([0, 1], [0, 1], linestyle='--', label='baseline')
plt.plot(fpr, tpr, linestyle='--', label='PrismEXP')
plt.text(0.65, 0.1, "AUC: "+str(np.round(auc, decimals=3)), fontsize=18)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend()
plt.savefig(GENE+"_"+LIBRARY+"_ROC.pdf")
plt.savefig(GENE+"_"+LIBRARY+"_ROC.png", dpi=300)
plt.show()
def calculateGeneAUC(prediction: pd.DataFrame, gene: str, rev_library: Dict, minLibSize: int=1) -> List[float]:
if gene in list(rev_library.keys()):
gold = [i in rev_library[gene] for i in prediction.index]
fpr, tpr, _ = roc_curve(list(gold), list(prediction.iloc[:,0]))
roc_auc = auc(fpr, tpr)
plotROC(fpr, tpr, roc_auc)
return(auc)
else:
print("Not enough information to calcualte AUC")
return(0)
The ROC curve shows how well previously known annotations for the gene have been recovered by PrismEXP. For this, all gene sets in the library are ranked by the prediction score in descending order. Previously known associations should rank high. The AUC can vary by gene and gene set library.
calculateGeneAUC(predictions, GENE, rev_library)
top_predictions
display(FileLink(GENE+"_"+LIBRARY+"_predictions.tsv", result_html_prefix=str('Download prediction table: ')))
display(FileLink(GENE+"_"+LIBRARY+"_ROC.pdf", result_html_prefix=str('Download PDF: ')))
display(FileLink(GENE+"_"+LIBRARY+"_ROC.png", result_html_prefix=str('Download PNG: ')))
[1] Kuleshov, Maxim V., Matthew R. Jones, Andrew D. Rouillard, Nicolas F. Fernandez, Qiaonan Duan, Zichen Wang, Simon Koplev et al. "Enrichr: a comprehensive gene set enrichment analysis web server 2016 update." Nucleic acids research 44, no. W1 (2016): W90-W97.
[2] Lachmann, Alexander, Brian M. Schilder, Megan L. Wojciechowicz, Denis Torre, Maxim V. Kuleshov, Alexandra B. Keenan, and Avi Ma’ayan. "Geneshot: search engine for ranking genes from arbitrary text queries." Nucleic acids research 47, no. W1 (2019): W571-W577.
[3] Lachmann, Alexander, Denis Torre, Alexandra B. Keenan, Kathleen M. Jagodnik, Hoyjin J. Lee, Lily Wang, Moshe C. Silverstein, and Avi Ma’ayan. "Massive mining of publicly available RNA-seq data from human and mouse." Nature communications 9, no. 1 (2018): 1-10.
[4] Wang, Yuhang, Fillia S. Makedon, James C. Ford, and Justin Pearlman. "HykGene: a hybrid approach for selecting marker genes for phenotype classification using microarray gene expression data." Bioinformatics 21, no. 8 (2005): 1530-1537.