Skip to content

Commit

Permalink
WIP.
Browse files Browse the repository at this point in the history
  • Loading branch information
orenbenkiki committed Aug 7, 2024
1 parent b3acb34 commit 3c9616e
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 58 deletions.
2 changes: 1 addition & 1 deletion docs/v0.1.0/.documenter-siteinfo.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"documenter":{"julia_version":"1.10.4","generation_timestamp":"2024-08-05T16:26:59","documenter_version":"1.5.0"}}
{"documenter":{"julia_version":"1.10.4","generation_timestamp":"2024-08-05T18:39:24","documenter_version":"1.5.0"}}
28 changes: 28 additions & 0 deletions docs/v0.1.0/contracts.html
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,28 @@ <h2 id="Vectors">
<header>
<a class="docstring-article-toggle-button fa-solid fa-chevron-down" href="javascript:;" title="Collapse docstring">
</a>
<a class="docstring-binding" id="Metacells.Contracts.gene_is_global_predictive_factor_vector" href="#Metacells.Contracts.gene_is_global_predictive_factor_vector">
<code>Metacells.Contracts.gene_is_global_predictive_factor_vector
</code>
</a>
<span class="docstring-category">Function
</span>
</header>
<section>
<div>
<pre>
<code class="language-julia hljs">function gene_is_global_predictive_factor_vector(expectation::ContractExpectation)::Pair{VectorKey, DataSpecification}
</code>
</pre>
<p>A mask of globally predictive transcription factors. That is, knowing the values of all these genes allows us to predict the values of the rest of the genes (but not as well as when using the locally predictive transcription factors).
</p>
</div>
</section>
</article>
<article class="docstring">
<header>
<a class="docstring-article-toggle-button fa-solid fa-chevron-down" href="javascript:;" title="Collapse docstring">
</a>
<a class="docstring-binding" id="Metacells.Contracts.gene_divergence_vector" href="#Metacells.Contracts.gene_divergence_vector">
<code>Metacells.Contracts.gene_divergence_vector
</code>
Expand Down Expand Up @@ -676,6 +698,12 @@ <h2 id="Index">
</a>
</li>
<li>
<a href="contracts.html#Metacells.Contracts.gene_is_global_predictive_factor_vector">
<code>Metacells.Contracts.gene_is_global_predictive_factor_vector
</code>
</a>
</li>
<li>
<a href="contracts.html#Metacells.Contracts.gene_is_lateral_vector">
<code>Metacells.Contracts.gene_is_lateral_vector
</code>
Expand Down
6 changes: 6 additions & 0 deletions docs/v0.1.0/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,12 @@ <h1 id="Index">
</a>
</li>
<li>
<a href="contracts.html#Metacells.Contracts.gene_is_global_predictive_factor_vector">
<code>Metacells.Contracts.gene_is_global_predictive_factor_vector
</code>
</a>
</li>
<li>
<a href="contracts.html#Metacells.Contracts.gene_is_lateral_vector">
<code>Metacells.Contracts.gene_is_lateral_vector
</code>
Expand Down
Binary file modified docs/v0.1.0/objects.inv
Binary file not shown.
2 changes: 1 addition & 1 deletion docs/v0.1.0/search_index.js

Large diffs are not rendered by default.

