Skip to content

Commit

Permalink
Hill-climber based gene selection (WIP).
Browse files Browse the repository at this point in the history
  • Loading branch information
orenbenkiki committed Sep 13, 2024
1 parent 309442c commit a331922
Show file tree
Hide file tree
Showing 3 changed files with 628 additions and 166 deletions.
109 changes: 109 additions & 0 deletions src/contracts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,17 @@ export block_total_UMIs_vector
export cell_axis
export gene_axis
export gene_block_fraction_matrix
export gene_block_is_local_marker_matrix
export gene_block_is_local_predictive_factor_matrix
export gene_block_local_r2_matrix
export gene_block_local_rms_matrix
export gene_block_program_mean_log_fraction_matrix
export gene_block_total_UMIs_matrix
export gene_divergence_vector
export gene_factor_priority_vector
export gene_global_r2_vector
export gene_global_rms_vector
export gene_is_correlated_vector
export gene_is_forbidden_factor_vector
export gene_is_global_predictive_factor_vector
export gene_is_lateral_vector
Expand All @@ -33,6 +40,7 @@ export metacell_axis
export metacell_block_vector
export metacell_total_UMIs_vector
export metacell_type_vector
export program_gene_fraction_regularization_scalar
export type_axis
export type_color_vector

Expand Down Expand Up @@ -165,6 +173,16 @@ function gene_is_forbidden_factor_vector(expectation::ContractExpectation)::Pair
(expectation, Bool, "A mask of genes that are forbidden from being used as predictive factors.")
end

"""
function gene_is_correlated_vector(expectation::ContractExpectation)::Pair{VectorKey, DataSpecification}
A mask of genes that are correlated with other gene(s). We typically search for groups of genes that act together. Genes
that have no correlation with other genes aren't useful for this sort of analysis, even if they are marker genes.
"""
function gene_is_correlated_vector(expectation::ContractExpectation)::Pair{VectorKey, DataSpecification}
return ("gene", "is_correlated") => (expectation, Bool, "A mask of genes that are correlated with other gene(s).")
end

