From a331922b6b9c002374f3d92f946be6c3a11c0372 Mon Sep 17 00:00:00 2001 From: Oren Ben-Kiki Date: Wed, 11 Sep 2024 07:36:04 +0300 Subject: [PATCH] Hill-climber based gene selection (WIP). --- src/contracts.jl | 109 ++++++++ src/identify_genes.jl | 68 +++++ src/programs.jl | 617 ++++++++++++++++++++++++++++++------------ 3 files changed, 628 insertions(+), 166 deletions(-) diff --git a/src/contracts.jl b/src/contracts.jl index 13472de..39ec34f 100644 --- a/src/contracts.jl +++ b/src/contracts.jl @@ -18,10 +18,17 @@ export block_total_UMIs_vector export cell_axis export gene_axis export gene_block_fraction_matrix +export gene_block_is_local_marker_matrix export gene_block_is_local_predictive_factor_matrix +export gene_block_local_r2_matrix +export gene_block_local_rms_matrix +export gene_block_program_mean_log_fraction_matrix export gene_block_total_UMIs_matrix export gene_divergence_vector export gene_factor_priority_vector +export gene_global_r2_vector +export gene_global_rms_vector +export gene_is_correlated_vector export gene_is_forbidden_factor_vector export gene_is_global_predictive_factor_vector export gene_is_lateral_vector @@ -33,6 +40,7 @@ export metacell_axis export metacell_block_vector export metacell_total_UMIs_vector export metacell_type_vector +export program_gene_fraction_regularization_scalar export type_axis export type_color_vector @@ -165,6 +173,16 @@ function gene_is_forbidden_factor_vector(expectation::ContractExpectation)::Pair (expectation, Bool, "A mask of genes that are forbidden from being used as predictive factors.") end +""" + function gene_is_correlated_vector(expectation::ContractExpectation)::Pair{VectorKey, DataSpecification} + +A mask of genes that are correlated with other gene(s). We typically search for groups of genes that act together. Genes +that have no correlation with other genes aren't useful for this sort of analysis, even if they are marker genes. +""" +function gene_is_correlated_vector(expectation::ContractExpectation)::Pair{VectorKey, DataSpecification} + return ("gene", "is_correlated") => (expectation, Bool, "A mask of genes that are correlated with other gene(s).") +end + """ function gene_factor_priority_vector(expectation::ContractExpectation)::Pair{VectorKey, DataSpecification} @@ -188,6 +206,30 @@ function gene_is_global_predictive_factor_vector(expectation::ContractExpectatio (expectation, Bool, "A mask of globally predictive transcription factors.") end +""" + function gene_global_rms_vector(expectation::ContractExpectation)::Pair{VectorKey, DataSpecification} + +The cross-validation RMS of the global linear approximation for this gene. That is, when using the global predictive +transcription factors, this is the residual mean square error (reduced by the gene's divergence) of the approximation of +the (log base 2) of the gene expression across the metacells. Genes which aren't approximated are given an RMS of 0. +""" +function gene_global_rms_vector(expectation::ContractExpectation)::Pair{VectorKey, DataSpecification} + return ("gene", "global_rms") => + (expectation, StorageFloat, "The cross-validation RMS of the global linear approximation for this gene.") +end + +""" + function gene_global_r2_vector(expectation::ContractExpectation)::Pair{VectorKey, DataSpecification} + +The cross-validation R^2 of the global linear approximation for this gene. That is, when using the global predictive +transcription factors, this is the coefficient of determination of the approximation of the (log base 2) of the gene +expression across the metacells. Genes which aren't approximated are given an R^2 of 0. +""" +function gene_global_r2_vector(expectation::ContractExpectation)::Pair{VectorKey, DataSpecification} + return ("gene", "global_r2") => + (expectation, StorageFloat, "The cross-validation R^2 of the global linear approximation for this gene.") +end + """ function cell_is_excluded_vector(expectation::ContractExpectation)::Pair{VectorKey, DataSpecification} @@ -349,6 +391,47 @@ function gene_block_is_local_predictive_factor_matrix( (expectation, Bool, "A mask of the predictive factors in each block.") end +""" + function gene_block_is_local_marker_matrix(expectation::ContractExpectation)::Pair{MatrixKey, DataSpecification} + +A mask of the marker genes in the environment of each block. That is, for each block, we look at the genes in the +environment of the block that have both a significant maximal expression and a wide range of expressions. This is a +subset of the overall marker genes. +""" +function gene_block_is_local_marker_matrix(expectation::ContractExpectation)::Pair{MatrixKey, DataSpecification} + return ("gene", "block", "is_local_marker") => + (expectation, Bool, "A mask of the marker genes in the environment of each block.") +end + +""" + function gene_block_program_mean_log_fraction_matrix(expectation::ContractExpectation)::Pair{MatrixKey, DataSpecification} + +The mean log of the expression level in the environment of the block. When predicting the expression level of genes, we +actually predict the difference from this mean based on the difference from this mean of the predictive transcription +factors. +""" +function gene_block_program_mean_log_fraction_matrix( + expectation::ContractExpectation, +)::Pair{MatrixKey, DataSpecification} + return ("gene", "block", "program_mean_log_fraction") => + (expectation, StorageFloat, "The mean log of the expression level in the environment of the block.") +end + +""" + function program_gene_fraction_regularization_scalar(expectation::ContractExpectation)::Pair{ScalarKey, DataSpecification} + +The regularization used to compute the log base 2 of the gene fractions for the predictive programs. +""" +function program_gene_fraction_regularization_scalar( + expectation::ContractExpectation, +)::Pair{ScalarKey, DataSpecification} + return "program_gene_fraction_regularization" => ( + expectation, + StorageFloat, + "The regularization used to compute the log base 2 of the gene fractions for the predictive programs.", + ) +end + """ function gene_block_fraction_matrix(expectation::ContractExpectation)::Pair{MatrixKey, DataSpecification} @@ -361,4 +444,30 @@ function gene_block_fraction_matrix(expectation::ContractExpectation)::Pair{Matr (expectation, StorageFloat, "The estimated fraction of the UMIs of each gene in each block.") end +""" + function gene_block_local_rms_matrix(expectation::ContractExpectation)::Pair{MatrixKey, DataSpecification} + +The cross-validation RMS of the linear approximation for this gene in each block. That is, when using the local predictive +transcription factors of the block, this is the residual mean square error (reduced by the gene's divergence) of the +approximation of the (log base 2) of the gene expression across the metacells in the neighborhood of the block. Genes +which aren't approximated are given an RMS of 0. +""" +function gene_block_local_rms_matrix(expectation::ContractExpectation)::Pair{MatrixKey, DataSpecification} + return ("gene", "block", "local_rms") => + (expectation, StorageFloat, "The cross-validation RMS of the linear approximation for this gene in each block.") +end + +""" + function gene_block_local_r2_matrix(expectation::ContractExpectation)::Pair{MatrixKey, DataSpecification} + +The cross-validation R^2 of the linear approximation for this gene in each block. That is, when using the local +predictive transcription factors of the block, this is the coefficient of determination of the approximation of the (log +base 2) of the gene expression across the metacells in the neighborhood of the block. Genes which aren't approximated +are given an R^2 of 0. +""" +function gene_block_local_r2_matrix(expectation::ContractExpectation)::Pair{MatrixKey, DataSpecification} + return ("gene", "block", "local_r2") => + (expectation, StorageFloat, "The cross-validation R^2 of the linear approximation for this gene in each block.") +end + end # module diff --git a/src/identify_genes.jl b/src/identify_genes.jl index 7b0242b..c1feaed 100644 --- a/src/identify_genes.jl +++ b/src/identify_genes.jl @@ -4,6 +4,7 @@ Identify special genes. module IdentifyGenes export compute_genes_divergence! +export identify_correlated_genes! export identify_marker_genes! using Daf @@ -133,4 +134,71 @@ $(CONTRACT) return nothing end +""" + function identify_correlated_genes!( + daf::DafWriter; + gene_fraction_regularization::AbstractFloat = $(DEFAULT.gene_fraction_regularization), + correlation_confidence::AbstractFloat = $(DEFAULT.correlation_confidence), + overwrite::Bool = $(DEFAULT.overwrite), + )::Nothing + +Identify genes that are correlated with other gene(s). Such genes are good candidates for looking for groups of genes +that act together. If `overwrite`, will overwrite an existing `is_correlated` mask. + + 1. Compute the log base 2 of the genes expression in each metacell (using the `gene_fraction_regularization`). + 2. Correlate this between all the pairs of genes. + 3. For each gene, shuffle its values along all metacells, and again correlate this between all the pairs of genes. + 4. Find the maximal absolute correlation for each gene in both cases (that is, strong anti-correlation also counts). + 5. Find the `correlation_confidence` quantile correlation of the shuffled data. + 6. Identify the genes that have at least that level of correlations in the unshuffled data. + +$(CONTRACT) +""" +@logged @computation Contract( + axes = [gene_axis(RequiredInput), metacell_axis(RequiredInput)], + data = [gene_metacell_fraction_matrix(RequiredInput), gene_is_correlated_vector(GuaranteedOutput)], +) function identify_correlated_genes!( # untested + daf::DafWriter; + gene_fraction_regularization::AbstractFloat = GENE_FRACTION_REGULARIZATION, + correlation_confidence::AbstractFloat = 0.99, + overwrite::Bool = false, + rng::AbstractRNG = default_rng(), +)::Nothing + @assert gene_fraction_regularization >= 0 + @assert 0 <= correlation_confidence <= 1 + + n_genes = axis_length(daf, "gene") + log_fractions_of_genes_in_metacells = + daf["/ metacell / gene : fraction % Log base 2 eps $(gene_fraction_regularization)"].array + + correlations_between_genes = cor(log_fractions_of_genes_in_metacells) + correlations_between_genes .= abs.(correlations_between_genes) + correlations_between_genes[diagind(correlations_between_genes)] .= 0 # NOJET + correlations_between_genes[isnan.(correlations_between_genes)] .= 0 + max_correlations_of_genes = vec(maximum(correlations_between_genes; dims = 1)) + @assert length(max_correlations_of_genes) == n_genes + + shuffled_log_fractions_of_genes_in_metacells = copy_array(log_fractions_of_genes_in_metacells) + for gene_index in 1:n_genes + @views shuffled_log_fractions_of_metacells_of_gene = shuffled_log_fractions_of_genes_in_metacells[:, gene_index] + shuffle!(rng, shuffled_log_fractions_of_metacells_of_gene) + end + + shuffled_correlations_between_genes = cor(shuffled_log_fractions_of_genes_in_metacells) + shuffled_correlations_between_genes .= abs.(shuffled_correlations_between_genes) + shuffled_correlations_between_genes[diagind(shuffled_correlations_between_genes)] .= 0 # NOJET + shuffled_correlations_between_genes[isnan.(shuffled_correlations_between_genes)] .= 0 + max_shuffled_correlations_of_genes = vec(maximum(shuffled_correlations_between_genes; dims = 1)) + + @debug "mean: $(mean(max_shuffled_correlations_of_genes))" + @debug "stdev: $(std(max_shuffled_correlations_of_genes))" + threshold = quantile(max_shuffled_correlations_of_genes, correlation_confidence) + @debug "threshold: $(threshold)" + + is_correlated_of_genes = max_correlations_of_genes .>= threshold + + set_vector!(daf, "gene", "is_correlated", is_correlated_of_genes; overwrite = overwrite) + return nothing +end + end # module diff --git a/src/programs.jl b/src/programs.jl index 005c5b6..17d511c 100644 --- a/src/programs.jl +++ b/src/programs.jl @@ -8,6 +8,7 @@ export compute_global_predictive_factors! export compute_blocks! export compute_blocks_vicinities! export compute_local_predictive_factors! +export compute_programs! using ..Contracts using ..Defaults @@ -46,6 +47,7 @@ end gene_fraction_regularization::AbstractFloat = $(DEFAULT.gene_fraction_regularization), max_principal_components::Integer = $(DEFAULT.max_principal_components), factors_per_principal_component::Real = $(DEFAULT.factors_per_principal_component), + overwrite::Bool = $(DEFAULT.overwrite), )::Nothing Given the transcription factors in the data, identify an ordered subset of these to use as candidates for predicting the @@ -71,6 +73,7 @@ $(CONTRACT) gene_fraction_regularization::AbstractFloat = 2 * GENE_FRACTION_REGULARIZATION, max_principal_components::Integer = 40, factors_per_principal_component::Real = 2, + overwrite::Bool = false, )::Nothing @assert gene_fraction_regularization >= 0 @assert max_principal_components > 0 @@ -117,13 +120,13 @@ $(CONTRACT) factor_genes_order = sortperm(minimal_rank_of_factor_genes) rank_of_factor_genes = invperm(factor_genes_order) - n_top_genes = Int(round(n_principal_components * factors_per_principal_component)) + n_top_genes = min(Int(round(n_principal_components * factors_per_principal_component)), length(factor_genes)) top_genes = factor_genes[factor_genes_order[1:n_top_genes]] @debug "top $(n_top_genes) ranked transcription factors: $(join(context.names_of_genes[top_genes], ", "))" factor_priority_of_genes = zeros(UInt16, context.n_genes) factor_priority_of_genes[factor_genes] .= max.((n_top_genes + 1) .- rank_of_factor_genes, 0) - set_vector!(daf, "gene", "factor_priority", factor_priority_of_genes) + set_vector!(daf, "gene", "factor_priority", factor_priority_of_genes; overwrite = overwrite) return nothing end @@ -160,13 +163,14 @@ end function compute_global_predictive_factors!( daf::DafWriter; gene_fraction_regularization::AbstractFloat = $(DEFAULT.gene_fraction_regularization), - cross_validation::Integer = $(DEFAULT.cross_validation), + cross_validation_parts::Integer = $(DEFAULT.cross_validation_parts), rng::AbstractRNG = default_rng(), + overwrite::Bool = $(DEFAULT.overwrite), )::Nothing Given a prioritized subset of the factor genes likely to be predictive, identify a subset of these genes which best predict the values of the rest of the genes, using non-negative least-squares approximation using the log (base 2) of -the genes expression (using the `gene_fraction_regularization`). To avoid over-fitting, we use `cross_validation`. +the genes expression (using the `gene_fraction_regularization`). To avoid over-fitting, we use `cross_validation_parts`. We require each additional used factor to improve (reduce) the RMS of the cross validation error by at least `min_rms_improvement` (on average). Factors with higher prioriry are allowed a lower improvement and factors with a @@ -186,6 +190,8 @@ $(CONTRACT) gene_divergence_vector(RequiredInput), gene_factor_priority_vector(RequiredInput), gene_is_global_predictive_factor_vector(GuaranteedOutput), + gene_global_rms_vector(GuaranteedOutput), + gene_global_r2_vector(GuaranteedOutput), ], ) function compute_global_predictive_factors!( # untested daf::DafWriter; @@ -193,13 +199,14 @@ $(CONTRACT) compute_factor_priority_of_genes!, :gene_fraction_regularization, ), - min_rms_improvement::AbstractFloat = 2e-2, + min_rms_improvement::AbstractFloat = 1e-2, rms_priority_improvement::Real = 1, - cross_validation::Integer = 5, + cross_validation_parts::Integer = 5, rng::AbstractRNG = default_rng(), + overwrite::Bool = false, )::Nothing @assert gene_fraction_regularization >= 0 - @assert cross_validation > 1 + @assert cross_validation_parts > 1 @assert min_rms_improvement >= 0 @assert 0 <= rms_priority_improvement <= 2 @@ -210,9 +217,9 @@ $(CONTRACT) rms_priority_improvement = rms_priority_improvement, ) - global_predictive_genes = identify_predictive_genes!(; + global_predictive_genes, cross_validation = identify_predictive_genes!(; context = context, - cross_validation = cross_validation, + cross_validation_parts = cross_validation_parts, included_genes = 1:(context.n_genes), included_metacells = 1:(context.n_metacells), core_metacells = 1:(context.n_metacells), @@ -223,8 +230,13 @@ $(CONTRACT) is_global_predictive_factor_of_genes = zeros(Bool, context.n_genes) is_global_predictive_factor_of_genes[global_predictive_genes] .= true - set_vector!(daf, "gene", "is_global_predictive_factor", is_global_predictive_factor_of_genes) + set_vector!(daf, "gene", "global_rms", cross_validation.rms_of_genes; overwrite = overwrite) + set_vector!(daf, "gene", "global_r2", cross_validation.r2_of_genes; overwrite = overwrite) + set_vector!(daf, "gene", "is_global_predictive_factor", is_global_predictive_factor_of_genes; overwrite = overwrite) + @debug "global predictive factors: [ $(join(context.names_of_genes[global_predictive_genes], ", ")) ]" + @debug "mean cross-validation genes RMS: $(mean(cross_validation.rms_of_genes))" + @debug "mean cross-validation genes R^2: $(mean(cross_validation.r2_of_genes))" return nothing end @@ -280,28 +292,35 @@ function compute_rms_cost_factor_of_genes(; # untested return rms_cost_factor_of_genes end +@kwdef struct CrossValidation + mean_rms::AbstractFloat + rms_of_genes::AbstractVector{<:AbstractFloat} + r2_of_genes::AbstractVector{<:AbstractFloat} +end + function identify_predictive_genes!(; # untested context::Context, - cross_validation::Integer, + cross_validation_parts::Integer, included_genes::Union{UnitRange{<:Integer}, AbstractVector{<:Integer}}, included_metacells::Union{UnitRange{<:Integer}, AbstractVector{<:Integer}}, core_metacells::Union{UnitRange{<:Integer}, AbstractVector{<:Integer}}, candidate_factors::CandidateFactors, rng::AbstractRNG, -)::AbstractVector{<:Integer} +)::Tuple{AbstractVector{<:Integer}, CrossValidation} predictive_genes = Int32[] - predictive_cost = Maybe{Float32}[nothing] + predictive_cost = Maybe{Float32}[nothing, nothing, nothing] op_index = 2 op_success = [true, true] last_added_gene_index = nothing last_removed_gene_index = nothing + cross_validation = nothing while sum(op_success) > 0 op_index = 1 + op_index % 2 if op_index == 1 - last_added_gene_index = try_add_factors!(; + last_added_gene_index, new_cross_validation = try_add_factors!(; context = context, - cross_validation = cross_validation, + cross_validation_parts = cross_validation_parts, included_genes = included_genes, included_metacells = included_metacells, core_metacells = core_metacells, @@ -313,9 +332,9 @@ function identify_predictive_genes!(; # untested ) op_success[op_index] = last_added_gene_index !== nothing else - last_removed_gene_index = try_remove_factors!(; + last_removed_gene_index, new_cross_validation = try_remove_factors!(; context = context, - cross_validation = cross_validation, + cross_validation_parts = cross_validation_parts, included_genes = included_genes, included_metacells = included_metacells, core_metacells = core_metacells, @@ -327,14 +346,18 @@ function identify_predictive_genes!(; # untested ) op_success[op_index] = last_removed_gene_index !== nothing end + if new_cross_validation !== nothing + cross_validation = new_cross_validation + end end - return predictive_genes + @assert cross_validation !== nothing + return predictive_genes, cross_validation end # NOJET function try_add_factors!(; # untested context::Context, - cross_validation::Integer, + cross_validation_parts::Integer, included_genes::Union{UnitRange{<:Integer}, AbstractVector{<:Integer}}, included_metacells::Union{UnitRange{<:Integer}, AbstractVector{<:Integer}}, core_metacells::Union{UnitRange{<:Integer}, AbstractVector{<:Integer}}, @@ -343,34 +366,37 @@ function try_add_factors!(; # untested predictive_genes::AbstractVector{<:Integer}, predictive_cost::Vector{Maybe{Float32}}, last_removed_gene_index::Maybe{<:Integer}, -)::Maybe{<:Integer} +)::Tuple{Maybe{<:Integer}, Maybe{CrossValidation}} last_added_gene_index = nothing + cross_validation = nothing for (test_index, factor_gene_index) in enumerate(candidate_factors.ordered_factor_genes) - if factor_gene_index != last_removed_gene_index && - !(factor_gene_index in predictive_genes) && - try_add_factor!(; - context = context, - cross_validation = cross_validation, - included_genes = included_genes, - included_metacells = included_metacells, - core_metacells = core_metacells, - candidate_factors = candidate_factors, - rng = rng, - predictive_genes = predictive_genes, - predictive_cost = predictive_cost, - last_added_gene_index = last_added_gene_index, - test_index = test_index, - factor_gene_index = factor_gene_index, - ) - last_added_gene_index = factor_gene_index + if factor_gene_index != last_removed_gene_index && !(factor_gene_index in predictive_genes) + new_cross_validation = try_add_factor!(; + context = context, + cross_validation_parts = cross_validation_parts, + included_genes = included_genes, + included_metacells = included_metacells, + core_metacells = core_metacells, + candidate_factors = candidate_factors, + rng = rng, + predictive_genes = predictive_genes, + predictive_cost = predictive_cost, + last_added_gene_index = last_added_gene_index, + test_index = test_index, + factor_gene_index = factor_gene_index, + ) + if new_cross_validation !== nothing + cross_validation = new_cross_validation + last_added_gene_index = factor_gene_index + end end end - return last_added_gene_index + return (last_added_gene_index, cross_validation) end function try_add_factor!(; # untested context::Context, - cross_validation::Integer, + cross_validation_parts::Integer, included_genes::Union{UnitRange{<:Integer}, AbstractVector{<:Integer}}, included_metacells::Union{UnitRange{<:Integer}, AbstractVector{<:Integer}}, core_metacells::Union{UnitRange{<:Integer}, AbstractVector{<:Integer}}, @@ -378,33 +404,36 @@ function try_add_factor!(; # untested rng::AbstractRNG, predictive_genes::AbstractVector{<:Integer}, predictive_cost::Vector{Maybe{Float32}}, - last_added_gene_index::Maybe{<:Integer}, # NOLINT - test_index::Integer, # NOLINT + last_added_gene_index::Maybe{<:Integer}, + test_index::Integer, factor_gene_index::Integer, -)::Bool +)::Maybe{CrossValidation} push!(predictive_genes, factor_gene_index) order_predictive_genes!(; predictive_genes = predictive_genes, candidate_factors = candidate_factors) - # print( - # "$(length(predictive_genes) - 1)" * - # (last_added_gene_index === nothing ? "" : "+") * - # " > $(test_index)" * - # " / $(length(candidate_factors.ordered_factor_genes)) " * - # join( - # [ - # (gene_index == last_added_gene_index ? "+" : "") * - # (gene_index == factor_gene_index ? "?+" : "") * - # context.names_of_genes[gene_index] for gene_index in predictive_genes - # ], - # " ", - # ) * - # " ~ $(predictive_cost[1] === nothing ? "NA" : @sprintf("%.5f", predictive_cost[1]))" * - # " ...\e[0K\r", - # ) - - cross_validation_rms = compute_cross_validation_rms(; + print("\n\e[1A") + print( + "\e7$(length(predictive_genes) - 1)" * + (last_added_gene_index === nothing ? "" : "+") * + " > $(test_index)" * + " / $(length(candidate_factors.ordered_factor_genes)) " * + join( + [ + (gene_index == last_added_gene_index ? "+" : "") * + (gene_index == factor_gene_index ? "?+" : "") * + context.names_of_genes[gene_index] for gene_index in predictive_genes + ], + " ", + ) * + " ~ $(predictive_cost[1] === nothing ? "NA" : @sprintf("%.5f", predictive_cost[1]))" * + " RMS $(predictive_cost[2] === nothing ? "NA" : @sprintf("%.5f", predictive_cost[2]))" * + " R^2 $(predictive_cost[3] === nothing ? "NA" : @sprintf("%.5f", predictive_cost[3]))" * + " ...\e[J\e8", + ) + + cross_validation = compute_cross_validation(; context = context, - cross_validation = cross_validation, + cross_validation_parts = cross_validation_parts, predictive_genes = predictive_genes, included_genes = included_genes, included_metacells = included_metacells, @@ -413,7 +442,7 @@ function try_add_factor!(; # untested ) rms_cost_factor = reduce(*, candidate_factors.rms_cost_factor_of_genes[predictive_genes]) - cost = cross_validation_rms * rms_cost_factor + cost = cross_validation.mean_rms * rms_cost_factor if predictive_cost[1] === nothing improvement = 1.0 @@ -423,18 +452,20 @@ function try_add_factor!(; # untested if improvement > 0.0 predictive_cost[1] = cost - return true + predictive_cost[2] = cross_validation.mean_rms + predictive_cost[3] = mean(cross_validation.r2_of_genes[included_genes]) + return cross_validation else filter!(predictive_genes) do predictive_gene return predictive_gene != factor_gene_index end - return false + return nothing end end function try_remove_factors!(; # untested context::Context, - cross_validation::Integer, + cross_validation_parts::Integer, included_genes::Union{UnitRange{<:Integer}, AbstractVector{<:Integer}}, included_metacells::Union{UnitRange{<:Integer}, AbstractVector{<:Integer}}, core_metacells::Union{UnitRange{<:Integer}, AbstractVector{<:Integer}}, @@ -443,35 +474,40 @@ function try_remove_factors!(; # untested predictive_genes::AbstractVector{<:Integer}, predictive_cost::Vector{Maybe{Float32}}, last_added_gene_index::Maybe{<:Integer}, -)::Maybe{<:Integer} +)::Tuple{Maybe{<:Integer}, Maybe{CrossValidation}} last_removed_gene_index = nothing + cross_validation = nothing predictive_index = 0 while predictive_index < length(predictive_genes) predictive_index += 1 factor_gene_index = predictive_genes[length(predictive_genes) + 1 - predictive_index] - if factor_gene_index !== last_added_gene_index && try_remove_factor!(; - context = context, - cross_validation = cross_validation, - included_genes = included_genes, - included_metacells = included_metacells, - core_metacells = core_metacells, - candidate_factors = candidate_factors, - rng = rng, - predictive_genes = predictive_genes, - predictive_cost = predictive_cost, - last_removed_gene_index = last_removed_gene_index, - predictive_index = predictive_index, - factor_gene_index = factor_gene_index, - ) - last_removed_gene_index = factor_gene_index + if factor_gene_index !== last_added_gene_index + new_cross_validation = try_remove_factor!(; + context = context, + cross_validation_parts = cross_validation_parts, + included_genes = included_genes, + included_metacells = included_metacells, + core_metacells = core_metacells, + candidate_factors = candidate_factors, + rng = rng, + predictive_genes = predictive_genes, + predictive_cost = predictive_cost, + last_removed_gene_index = last_removed_gene_index, + predictive_index = predictive_index, + factor_gene_index = factor_gene_index, + ) + if new_cross_validation !== nothing + last_removed_gene_index = factor_gene_index + cross_validation = new_cross_validation + end end end - return last_removed_gene_index + return last_removed_gene_index, cross_validation end function try_remove_factor!(; # untested context::Context, - cross_validation::Integer, + cross_validation_parts::Integer, included_genes::Union{UnitRange{<:Integer}, AbstractVector{<:Integer}}, included_metacells::Union{UnitRange{<:Integer}, AbstractVector{<:Integer}}, core_metacells::Union{UnitRange{<:Integer}, AbstractVector{<:Integer}}, @@ -479,35 +515,38 @@ function try_remove_factor!(; # untested rng::AbstractRNG, predictive_genes::AbstractVector{<:Integer}, predictive_cost::Vector{Maybe{Float32}}, - last_removed_gene_index::Maybe{<:Integer}, # NOLINT - predictive_index::Integer, # NOLINT + last_removed_gene_index::Maybe{<:Integer}, + predictive_index::Integer, factor_gene_index::Integer, -)::Bool - # print( - # "$(length(predictive_genes))" * - # (last_removed_gene_index === nothing ? "" : "-") * - # " < $(predictive_index)" * - # " / $(length(predictive_genes)) " * - # join( - # [ - # (gene_index == last_removed_gene_index ? "-" : "") * - # (gene_index == factor_gene_index ? "?-" : "") * - # context.names_of_genes[gene_index] for gene_index in candidate_factors.ordered_factor_genes if - # gene_index in predictive_genes || gene_index == last_removed_gene_index - # ], - # " ", - # ) * - # " ~ $(@sprintf("%.5f", predictive_cost[1]))" * - # " ...\e[0K\r", - # ) +)::Maybe{CrossValidation} + print("\n\e[1A") + print( + "\e7$(length(predictive_genes))" * + (last_removed_gene_index === nothing ? "" : "-") * + " < $(predictive_index)" * + " / $(length(predictive_genes)) " * + join( + [ + (gene_index == last_removed_gene_index ? "-" : "") * + (gene_index == factor_gene_index ? "?-" : "") * + context.names_of_genes[gene_index] for gene_index in candidate_factors.ordered_factor_genes if + gene_index in predictive_genes || gene_index == last_removed_gene_index + ], + " ", + ) * + " ~ $(@sprintf("%.5f", predictive_cost[1]))" * + " RMS $(@sprintf("%.5f", predictive_cost[2]))" * + " R^2 $(@sprintf("%.5f", predictive_cost[3]))" * + " ...\e[J\e8", + ) filter!(predictive_genes) do predictive_gene return predictive_gene != factor_gene_index end - cross_validation_rms = compute_cross_validation_rms(; + cross_validation = compute_cross_validation(; context = context, - cross_validation = cross_validation, + cross_validation_parts = cross_validation_parts, predictive_genes = predictive_genes, included_genes = included_genes, included_metacells = included_metacells, @@ -516,16 +555,18 @@ function try_remove_factor!(; # untested ) rms_cost_factor = reduce(*, candidate_factors.rms_cost_factor_of_genes[predictive_genes]) - cost = cross_validation_rms * rms_cost_factor + cost = cross_validation.mean_rms * rms_cost_factor improvement = predictive_cost[1] - cost if improvement >= 0 predictive_cost[1] = cost - return true + predictive_cost[2] = cross_validation.mean_rms + predictive_cost[3] = mean(cross_validation.r2_of_genes[included_genes]) + return cross_validation else push!(predictive_genes, factor_gene_index) order_predictive_genes!(; predictive_genes = predictive_genes, candidate_factors = candidate_factors) - return false + return nothing end end @@ -536,26 +577,27 @@ function order_predictive_genes!(; predictive_genes::Vector{<:Integer}, candidat return nothing end -function compute_cross_validation_rms(; # untested +function compute_cross_validation(; # untested context::Context, - cross_validation::Integer, + cross_validation_parts::Integer, predictive_genes::AbstractVector{<:Integer}, included_genes::Union{UnitRange{<:Integer}, AbstractVector{<:Integer}}, included_metacells::Union{UnitRange{<:Integer}, AbstractVector{<:Integer}}, core_metacells::Union{UnitRange{<:Integer}, AbstractVector{<:Integer}}, rng::AbstractRNG, -)::AbstractFloat +)::CrossValidation core_metacells_indices = collect(core_metacells) shuffle!(rng, core_metacells_indices) - floor_chunk_size = max(div(length(core_metacells_indices), cross_validation), 1) + floor_chunk_size = max(div(length(core_metacells_indices), cross_validation_parts), 1) n_chunks = ceil(length(core_metacells_indices) / floor_chunk_size) chunk_size = length(core_metacells_indices) / n_chunks included_metacells_mask = zeros(Bool, context.n_metacells) included_metacells_mask[included_metacells] .= true - rms_of_core_metacells = Vector{Float32}(undef, length(core_metacells_indices)) + rms_of_genes = zeros(Float32, context.n_genes) + r2_of_genes = zeros(Float32, context.n_genes) for chunk_index in 1:n_chunks first_left_out_position = Int(round((chunk_index - 1) * chunk_size)) + 1 @@ -576,14 +618,22 @@ function compute_cross_validation_rms(; # untested included_metacells = included_metacells_indices, ) - rms_of_core_metacells[left_out_positions] .= rms_of_metacells_by_least_squares(; + evaluate_genes_by_least_squares(; context = context, least_squares = least_squares, chunk_metacells = chunk_metacells_indices, + rms_of_genes = rms_of_genes, + r2_of_genes = r2_of_genes, ) end - return mean(rms_of_core_metacells) + rms_of_genes ./= n_chunks + r2_of_genes ./= n_chunks + return CrossValidation(; + mean_rms = mean(rms_of_genes[included_genes]), + rms_of_genes = rms_of_genes, + r2_of_genes = r2_of_genes, + ) end @kwdef struct LeastSquares @@ -680,35 +730,67 @@ function prepare_data(; # untested ) end -function rms_of_metacells_by_least_squares(; # untested +function evaluate_genes_by_least_squares(; # untested context::Context, least_squares::LeastSquares, chunk_metacells::AbstractVector{<:Integer}, -)::AbstractVector{<:AbstractFloat} - residual_log_fractions_in_chunk_metacells_of_included_genes = predict_by_least_squares(; + rms_of_genes::AbstractVector{<:AbstractFloat}, + r2_of_genes::AbstractVector{<:AbstractFloat}, +)::Nothing + n_included_genes = length(least_squares.included_genes) + + divergence_of_included_genes = context.divergence_of_genes[least_squares.included_genes] + log_fractions_in_chunk_metacells_of_included_genes = + transposer(context.log_fractions_of_genes_in_metacells[least_squares.included_genes, chunk_metacells]) + predicted_log_fractions_in_chunk_metacells_of_included_genes = predict_by_least_squares(; context = context, least_squares = least_squares, selected_metacells = chunk_metacells, + divergence_of_included_genes = divergence_of_included_genes, ) - residual_log_fractions_in_chunk_metacells_of_included_genes .-= - transpose(context.log_fractions_of_genes_in_metacells[least_squares.included_genes, chunk_metacells]) - - rms_of_metacells = vec( - mean( - residual_log_fractions_in_chunk_metacells_of_included_genes .* - residual_log_fractions_in_chunk_metacells_of_included_genes; - dims = 2, - ), - ) - @assert_vector(rms_of_metacells, length(chunk_metacells)) - return rms_of_metacells + residual_log_fractions_in_chunk_metacells_of_included_genes = + predicted_log_fractions_in_chunk_metacells_of_included_genes .- + log_fractions_in_chunk_metacells_of_included_genes + residual_log_fractions_in_chunk_metacells_of_included_genes .*= transpose(1.0 .- divergence_of_included_genes) + residual_log_fractions_in_chunk_metacells_of_included_genes .*= + residual_log_fractions_in_chunk_metacells_of_included_genes + + rms_of_included_genes = vec(sqrt.(mean(residual_log_fractions_in_chunk_metacells_of_included_genes; dims = 1))) + @assert_vector(rms_of_included_genes, n_included_genes) + rms_of_genes[least_squares.included_genes] .+= rms_of_included_genes + + @threads for included_gene_position in 1:n_included_genes + included_gene_index = least_squares.included_genes[included_gene_position] + @views log_fractions_in_chunk_metacells_of_included_gene = + log_fractions_in_chunk_metacells_of_included_genes[:, included_gene_position] + @views predicted_log_fractions_in_chunk_metacells_of_included_gene = + predicted_log_fractions_in_chunk_metacells_of_included_genes[:, included_gene_position] + correlation_of_gene = vcor( + log_fractions_in_chunk_metacells_of_included_gene, + predicted_log_fractions_in_chunk_metacells_of_included_gene, + ) + if isnan(correlation_of_gene) + if minimum(log_fractions_in_chunk_metacells_of_included_gene) == + maximum(log_fractions_in_chunk_metacells_of_included_gene) && + minimum(predicted_log_fractions_in_chunk_metacells_of_included_gene) == + maximum(predicted_log_fractions_in_chunk_metacells_of_included_gene) + correlation_of_gene = 1 + else + correlation_of_gene = 0 + end + end + r2_of_genes[included_gene_index] += correlation_of_gene * correlation_of_gene + end + + return nothing end function predict_by_least_squares(; # untested context::Context, least_squares::LeastSquares, selected_metacells::AbstractVector{<:Integer}, + divergence_of_included_genes::AbstractVector{<:AbstractFloat}, )::AbstractMatrix{<:AbstractFloat} relative_log_fractions_in_selected_metacells_of_predictive_genes, _, _ = prepare_data(; context = context, @@ -727,7 +809,6 @@ function predict_by_least_squares(; # untested Columns, ) - divergence_of_included_genes = context.divergence_of_genes[least_squares.included_genes] predicted_log_fractions_in_selected_metacells_of_included_genes ./= transpose(1.0 .- divergence_of_included_genes) predicted_log_fractions_in_selected_metacells_of_included_genes .+= transpose(least_squares.mean_log_fractions_of_included_genes) @@ -748,6 +829,7 @@ end gene_fraction_regularization::AbstractFloat = $(DEFAULT.gene_fraction_regularization), fold_confidence::AbstractFloat = $(DEFAULT.fold_confidence), max_block_span::Real = $(DEFAULT.max_block_span), + overwrite::Bool = $(DEFAULT.overwrite), )::Nothing Given a set of transcription factors that can be used to predict the rest of the genes across the whole manifold, group @@ -782,7 +864,8 @@ $(CONTRACT) :gene_fraction_regularization, ), fold_confidence::AbstractFloat = 0.9, - max_block_span::Real = function_default(identify_marker_genes!, :min_marker_gene_range_fold), + max_block_span::Real = max(function_default(identify_marker_genes!, :min_marker_gene_range_fold) - 1, 1), + overwrite::Bool = false, )::Nothing @assert min_significant_gene_UMIs >= 0 @assert gene_fraction_regularization >= 0 @@ -804,8 +887,10 @@ $(CONTRACT) block_names = group_names(daf, "metacell", blocks.metacells_of_blocks; prefix = "B") add_axis!(daf, "block", block_names) - set_vector!(daf, "metacell", "block", block_names[blocks.blocks_of_metacells]) - set_matrix!(daf, "block", "block", "distance", blocks.distances_between_blocks) + set_vector!(daf, "metacell", "block", block_names[blocks.blocks_of_metacells]; overwrite = overwrite) + distances_between_blocks = blocks.distances_between_blocks + @assert distances_between_blocks !== nothing + set_matrix!(daf, "block", "block", "distance", distances_between_blocks; overwrite = overwrite) return nothing end @@ -869,7 +954,7 @@ end blocks_of_metacells::AbstractVector{<:Integer} metacells_of_blocks::AbstractVector{<:AbstractVector{<:Integer}} n_blocks::Integer - distances_between_blocks::AbstractMatrix{<:AbstractFloat} + distances_between_blocks::Maybe{AbstractMatrix{<:AbstractFloat}} end function compute_blocks_by_confidence(; # untested @@ -1055,10 +1140,11 @@ end min_rms_improvement::AbstractFloat = $(DEFAULT.min_rms_improvement), rms_priority_improvement::Real = $(DEFAULT.rms_priority_improvement), block_rms_bonus::AbstractFloat = $(DEFAULT.block_rms_bonus), - cross_validation::Integer = $(DEFAULT.cross_validation), + cross_validation_parts::Integer = $(DEFAULT.cross_validation_parts), min_blocks_in_neighborhood::Integer = $(DEFAULT.min_blocks_in_neighborhood), min_metacells_in_neighborhood::Integer = $(DEFAULT.min_metacells_in_neighborhood), rng::AbstractRNG = default_rng(), + overwrite::Bool = $(DEFAULT.overwrite), )::Nothing Given a partition of the metacells into distinct blocks, compute for each block its immediate neighborhood (of at least @@ -1094,13 +1180,14 @@ $(CONTRACT) min_rms_improvement::AbstractFloat = function_default(compute_global_predictive_factors!, :min_rms_improvement), rms_priority_improvement::Real = function_default(compute_global_predictive_factors!, :rms_priority_improvement), block_rms_bonus::AbstractFloat = 5e-4, - cross_validation::Integer = function_default(compute_global_predictive_factors!, :cross_validation), + cross_validation_parts::Integer = function_default(compute_global_predictive_factors!, :cross_validation_parts), min_blocks_in_neighborhood::Integer = 4, min_metacells_in_neighborhood::Integer = 100, rng::AbstractRNG = default_rng(), + overwrite::Bool = false, )::Nothing @assert gene_fraction_regularization >= 0 - @assert cross_validation > 1 + @assert cross_validation_parts > 1 @assert min_rms_improvement >= 0 @assert 0 <= rms_priority_improvement <= 2 @assert block_rms_bonus >= 0 @@ -1123,7 +1210,7 @@ $(CONTRACT) @views block_is_in_environment = block_block_is_in_environment[:, block_index] compute_vicinity_of_block!(; context = context, - cross_validation = cross_validation, + cross_validation_parts = cross_validation_parts, block_rms_bonus = block_rms_bonus, min_blocks_in_neighborhood = min_blocks_in_neighborhood, min_metacells_in_neighborhood = min_metacells_in_neighborhood, @@ -1136,8 +1223,22 @@ $(CONTRACT) ) end - set_matrix!(daf, "block", "block", "is_in_neighborhood", SparseMatrixCSC(block_block_is_in_neighborhood)) - set_matrix!(daf, "block", "block", "is_in_environment", SparseMatrixCSC(block_block_is_in_environment)) + set_matrix!( + daf, + "block", + "block", + "is_in_neighborhood", + SparseMatrixCSC(block_block_is_in_neighborhood); + overwrite = overwrite, + ) + set_matrix!( + daf, + "block", + "block", + "is_in_environment", + SparseMatrixCSC(block_block_is_in_environment); + overwrite = overwrite, + ) return nothing end @@ -1146,7 +1247,7 @@ function load_blocks(daf::DafReader)::Blocks # untested n_blocks = axis_length(daf, "block") blocks_of_metacells = axis_indices(daf, "block", get_vector(daf, "metacell", "block")) metacells_of_blocks = [findall(blocks_of_metacells .== block_index) for block_index in 1:n_blocks] - distances_between_blocks = get_matrix(daf, "block", "block", "distance") + distances_between_blocks = get_matrix(daf, "block", "block", "distance"; default = nothing) return Blocks(; blocks_of_metacells = blocks_of_metacells, @@ -1159,7 +1260,7 @@ end function compute_vicinity_of_block!(; # untested context::Context, block_rms_bonus::AbstractFloat, - cross_validation::Integer, + cross_validation_parts::Integer, min_blocks_in_neighborhood::Integer, min_metacells_in_neighborhood::Integer, global_predictive_genes::AbstractVector{<:Integer}, @@ -1169,7 +1270,9 @@ function compute_vicinity_of_block!(; # untested block_is_in_neighborhood::AbstractVector{Bool}, block_is_in_environment::AbstractVector{Bool}, )::Nothing - distances_between_others_and_block = blocks.distances_between_blocks[:, block_index] + distances_between_blocks = blocks.distances_between_blocks + @assert distances_between_blocks !== nothing + distances_between_others_and_block = distances_between_blocks[:, block_index] @assert_vector(distances_between_others_and_block, blocks.n_blocks) ordered_block_indices = sortperm(distances_between_others_and_block) @@ -1187,7 +1290,7 @@ function compute_vicinity_of_block!(; # untested context = context, blocks = blocks, block_rms_bonus = block_rms_bonus, - cross_validation = cross_validation, + cross_validation_parts = cross_validation_parts, global_predictive_genes = global_predictive_genes, rng = rng, block_index = block_index, @@ -1239,7 +1342,7 @@ function compute_environment_of_block(; # untested context::Context, blocks::Blocks, block_rms_bonus::AbstractFloat, - cross_validation::Integer, + cross_validation_parts::Integer, global_predictive_genes::AbstractVector{<:Integer}, rng::AbstractRNG, block_index::Integer, @@ -1262,7 +1365,7 @@ function compute_environment_of_block(; # untested context = context, blocks = blocks, block_rms_bonus = block_rms_bonus, - cross_validation = cross_validation, + cross_validation_parts = cross_validation_parts, global_predictive_genes = global_predictive_genes, rng = rng, block_index = block_index, @@ -1341,7 +1444,7 @@ function cost_of_environment(; # untested context::Context, blocks::Blocks, block_rms_bonus::AbstractFloat, - cross_validation::Integer, + cross_validation_parts::Integer, global_predictive_genes::AbstractVector{<:Integer}, rng::AbstractRNG, block_index::Integer, @@ -1357,9 +1460,9 @@ function cost_of_environment(; # untested environment_metacells_mask = region_metacells_mask(blocks, ordered_block_indices[1:n_blocks_in_environment]) environment_metacells = findall(environment_metacells_mask) - rms = compute_cross_validation_rms(; + cross_validation = compute_cross_validation(; context = context, - cross_validation = cross_validation, + cross_validation_parts = cross_validation_parts, predictive_genes = global_predictive_genes, included_genes = 1:(context.n_genes), included_metacells = environment_metacells, @@ -1367,16 +1470,22 @@ function cost_of_environment(; # untested rng = rng, ) - cost = rms * (1 - block_rms_bonus * sum(1 ./ (1:(n_blocks_in_environment - n_blocks_in_neighborhood)))) + cost = + cross_validation.mean_rms * + (1 - block_rms_bonus * sum(1 ./ (1:(n_blocks_in_environment - n_blocks_in_neighborhood)))) + print("\n\e[1A") print( - "$(sum(region_metacells_mask(blocks, ordered_block_indices[1:1]))) mcs" * + "\e7$(sum(region_metacells_mask(blocks, ordered_block_indices[1:1]))) mcs" * " @ block $(block_index)" * " in $(sum(region_metacells_mask(blocks, ordered_block_indices[1:n_blocks_in_neighborhood]))) mcs" * " @ $(n_blocks_in_neighborhood) neighborhood" * " in $(sum(region_metacells_mask(blocks, ordered_block_indices[1:n_blocks_in_environment]))) mcs" * " @ $(n_blocks_in_environment) environment" * - " : $(@sprintf("%.5f", cost)) ...\e[0K\r", + " : $(@sprintf("%.5f", cost))" * + " RMS $(@sprintf("%.5f", cross_validation.mean_rms))" * + " R^2 $(@sprintf("%.5f", mean(cross_validation.r2_of_genes)))" * + " ...\e[J\e8", ) return cost @@ -1388,10 +1497,11 @@ end gene_fraction_regularization::AbstractFloat = $(DEFAULT.gene_fraction_regularization), min_rms_improvement::AbstractFloat = $(DEFAULT.min_rms_improvement), rms_priority_improvement::Real = $(DEFAULT.rms_priority_improvement), - cross_validation::Integer = $(DEFAULT.cross_validation), + cross_validation_parts::Integer = $(DEFAULT.cross_validation_parts), min_marker_gene_range_fold::Real = $(DEFAULT.min_marker_gene_range_fold), min_marker_gene_max_fraction::AbstractFloat = $(DEFAULT.min_marker_gene_max_fraction), rng::AbstractRNG = default_rng(), + overwrite::Bool = $(DEFAULT.overwrite), )::Nothing Having computed the neighborhoods and environments, then for each block, figure out the set of transcription factors for @@ -1400,8 +1510,11 @@ predictive factors will be different from the set of global predictive factors ( When computing this set, we only consider the genes that are marker genes within the environment (as per [`identify_marker_genes!`](@ref), using `min_marker_gene_max_fraction` and a tighter `min_marker_gene_range_fold`. + +$(CONTRACT) """ @logged @computation Contract( + is_relaxed = true, axes = [gene_axis(RequiredInput), metacell_axis(RequiredInput), block_axis(RequiredInput)], data = [ gene_metacell_fraction_matrix(RequiredInput), @@ -1409,10 +1522,12 @@ When computing this set, we only consider the genes that are marker genes within gene_factor_priority_vector(RequiredInput), gene_is_global_predictive_factor_vector(RequiredInput), metacell_block_vector(RequiredInput), - block_block_distance_matrix(RequiredInput), block_block_is_in_neighborhood_matrix(RequiredInput), block_block_is_in_environment_matrix(RequiredInput), + gene_block_is_local_marker_matrix(GuaranteedOutput), gene_block_is_local_predictive_factor_matrix(GuaranteedOutput), + gene_block_local_rms_matrix(GuaranteedOutput), + gene_block_local_r2_matrix(GuaranteedOutput), ], ) function compute_local_predictive_factors!( # untested daf::DafWriter; @@ -1422,20 +1537,21 @@ When computing this set, we only consider the genes that are marker genes within ), min_rms_improvement::AbstractFloat = function_default(compute_global_predictive_factors!, :min_rms_improvement), rms_priority_improvement::Real = function_default(compute_global_predictive_factors!, :rms_priority_improvement), - cross_validation::Integer = function_default(compute_global_predictive_factors!, :cross_validation), + cross_validation_parts::Integer = function_default(compute_global_predictive_factors!, :cross_validation_parts), min_marker_gene_range_fold::Real = max( function_default(identify_marker_genes!, :min_marker_gene_range_fold) - 1, - 0, + 1, ), min_marker_gene_max_fraction::AbstractFloat = function_default( identify_marker_genes!, :min_marker_gene_max_fraction, ), rng::AbstractRNG = default_rng(), + overwrite::Bool = false, )::Nothing @assert gene_fraction_regularization >= 0 @assert min_rms_improvement >= 0 - @assert cross_validation > 0 + @assert cross_validation_parts > 0 @assert min_marker_gene_range_fold >= 0 @assert min_marker_gene_max_fraction >= 0 @@ -1454,18 +1570,24 @@ When computing this set, we only consider the genes that are marker genes within block_block_is_in_environment = get_matrix(daf, "block", "block", "is_in_environment") gene_block_is_local_predictive_factor = zeros(Bool, context.n_genes, blocks.n_blocks) + gene_block_is_local_marker = zeros(Bool, context.n_genes, blocks.n_blocks) + gene_block_local_rms = zeros(Float32, context.n_genes, blocks.n_blocks) + gene_block_local_r2 = zeros(Float32, context.n_genes, blocks.n_blocks) for block_index in 1:(blocks.n_blocks) blocks_in_neighborhood = findall(block_block_is_in_neighborhood[:, block_index]) blocks_in_environment = findall(block_block_is_in_environment[:, block_index]) @views block_local_predictive_factors_mask = gene_block_is_local_predictive_factor[:, block_index] + @views block_local_markers_mask = gene_block_is_local_marker[:, block_index] + @views block_local_rms_of_genes = gene_block_local_rms[:, block_index] + @views block_local_r2_of_genes = gene_block_local_r2[:, block_index] compute_local_predictive_factors_of_block(; context = context, global_predictive_genes = global_predictive_genes, candidate_factors = candidate_factors, blocks = blocks, gene_fraction_regularization = gene_fraction_regularization, - cross_validation = cross_validation, + cross_validation_parts = cross_validation_parts, min_marker_gene_range_fold = min_marker_gene_range_fold, min_marker_gene_max_fraction = min_marker_gene_max_fraction, rng = rng, @@ -1473,16 +1595,36 @@ When computing this set, we only consider the genes that are marker genes within blocks_in_neighborhood = blocks_in_neighborhood, blocks_in_environment = blocks_in_environment, block_local_predictive_factors_mask = block_local_predictive_factors_mask, + block_local_markers_mask = block_local_markers_mask, + block_local_rms_of_genes = block_local_rms_of_genes, + block_local_r2_of_genes = block_local_r2_of_genes, + overwrite = overwrite, ) end - return set_matrix!( + set_matrix!( daf, "gene", "block", "is_local_predictive_factor", - SparseMatrixCSC(gene_block_is_local_predictive_factor), + SparseMatrixCSC(gene_block_is_local_predictive_factor); + overwrite = overwrite, ) + + set_matrix!( + daf, + "gene", + "block", + "is_local_marker", + SparseMatrixCSC(gene_block_is_local_marker); + overwrite = overwrite, + ) + + set_matrix!(daf, "gene", "block", "local_rms", SparseMatrixCSC(gene_block_local_rms); overwrite = overwrite) + + set_matrix!(daf, "gene", "block", "local_r2", SparseMatrixCSC(gene_block_local_r2); overwrite = overwrite) + + return nothing end function compute_local_predictive_factors_of_block(; # untested @@ -1491,7 +1633,7 @@ function compute_local_predictive_factors_of_block(; # untested candidate_factors::CandidateFactors, blocks::Blocks, gene_fraction_regularization::AbstractFloat, - cross_validation::Integer, + cross_validation_parts::Integer, min_marker_gene_range_fold::Real, min_marker_gene_max_fraction::AbstractFloat, rng::AbstractRNG, @@ -1499,6 +1641,10 @@ function compute_local_predictive_factors_of_block(; # untested blocks_in_neighborhood::AbstractVector{<:Integer}, blocks_in_environment::AbstractVector{<:Integer}, block_local_predictive_factors_mask::Union{AbstractVector{Bool}, BitVector}, + block_local_markers_mask::Union{AbstractVector{Bool}, BitVector}, + block_local_rms_of_genes::AbstractVector{<:AbstractFloat}, + block_local_r2_of_genes::AbstractVector{<:AbstractFloat}, + overwrite::Bool, )::Nothing neighborhood_metacells_mask = region_metacells_mask(blocks, blocks_in_neighborhood) neighborhood_metacells = findall(neighborhood_metacells_mask) @@ -1513,14 +1659,16 @@ function compute_local_predictive_factors_of_block(; # untested min_marker_gene_range_fold = min_marker_gene_range_fold, min_marker_gene_max_fraction = min_marker_gene_max_fraction, environment_metacells_mask = environment_metacells_mask, + overwrite = overwrite, ), ) + block_local_markers_mask .= environment_genes_mask environment_genes_mask[global_predictive_genes] .= true environment_genes = findall(environment_genes_mask) - local_predictive_genes = identify_predictive_genes!(; + local_predictive_genes, cross_validation = identify_predictive_genes!(; context = context, - cross_validation = cross_validation, + cross_validation_parts = cross_validation_parts, included_genes = environment_genes, included_metacells = environment_metacells, core_metacells = neighborhood_metacells, @@ -1528,9 +1676,13 @@ function compute_local_predictive_factors_of_block(; # untested rng = rng, ) - @debug "Block $(block_index) : [ $(join(context.names_of_genes[local_predictive_genes], ", ")) ]" + @debug "Block $(block_index) Predictive : [ $(join(context.names_of_genes[local_predictive_genes], ", ")) ]" + @debug "Block $(block_index) Markers : $(length(environment_genes))" + @debug "Block $(block_index) RMS : $(mean(cross_validation.rms_of_genes[environment_genes])) R^2 : $(mean(cross_validation.r2_of_genes[environment_genes]))" block_local_predictive_factors_mask[local_predictive_genes] .= true + block_local_rms_of_genes .= cross_validation.rms_of_genes + block_local_r2_of_genes .= cross_validation.r2_of_genes reused = 0 added = 0 @@ -1561,9 +1713,10 @@ function marker_genes_of_environment(; # untested min_marker_gene_range_fold::Real, min_marker_gene_max_fraction::AbstractFloat, environment_metacells_mask::Union{AbstractVector{Bool}, BitVector}, + overwrite::Bool, )::Union{AbstractVector{Bool}, BitVector} chain = chain_writer([context.daf, MemoryDaf(; name = "environment")]; name = "mask_chain") - set_vector!(chain, "metacell", "is_in_environment", environment_metacells_mask) + set_vector!(chain, "metacell", "is_in_environment", environment_metacells_mask; overwrite = overwrite) adapter( # NOJET chain; input_axes = ["metacell" => "/metacell & is_in_environment", "gene" => "="], @@ -1583,4 +1736,136 @@ function marker_genes_of_environment(; # untested return get_vector(chain, "gene", "is_marker").array end +""" + function compute_programs!( + daf::DafWriter; + gene_fraction_regularization::AbstractFloat = $(DEFAULT.gene_fraction_regularization), + overwrite::Bool = $(DEFAULT.overwrite), + )::Nothing + +Having computed, for each block, the set of transcription factors for best approximating the gene expression of the +metacells in each neighborhood based on the environment, then compute the actual coefficients for doing this prediction. +If `overwrite`, will overwrite existing data. + +The program for predicting the value of a gene `i` based on the values of the predictive transcription factors `j` is +expressed as `G_i = MG_i + Sum P_ij (F_j - MF_j)`, where `MG_i` is the mean of the (log base 2 of the) expression of the +gene `i` in the environment, and similarly `MF_j` is the mean of the (log base 2 of the) expression of the factor `j` in +the same environment; and the `P_ij` is the coefficient for computing the (log base 2 of the) expression of gene `i` +using the (log base 2 of the) expression of the factor `F_j`. + +We store the `gene_fraction_regularization` used to compute the log base 2 of the fractions as a scalar called +`program_gene_fraction_regularization`, and the means as the per-block-per-gene matrix `program_mean_log_fraction`. The +coefficients themselves are stored as a set of matrices `_program_coefficient` which hold the `P_ij` matrix +(columns for transcription factors, rows for genes). + +$(CONTRACT) + +**gene, gene @ _program_coefficient**::AbstractFloat (guaranteed): The (sparse) coefficient for predicting +the (log base 2) expression of each (row) gene based on the (log base 2) expression of each (column) transcription +factor. +""" +@logged @computation Contract( + is_relaxed = true, + axes = [gene_axis(RequiredInput), metacell_axis(RequiredInput), block_axis(RequiredInput)], + data = [ + gene_metacell_fraction_matrix(RequiredInput), + gene_divergence_vector(RequiredInput), + metacell_block_vector(RequiredInput), + block_block_is_in_environment_matrix(RequiredInput), + gene_block_is_local_predictive_factor_matrix(RequiredInput), + program_gene_fraction_regularization_scalar(GuaranteedOutput), + gene_block_program_mean_log_fraction_matrix(GuaranteedOutput), + ], +) function compute_programs!( # untested + daf::DafWriter; + gene_fraction_regularization::AbstractFloat = function_default( + compute_factor_priority_of_genes!, + :gene_fraction_regularization, + ), + overwrite::Bool = false, +)::Nothing + @assert gene_fraction_regularization >= 0 + + context = load_context(daf; gene_fraction_regularization = gene_fraction_regularization) + + blocks = load_blocks(daf) + block_block_is_in_environment = get_matrix(daf, "block", "block", "is_in_environment").array + gene_block_is_local_predictive_factor = get_matrix(daf, "gene", "block", "is_local_predictive_factor").array + names_of_blocks = axis_array(daf, "block") + + set_scalar!(daf, "program_gene_fraction_regularization", gene_fraction_regularization; overwrite = overwrite) + + n_genes = axis_length(daf, "gene") + n_blocks = axis_length(daf, "block") + mean_log_fraction_of_genes_in_blocks = zeros(Float32, n_genes, n_blocks) + + for block_index in 1:(blocks.n_blocks) + block_name = names_of_blocks[block_index] + blocks_in_environment = findall(block_block_is_in_environment[:, block_index]) + environment_metacells_mask = region_metacells_mask(blocks, blocks_in_environment) + + @debug "Block $(block_name) ..." + + @views block_local_predictive_factors_mask = gene_block_is_local_predictive_factor[:, block_index] + predictive_genes = findall(block_local_predictive_factors_mask) + least_squares = solve_least_squares(; + context = context, + predictive_genes = predictive_genes, + included_genes = 1:n_genes, + included_metacells = findall(environment_metacells_mask), + ) + + mean_log_fraction_of_genes_in_blocks[:, block_index] .= least_squares.mean_log_fractions_of_included_genes + + coefficients_of_included_genes_of_predictive_genes = + SparseMatrixCSC(least_squares.coefficients_of_included_genes_of_predictive_genes) + + empty_sparse_matrix!( + daf, + "gene", + "gene", + "$(block_name)_program_coefficient", + eltype(coefficients_of_included_genes_of_predictive_genes), + nnz(coefficients_of_included_genes_of_predictive_genes); + overwrite = overwrite, + ) do empty_colptr, empty_rowval, empty_nzval + empty_rowval .= coefficients_of_included_genes_of_predictive_genes.rowval + empty_nzval .= coefficients_of_included_genes_of_predictive_genes.nzval + + next_predictive_gene_position = 1 + next_predictive_gene_index = predictive_genes[1] + empty_colptr[1] = 1 + for gene_index in 1:n_genes + if gene_index < next_predictive_gene_index + empty_colptr[gene_index + 1] = empty_colptr[gene_index] + else + @assert gene_index == next_predictive_gene_index + empty_colptr[gene_index + 1] = + coefficients_of_included_genes_of_predictive_genes.colptr[next_predictive_gene_position + 1] + next_predictive_gene_position += 1 + if next_predictive_gene_position <= length(predictive_genes) + next_predictive_gene_index = predictive_genes[next_predictive_gene_position] + else + next_predictive_gene_index = n_genes + 1 + end + end + end + @assert next_predictive_gene_position == length(predictive_genes) + 1 + @assert next_predictive_gene_index == n_genes + 1 + @assert empty_colptr[n_genes + 1] == length(empty_nzval) + 1 + end + end + + set_matrix!( + daf, + "gene", + "block", + "program_mean_log_fraction", + mean_log_fraction_of_genes_in_blocks; + overwrite = overwrite, + ) + + return nothing +end + end # module