18 changes: 15 additions & 3 deletions src/contracts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ levels (based on the number of UMIs used for the estimates) and the [`gene_diver
module Contracts

export block_axis
export block_block_distance
export block_block_distance_matrix
export cell_axis
export gene_axis
export gene_divergence_vector
export gene_is_global_predictive_factor_vector
export gene_is_lateral_vector
export gene_is_marker_vector
export gene_is_transcription_factor_vector
Expand Down Expand Up @@ -145,6 +146,17 @@ function gene_is_transcription_factor_vector(expectation::ContractExpectation)::
return ("gene", "is_transcription_factor") => (expectation, Bool, "A mask of genes that bind to the DNA.")
end

"""
function gene_is_global_predictive_factor_vector(expectation::ContractExpectation)::Pair{VectorKey, DataSpecification}
A mask of globally predictive transcription factors. That is, knowing the values of all these genes allows us to predict the
values of the rest of the genes (but not as well as when using the locally predictive transcription factors).
"""
function gene_is_global_predictive_factor_vector(expectation::ContractExpectation)::Pair{VectorKey, DataSpecification} # untested
return ("gene", "is_global_predictive_factor") =>
(expectation, Bool, "A mask of globally predictive transcription factors.")
end

"""
function cell_is_excluded_vector(expectation::ContractExpectation)::Pair{VectorKey, DataSpecification}
Expand Down Expand Up @@ -226,13 +238,13 @@ function gene_metacell_total_UMIs_matrix(expectation::ContractExpectation)::Pair
end

"""
function block_block_distance(expectation::ContractExpectation)::Pair{MatrixKey, DataSpecification}
function block_block_distance_matrix(expectation::ContractExpectation)::Pair{MatrixKey, DataSpecification}
The distance (fold factor) between the most different metacell genes between the blocks. This is the fold factor between
the most different gene expression in a pair of metacells, one in each block. This considers only the global predictive
genes.
"""
function block_block_distance(expectation::ContractExpectation)::Pair{MatrixKey, DataSpecification} # untested
function block_block_distance_matrix(expectation::ContractExpectation)::Pair{MatrixKey, DataSpecification} # untested
return ("block", "block", "distance") => (
expectation,
StorageFloat,
Expand Down
1 change: 1 addition & 0 deletions src/contracts.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Metacells.Contracts.gene_is_excluded_vector
Metacells.Contracts.gene_is_lateral_vector
Metacells.Contracts.gene_is_marker_vector
Metacells.Contracts.gene_is_transcription_factor_vector
Metacells.Contracts.gene_is_global_predictive_factor_vector
Metacells.Contracts.gene_divergence_vector
Metacells.Contracts.cell_is_excluded_vector
Metacells.Contracts.metacell_total_UMIs_vector
Expand Down
155 changes: 102 additions & 53 deletions src/programs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,17 @@ end
rng::AbstractRNG = default_rng(),
)::Nothing
TODOY
TODOX
"""
@logged @computation function compute_global_predictive_factors!( # untested
@logged @computation Contract(
axes = [gene_axis(RequiredInput), metacell_axis(RequiredInput)],
data = [
gene_metacell_fraction_matrix(RequiredInput),
gene_divergence_vector(RequiredInput),
gene_is_transcription_factor_vector(RequiredInput),
gene_is_global_predictive_factor_vector(GuaranteedOutput),
],
) function compute_global_predictive_factors!( # untested
daf::DafWriter;
gene_fraction_regularization::AbstractFloat = GENE_FRACTION_REGULARIZATION,
max_principal_components::Integer = 30,
Expand Down Expand Up @@ -543,7 +551,7 @@ end

function try_add_factor!(
context::Context,
::Integer,
test_index::Integer,
factor_gene_index::Integer,
last_removed_gene_index::Maybe{<:Integer},
)::Bool
Expand All @@ -568,21 +576,21 @@ function try_add_factor!(
@assert context.cross_validation !== nothing

improvement =
current_cross_validation.mean_rms * (1 + current_n_predictive_genes * 1e-2) -
cross_validation.mean_rms * (1 + n_predictive_genes * 1e-2)
current_cross_validation.mean_rms * (1 + 1e-2) ^ current_n_predictive_genes -
cross_validation.mean_rms * (1 + 1e-2) ^ n_predictive_genes

if improvement > 0
# print(
# stderr,
# "TODOY # $(current_n_predictive_genes) ++ $(test_index) / $(length(context.ordered_factor_gene_indices)) / $(context.data.factor_genes.n_entries) $(context.data.names_of_genes[factor_gene_index]) -> $(context.cross_validation.mean_rms) > $(current_cross_validation.mean_rms) ... \r",
# )
print(
stderr,
"TODOX # $(current_n_predictive_genes) ++ $(test_index) / $(length(context.ordered_factor_gene_indices)) / $(context.data.factor_genes.n_entries) $(context.data.names_of_genes[factor_gene_index]) -> $(context.cross_validation.mean_rms) > $(current_cross_validation.mean_rms) ... \r",
)
return true
end

# print(
# stderr,
# "TODOY # $(current_n_predictive_genes) ?+ $(test_index) / $(length(context.ordered_factor_gene_indices)) / $(context.data.factor_genes.n_entries) $(context.data.names_of_genes[factor_gene_index]) -> $(context.cross_validation.mean_rms) <~ $(current_cross_validation.mean_rms) ... \r",
# )
print(
stderr,
"TODOX # $(current_n_predictive_genes) ?+ $(test_index) / $(length(context.ordered_factor_gene_indices)) / $(context.data.factor_genes.n_entries) $(context.data.names_of_genes[factor_gene_index]) -> $(context.cross_validation.mean_rms) <~ $(current_cross_validation.mean_rms) ... \r",
)

pop_predictive!(context, factor_gene_index)
context.cross_validation = current_cross_validation
Expand All @@ -606,7 +614,7 @@ end

function try_remove_factor!(
context,
::Integer,
predictive_index::Integer,
factor_gene_index::Integer,
last_added_gene_index::Maybe{<:Integer},
)::Bool
Expand Down Expand Up @@ -634,20 +642,20 @@ function try_remove_factor!(
@assert context.cross_validation !== nothing

improvement =
current_cross_validation.mean_rms * (1 + current_n_predictive_genes * 1e-2) -
cross_validation.mean_rms * (1 + n_predictive_genes * 1e-2)
current_cross_validation.mean_rms * (1 + 1e-2) ^ current_n_predictive_genes -
cross_validation.mean_rms * (1 + 1e-2) ^ n_predictive_genes
if improvement >= 0 # -1e-3
# print(
# stderr,
# "TODOY # $(current_n_predictive_genes) -- $(predictive_index) / $(current_n_predictive_genes) / $(context.data.factor_genes.n_entries) $(context.data.names_of_genes[factor_gene_index]) -> $(context.cross_validation.mean_rms) <= $(current_cross_validation.mean_rms) ... \r",
# )
print(
stderr,
"TODOX # $(current_n_predictive_genes) -- $(predictive_index) / $(current_n_predictive_genes) / $(context.data.factor_genes.n_entries) $(context.data.names_of_genes[factor_gene_index]) -> $(context.cross_validation.mean_rms) <= $(current_cross_validation.mean_rms) ... \r",
)
return true
end

# print(
# stderr,
# "TODOY # $(current_n_predictive_genes) ?- $(predictive_index) / $(current_n_predictive_genes) / $(context.data.factor_genes.n_entries) $(context.data.names_of_genes[factor_gene_index]) -> $(context.cross_validation.mean_rms) > $(current_cross_validation.mean_rms) ... \r",
# )
print(
stderr,
"TODOX # $(current_n_predictive_genes) ?- $(predictive_index) / $(current_n_predictive_genes) / $(context.data.factor_genes.n_entries) $(context.data.names_of_genes[factor_gene_index]) -> $(context.cross_validation.mean_rms) > $(current_cross_validation.mean_rms) ... \r",
)

reset_predictive!(context, current_predictive_genes_indices)
context.cross_validation = current_cross_validation
Expand Down Expand Up @@ -769,6 +777,7 @@ function solve_least_squares!(context::Context, selected_metacells::Union{SomeMa
alg = :fnnls,
)
@assert_matrix(coefficients_of_predictive_genes_of_genes, predictive_genes.n_entries, context.data.n_genes, Columns)
# TODOX coefficients_of_predictive_genes_of_genes[coefficients_of_predictive_genes_of_genes .< 1e-2] .= 0

context.least_squares =
least_squares = LeastSquares(;
Expand Down Expand Up @@ -1037,8 +1046,22 @@ end
min_metacells_in_neighborhood::Integer = $(DEFAULT.min_metacells_in_neighborhood),
rng::AbstractRNG = default_rng(),
)::Nothing
TODOX
"""
@logged @computation function compute_local_predictive_factors!( # untested
@logged @computation Contract(
axes = [gene_axis(RequiredInput), metacell_axis(RequiredInput), block_axis(GuaranteedOutput)],
data = [
gene_metacell_fraction_matrix(RequiredInput),
gene_divergence_vector(RequiredInput),
gene_is_transcription_factor_vector(RequiredInput),
gene_is_global_predictive_factor_vector(RequiredInput),
metacell_total_UMIs_vector(RequiredInput),
gene_metacell_total_UMIs_matrix(RequiredInput),
metacell_block_vector(GuaranteedOutput),
block_block_distance_matrix(GuaranteedOutput),
],
) function compute_local_predictive_factors!( # untested
daf::DafWriter;
min_significant_gene_UMIs::Integer = 40,
gene_fraction_regularization::AbstractFloat = GENE_FRACTION_REGULARIZATION,
Expand Down Expand Up @@ -1070,10 +1093,17 @@ end
order_factors!(context)
load_predictive!(context, daf)
compute_confidence!(context, daf)
compute_blocks!(context)

blocks = compute_blocks!(context)
block_names = group_names(daf, "metacell", blocks.metacells_of_blocks; prefix = "B")
blocks_of_metacells = block_names[blocks.blocks_of_metacells]
add_axis!(daf, "block", block_names)
set_vector!(daf, "metacell", "block", blocks_of_metacells)
set_matrix!(daf, "block", "block", "distance", blocks.distances_between_blocks)

compute_environments!(context)

@assert false
return nothing
end

@logged function load_predictive!(context::Context, daf::DafReader)::SomeIndices
Expand All @@ -1098,29 +1128,21 @@ end
total_UMIs_of_genes_in_metacells = get_matrix(daf, "gene", "metacell", "total_UMIs").array
@assert_matrix(total_UMIs_of_genes_in_metacells, context.data.n_genes, context.data.n_metacells, Columns)

confidence_stdevs = quantile(Normal(), context.parameters.fold_confidence)
confidence_fractions_of_genes_in_metacells = # NOJET
confidence_stdevs .*
sqrt.(transpose(total_UMIs_of_metacells) .* context.data.fractions_of_genes_in_metacells) ./
transpose(total_UMIs_of_metacells)

log_decreased_fractions_of_genes_in_metacells =
log2.(
max.(context.data.fractions_of_genes_in_metacells .- confidence_fractions_of_genes_in_metacells, 0.0) .+
context.parameters.gene_fraction_regularization
log_decreased_fractions_of_genes_in_metacells, log_increased_fractions_of_genes_in_metacells =
compute_confidence_log_fraction_of_genes_in_metacells(;
gene_fraction_regularization = context.parameters.gene_fraction_regularization,
fractions_of_genes_in_metacells = context.data.fractions_of_genes_in_metacells,
total_UMIs_of_metacells = total_UMIs_of_metacells,
fold_confidence = context.parameters.fold_confidence,
)

@assert_matrix(
log_decreased_fractions_of_genes_in_metacells,
context.data.n_genes,
context.data.n_metacells,
Columns
)

log_increased_fractions_of_genes_in_metacells =
log2.(
context.data.fractions_of_genes_in_metacells .+ confidence_fractions_of_genes_in_metacells .+
context.parameters.gene_fraction_regularization
)
@assert_matrix(
log_increased_fractions_of_genes_in_metacells,
context.data.n_genes,
Expand All @@ -1138,6 +1160,33 @@ end
return confidence
end

function compute_confidence_log_fraction_of_genes_in_metacells(;
gene_fraction_regularization::AbstractFloat,
fractions_of_genes_in_metacells::AbstractMatrix{<:AbstractFloat},
total_UMIs_of_metacells::AbstractVector{<:Unsigned},
fold_confidence::AbstractFloat,
)::Tuple{AbstractMatrix{<:AbstractFloat}, AbstractMatrix{<:AbstractFloat}}
confidence_stdevs = quantile(Normal(), fold_confidence)

confidence_fractions_of_genes_in_metacells = # NOJET
confidence_stdevs .* sqrt.(transpose(total_UMIs_of_metacells) .* fractions_of_genes_in_metacells) ./
transpose(total_UMIs_of_metacells)

log_decreased_fractions_of_genes_in_metacells =
log2.(
max.(fractions_of_genes_in_metacells .- confidence_fractions_of_genes_in_metacells, 0.0) .+
gene_fraction_regularization
)

log_increased_fractions_of_genes_in_metacells =
log2.(
fractions_of_genes_in_metacells .+ confidence_fractions_of_genes_in_metacells .+
gene_fraction_regularization
)

return (log_decreased_fractions_of_genes_in_metacells, log_increased_fractions_of_genes_in_metacells)
end

@logged function compute_blocks!(context::Context)::Blocks
distances_between_metacells = compute_distances_between_metacells(context)

Expand Down Expand Up @@ -1372,12 +1421,12 @@ end
end
end

# n_block_metacells = sum(blocks.blocks_of_metacells .== block_index)
# @debug "TODOY $(block_index): $(n_block_metacells) metacells => $(depict(environments[block_index]))"
# @debug "TODOY Predictive genes: total: $(total) reused: $(reused) added: $(added) removed: $(removed)"
# @debug "TODOY Variability RMS: global: $(@sprintf("%.5f", environment.global_analysis.variability.mean_rms)) by_global: $(@sprintf("%.5f", environment.by_global_analysis.variability.mean_rms)) local: $(@sprintf("%.5f", environment.local_analysis.variability.mean_rms))"
# @debug "TODOY Analysis RMS: global: $(@sprintf("%.5f", environment.global_analysis.cross_validation.mean_rms)) by_global: $(@sprintf("%.5f", environment.by_global_analysis.cross_validation.mean_rms)) local: $(@sprintf("%.5f", environment.local_analysis.cross_validation.mean_rms))"
# @debug "TODOY Analysis R2 global: $(@sprintf("%.5f", environment.global_analysis.cross_validation.mean_r2)) by_global: $(@sprintf("%.5f", environment.by_global_analysis.cross_validation.mean_r2)) local: $(@sprintf("%.5f", environment.local_analysis.cross_validation.mean_r2))"
n_block_metacells = sum(blocks.blocks_of_metacells .== block_index)
@debug "TODOX $(block_index): $(n_block_metacells) metacells => $(depict(environments[block_index]))"
@debug "TODOX Predictive genes: total: $(total) reused: $(reused) added: $(added) removed: $(removed)"
@debug "TODOX Variability RMS: global: $(@sprintf("%.5f", environment.global_analysis.variability.mean_rms)) by_global: $(@sprintf("%.5f", environment.by_global_analysis.variability.mean_rms)) local: $(@sprintf("%.5f", environment.local_analysis.variability.mean_rms))"
@debug "TODOX Analysis RMS: global: $(@sprintf("%.5f", environment.global_analysis.cross_validation.mean_rms)) by_global: $(@sprintf("%.5f", environment.by_global_analysis.cross_validation.mean_rms)) local: $(@sprintf("%.5f", environment.local_analysis.cross_validation.mean_rms))"
@debug "TODOX Analysis R2 global: $(@sprintf("%.5f", environment.global_analysis.cross_validation.mean_r2)) by_global: $(@sprintf("%.5f", environment.by_global_analysis.cross_validation.mean_r2)) local: $(@sprintf("%.5f", environment.local_analysis.cross_validation.mean_r2))"
end

open("blocks_qc.csv", "w") do file
Expand Down Expand Up @@ -1617,7 +1666,7 @@ end

function analyze_environment(
context::Context,
::Integer,
n_blocks_in_neighborhood::Integer,
next_n_blocks::Integer,
ordered_block_indices::Vector{<:Integer},
)::Analysis
Expand All @@ -1631,10 +1680,10 @@ function analyze_environment(
analyze_left_outs!(context)
analysis = final_analysis!(context)

# print(
# stderr,
# "TODOY $(context.included_metacells.n_entries) metacells in $(n_blocks_in_neighborhood) <= $(next_n_blocks) => RMS: $(analysis.cross_validation.mean_rms) ... \r",
# )
print(
stderr,
"TODOX $(context.included_metacells.n_entries) metacells in $(n_blocks_in_neighborhood) <= $(next_n_blocks) => RMS: $(analysis.cross_validation.mean_rms) ... \r",
)

return analysis
end
Expand Down

0 comments on commit 3c9616e

Please sign in to comment.