"""
function gene_factor_priority_vector(expectation::ContractExpectation)::Pair{VectorKey, DataSpecification}
Expand All @@ -188,6 +206,30 @@ function gene_is_global_predictive_factor_vector(expectation::ContractExpectatio
(expectation, Bool, "A mask of globally predictive transcription factors.")
end

"""
function gene_global_rms_vector(expectation::ContractExpectation)::Pair{VectorKey, DataSpecification}
The cross-validation RMS of the global linear approximation for this gene. That is, when using the global predictive
transcription factors, this is the residual mean square error (reduced by the gene's divergence) of the approximation of
the (log base 2) of the gene expression across the metacells. Genes which aren't approximated are given an RMS of 0.
"""
function gene_global_rms_vector(expectation::ContractExpectation)::Pair{VectorKey, DataSpecification}
return ("gene", "global_rms") =>
(expectation, StorageFloat, "The cross-validation RMS of the global linear approximation for this gene.")
end

"""
function gene_global_r2_vector(expectation::ContractExpectation)::Pair{VectorKey, DataSpecification}
The cross-validation R^2 of the global linear approximation for this gene. That is, when using the global predictive
transcription factors, this is the coefficient of determination of the approximation of the (log base 2) of the gene
expression across the metacells. Genes which aren't approximated are given an R^2 of 0.
"""
function gene_global_r2_vector(expectation::ContractExpectation)::Pair{VectorKey, DataSpecification}
return ("gene", "global_r2") =>
(expectation, StorageFloat, "The cross-validation R^2 of the global linear approximation for this gene.")
end

"""
function cell_is_excluded_vector(expectation::ContractExpectation)::Pair{VectorKey, DataSpecification}
Expand Down Expand Up @@ -349,6 +391,47 @@ function gene_block_is_local_predictive_factor_matrix(
(expectation, Bool, "A mask of the predictive factors in each block.")
end

"""
function gene_block_is_local_marker_matrix(expectation::ContractExpectation)::Pair{MatrixKey, DataSpecification}
A mask of the marker genes in the environment of each block. That is, for each block, we look at the genes in the
environment of the block that have both a significant maximal expression and a wide range of expressions. This is a
subset of the overall marker genes.
"""
function gene_block_is_local_marker_matrix(expectation::ContractExpectation)::Pair{MatrixKey, DataSpecification}
return ("gene", "block", "is_local_marker") =>
(expectation, Bool, "A mask of the marker genes in the environment of each block.")
end

"""
function gene_block_program_mean_log_fraction_matrix(expectation::ContractExpectation)::Pair{MatrixKey, DataSpecification}
The mean log of the expression level in the environment of the block. When predicting the expression level of genes, we
actually predict the difference from this mean based on the difference from this mean of the predictive transcription
factors.
"""
function gene_block_program_mean_log_fraction_matrix(
expectation::ContractExpectation,
)::Pair{MatrixKey, DataSpecification}
return ("gene", "block", "program_mean_log_fraction") =>
(expectation, StorageFloat, "The mean log of the expression level in the environment of the block.")
end

"""
function program_gene_fraction_regularization_scalar(expectation::ContractExpectation)::Pair{ScalarKey, DataSpecification}
The regularization used to compute the log base 2 of the gene fractions for the predictive programs.
"""
function program_gene_fraction_regularization_scalar(
expectation::ContractExpectation,
)::Pair{ScalarKey, DataSpecification}
return "program_gene_fraction_regularization" => (
expectation,
StorageFloat,
"The regularization used to compute the log base 2 of the gene fractions for the predictive programs.",
)
end

"""
function gene_block_fraction_matrix(expectation::ContractExpectation)::Pair{MatrixKey, DataSpecification}
Expand All @@ -361,4 +444,30 @@ function gene_block_fraction_matrix(expectation::ContractExpectation)::Pair{Matr
(expectation, StorageFloat, "The estimated fraction of the UMIs of each gene in each block.")
end

"""
function gene_block_local_rms_matrix(expectation::ContractExpectation)::Pair{MatrixKey, DataSpecification}
The cross-validation RMS of the linear approximation for this gene in each block. That is, when using the local predictive
transcription factors of the block, this is the residual mean square error (reduced by the gene's divergence) of the
approximation of the (log base 2) of the gene expression across the metacells in the neighborhood of the block. Genes
which aren't approximated are given an RMS of 0.
"""
function gene_block_local_rms_matrix(expectation::ContractExpectation)::Pair{MatrixKey, DataSpecification}
return ("gene", "block", "local_rms") =>
(expectation, StorageFloat, "The cross-validation RMS of the linear approximation for this gene in each block.")
end

"""
function gene_block_local_r2_matrix(expectation::ContractExpectation)::Pair{MatrixKey, DataSpecification}
The cross-validation R^2 of the linear approximation for this gene in each block. That is, when using the local
predictive transcription factors of the block, this is the coefficient of determination of the approximation of the (log
base 2) of the gene expression across the metacells in the neighborhood of the block. Genes which aren't approximated
are given an R^2 of 0.
"""
function gene_block_local_r2_matrix(expectation::ContractExpectation)::Pair{MatrixKey, DataSpecification}
return ("gene", "block", "local_r2") =>
(expectation, StorageFloat, "The cross-validation R^2 of the linear approximation for this gene in each block.")
end

end # module
68 changes: 68 additions & 0 deletions src/identify_genes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Identify special genes.
module IdentifyGenes

export compute_genes_divergence!
export identify_correlated_genes!
export identify_marker_genes!

using Daf
Expand Down Expand Up @@ -133,4 +134,71 @@ $(CONTRACT)
return nothing
end

"""
function identify_correlated_genes!(
daf::DafWriter;
gene_fraction_regularization::AbstractFloat = $(DEFAULT.gene_fraction_regularization),
correlation_confidence::AbstractFloat = $(DEFAULT.correlation_confidence),
overwrite::Bool = $(DEFAULT.overwrite),
)::Nothing
Identify genes that are correlated with other gene(s). Such genes are good candidates for looking for groups of genes
that act together. If `overwrite`, will overwrite an existing `is_correlated` mask.
1. Compute the log base 2 of the genes expression in each metacell (using the `gene_fraction_regularization`).
2. Correlate this between all the pairs of genes.
3. For each gene, shuffle its values along all metacells, and again correlate this between all the pairs of genes.
4. Find the maximal absolute correlation for each gene in both cases (that is, strong anti-correlation also counts).
5. Find the `correlation_confidence` quantile correlation of the shuffled data.
6. Identify the genes that have at least that level of correlations in the unshuffled data.
$(CONTRACT)
"""
@logged @computation Contract(
axes = [gene_axis(RequiredInput), metacell_axis(RequiredInput)],
data = [gene_metacell_fraction_matrix(RequiredInput), gene_is_correlated_vector(GuaranteedOutput)],
) function identify_correlated_genes!( # untested
daf::DafWriter;
gene_fraction_regularization::AbstractFloat = GENE_FRACTION_REGULARIZATION,
correlation_confidence::AbstractFloat = 0.99,
overwrite::Bool = false,
rng::AbstractRNG = default_rng(),
)::Nothing
@assert gene_fraction_regularization >= 0
@assert 0 <= correlation_confidence <= 1

n_genes = axis_length(daf, "gene")
log_fractions_of_genes_in_metacells =
daf["/ metacell / gene : fraction % Log base 2 eps $(gene_fraction_regularization)"].array

correlations_between_genes = cor(log_fractions_of_genes_in_metacells)
correlations_between_genes .= abs.(correlations_between_genes)
correlations_between_genes[diagind(correlations_between_genes)] .= 0 # NOJET
correlations_between_genes[isnan.(correlations_between_genes)] .= 0
max_correlations_of_genes = vec(maximum(correlations_between_genes; dims = 1))
@assert length(max_correlations_of_genes) == n_genes

shuffled_log_fractions_of_genes_in_metacells = copy_array(log_fractions_of_genes_in_metacells)
for gene_index in 1:n_genes
@views shuffled_log_fractions_of_metacells_of_gene = shuffled_log_fractions_of_genes_in_metacells[:, gene_index]
shuffle!(rng, shuffled_log_fractions_of_metacells_of_gene)
end

shuffled_correlations_between_genes = cor(shuffled_log_fractions_of_genes_in_metacells)
shuffled_correlations_between_genes .= abs.(shuffled_correlations_between_genes)
shuffled_correlations_between_genes[diagind(shuffled_correlations_between_genes)] .= 0 # NOJET
shuffled_correlations_between_genes[isnan.(shuffled_correlations_between_genes)] .= 0
max_shuffled_correlations_of_genes = vec(maximum(shuffled_correlations_between_genes; dims = 1))

@debug "mean: $(mean(max_shuffled_correlations_of_genes))"
@debug "stdev: $(std(max_shuffled_correlations_of_genes))"
threshold = quantile(max_shuffled_correlations_of_genes, correlation_confidence)
@debug "threshold: $(threshold)"

is_correlated_of_genes = max_correlations_of_genes .>= threshold

set_vector!(daf, "gene", "is_correlated", is_correlated_of_genes; overwrite = overwrite)
return nothing
end

end # module
Loading

0 comments on commit a331922

Please sign in to comment.