# 使用NMF代替层次聚类

rm(list=ls())
options(stringsAsFactors = F)
a2=mut.wt
## 非负矩阵分解识别驱动性signature并绘图
nmf.input <- t(a2)
nmf.input <- nmf.input[setdiff(rownames(nmf.input),"unknown"),] #去掉unknown
# 需要去除那些没有被计算到的signatures
nmf.input <- nmf.input[rowSums(nmf.input)>0,]
library(pheatmap)
pheatmap(nmf.input)

### 第一步：判断最佳分组

library(NMF)
ranks <- 2:10
estim <- lapply(ranks, function(r){
fit <- nmf(nmf.input, r, nrun = 5, seed = 4, method = "lee") # nrun设置为5以免运行时间过长
list(fit = fit, consensus = consensus(fit), .opt = "vp",coph = cophcor(fit))
})

names(estim) <- paste('rank', ranks)
sapply(estim, '[[', 'coph')

png("Cophenetic coefficient for seleting optimal nmf rank.png")
par(cex.axis=1.5)
plot(ranks, sapply(estim, '[[', 'coph'),
xlab="", ylab="", type="b",
col="red", lwd=4,xaxt="n")
axis(side = 1, at=1:10)
title(xlab="number of clusters", ylab="Cophenetic coefficient", cex.lab=1.5)
dev.off()

### 第二步：筛选signature

rank <- 3
seed <- 2019620
rownames(nmf.input) <- gsub("Signature","Sig",rownames(nmf.input)) # 行名简化
mut.nmf <- nmf(nmf.input,
rank = rank,
seed = seed,
method = "lee")
index <- extractFeatures(mut.nmf,"max")
sig.order <- unlist(index)

> index
[[1]]
[1] 2 1

[[2]]
[1] 5 3 11 4

[[3]]
[1] 10 9

attr(,"method")
[1] "max"
> sig.order
[1] 2 1 5 3 11 4 10 9

### 第三步：使用挑选出的signature再次NMF

nmf.input2 <- nmf.input[sig.order,]
library(pheatmap)
pheatmap(nmf.input2,cluster_rows = T,cluster_cols = F)

mut.nmf2 <- nmf(nmf.input2,
rank = rank,
seed = seed,
method = "lee")
group <- predict(mut.nmf2) # 提出亚型
table(group)

### 番外：一些可视化函数

sample.order <- names(group[order(group)])
#设置颜色
jco <- c("#2874C5","#EABF00","#C6524A","#868686")

png(file = "consensusmap.png")
consensusmap(mut.nmf2,
labRow = NA,
labCol = NA,
annCol = data.frame("cluster"=group[colnames(nmf.input)]),
annColors = list(cluster=c("1"=jco[1],"2"=jco[2],"3"=jco[3],"4"=jco[4])))
dev.off()

png(file = "basismap.png" )
# 从这张图可以比较清晰地看到各亚型中的驱动signature（深色），与下面的nmf heatmap对应

basismap(mut.nmf2,
cexCol = 1,
cexRow = 0.3,
annColors=list(c("1"=jco[1],"2"=jco[2],"3"=jco[3],"4"=jco[4])))
dev.off()

aheatmap(as.matrix(nmf.input2[,sample.order]),
Rowv=NA,
Colv=NA,
annCol = data.frame("cluster"=group[sample.order]),
annColors = list(cluster=c("1"=jco[1],"2"=jco[2],"3"=jco[3],"4"=jco[4])),
color=c("#EAF0FA","#6081C3","#3454A7"), # 例文的蓝色渐变
revC=TRUE,
cexCol = 0.3,
cexRow = 0.3,
filename = "NMF_heatmap.pdf")

ac=data.frame(group=paste0('NMF_',group))
rownames(ac)=colnames(nmf.input2)
pheatmap(nmf.input2,show_colnames = F,annotation_col = ac)
pheatmap(nmf.input2[,sample.order],
show_colnames = F,cluster_cols = F,annotation_col = ac)

pheatmap(nmf.input2,show_colnames = F,annotation_col = ac,
filename = "NMF_heatmap1.pdf")
pheatmap(nmf.input2[,sample.order],
show_colnames = F,cluster_cols = F,annotation_col = ac,
filename = "NMF_heatmap2.pdf")
dev.off()