-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
39 lines (31 loc) · 944 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from itertools import chain
import anndata as ad
import itertools
import networkx as nx
import pandas as pd
import scanpy as sc
import scglue
import seaborn as sns
from matplotlib import rcParams
scglue.plot.set_publication_params()
rcParams["figure.figsize"] = (4, 4)
rna = ad.read_h5ad("./data/rna-pp.h5ad")
atac = ad.read_h5ad("./data/atac-pp.h5ad")
guidance = nx.read_graphml("./data/guidance.graphml.gz")
scglue.models.configure_dataset(
rna, "NB", use_highly_variable=True,
use_layer="counts", use_rep="X_pca"
)
scglue.models.configure_dataset(
atac, "NB", use_highly_variable=True,
use_rep="X_lsi"
)
guidance_hvf = guidance.subgraph(chain(
rna.var.query("highly_variable").index,
atac.var.query("highly_variable").index
)).copy()
glue = scglue.models.fit_SCGLUE(
{"rna": rna, "atac": atac}, guidance_hvf,
fit_kws={"directory": "glue"}
)
glue.save("glue.dill")