Skip to content

Commit

Permalink
use the correct kind of transaction for write operations, add sticker…
Browse files Browse the repository at this point in the history
… type enum
  • Loading branch information
avoonix committed Jul 13, 2024
1 parent 01af50d commit 040811d
Show file tree
Hide file tree
Showing 14 changed files with 52 additions and 39 deletions.
2 changes: 2 additions & 0 deletions fuzzle/migrations/3_sticker_type_and_merging/down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE sticker_file ADD COLUMN is_animated BOOLEAN NOT NULL CHECK (is_animated IN (0, 1)) DEFAULT 1,
ALTER TABLE sticker_file DROP COLUMN sticker_type;
2 changes: 2 additions & 0 deletions fuzzle/migrations/3_sticker_type_and_merging/up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE sticker_file DROP COLUMN is_animated;
ALTER TABLE sticker_file ADD COLUMN sticker_type INTEGER NOT NULL DEFAULT 0;
9 changes: 9 additions & 0 deletions fuzzle/src/database/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ pub enum MergeStatus {
NotMerged = 2,
}

#[derive(PartialEq, Debug, Copy, Clone, Primitive, AsExpression, FromSqlRow)]
#[diesel(sql_type = diesel::sql_types::BigInt)]
pub enum StickerType {
Animated = 0,
Video = 1,
Static = 2,
}

macro_rules! impl_enum {
($type_name:ty) => {
impl serialize::ToSql<diesel::sql_types::BigInt, Sqlite> for $type_name
Expand Down Expand Up @@ -214,6 +222,7 @@ macro_rules! impl_enum {
}

impl_enum!(MergeStatus);
impl_enum!(StickerType);
impl_enum!(Category);

#[derive(Debug, Serialize, Deserialize, Clone, Default)]
Expand Down
7 changes: 3 additions & 4 deletions fuzzle/src/database/queries/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@ impl diesel::r2d2::CustomizeConnection<SqliteConnection, diesel::r2d2::Error>
{
fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), diesel::r2d2::Error> {
(|| {
conn.batch_execute(&format!("PRAGMA busy_timeout = {};", Duration::from_secs(30).as_millis()))?;
conn.batch_execute("PRAGMA journal_mode = WAL; PRAGMA synchronous = NORMAL;")?;
conn.batch_execute("PRAGMA foreign_keys = ON;")?;
conn.batch_execute("PRAGMA temp_store = MEMORY;")?;
conn.batch_execute(&format!("PRAGMA busy_timeout = {};", Duration::from_secs(60).as_millis()))?;
conn.batch_execute("PRAGMA journal_mode = WAL; PRAGMA synchronous = NORMAL; PRAGMA wal_autocheckpoint = 1000; PRAGMA wal_checkpoint(TRUNCATE);")?;
conn.batch_execute("PRAGMA foreign_keys = ON; PRAGMA temp_store = MEMORY;")?;
conn.batch_execute(&format!("PRAGMA cache_size = {};", format_cache_size(2 * GB)))?;
conn.batch_execute(&format!("PRAGMA mmap_size = {};", (4 * GB).to_string()))?;
// TODO: need to run optimize before close as well
Expand Down
14 changes: 7 additions & 7 deletions fuzzle/src/database/queries/sticker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use tracing::warn;

use crate::{
database::{
StickerFile, MergeStatus, StickerSet, StickerUser, min_max, query_builder::StickerTagQuery, Order, Sticker, StickerIdStickerFileId
min_max, query_builder::StickerTagQuery, MergeStatus, Order, Sticker, StickerFile, StickerIdStickerFileId, StickerSet, StickerType, StickerUser
},
util::Emoji,
};
Expand Down Expand Up @@ -54,10 +54,10 @@ impl Database {
}

#[tracing::instrument(skip(self), err(Debug))]
pub async fn ___temp___update_is_animated(
pub async fn ___temp___update_sticker_type(
&self,
sticker_id: &str,
is_animated: bool,
sticker_type: StickerType,
) -> Result<(), DatabaseError> {
let updated_rows = diesel::update(sticker_file::table)
.filter(
Expand All @@ -67,7 +67,7 @@ impl Database {
.select((sticker::sticker_file_id)),
),
)
.set(sticker_file::is_animated.eq(is_animated))
.set(sticker_file::sticker_type.eq(sticker_type))
.execute(&mut self.pool.get()?)?;
#[cfg(debug_assertions)]
assert_eq!(updated_rows, 1);
Expand All @@ -93,13 +93,13 @@ impl Database {
&self,
sticker_file_id: &str,
thumbnail_file_id: Option<String>,
is_animated: bool,
sticker_type: StickerType,
) -> Result<(), DatabaseError> {
let q = insert_into(sticker_file::table)
.values((
sticker_file::id.eq(sticker_file_id),
sticker_file::thumbnail_file_id.eq(thumbnail_file_id.clone()),
sticker_file::is_animated.eq(is_animated),
sticker_file::sticker_type.eq(sticker_type),
))
.on_conflict(sticker_file::id);
if let Some(thumbnail_file_id) = thumbnail_file_id {
Expand Down Expand Up @@ -492,7 +492,7 @@ impl Database {
duplicate_file_id: &str,
user_id: Option<i64>,
) -> Result<(), DatabaseError> {
Ok(self.pool.get()?.transaction(|conn| {
Ok(self.pool.get()?.immediate_transaction(|conn| {
let stickers_affected_merge = sql_query("INSERT INTO merged_sticker (canonical_sticker_file_id, removed_sticker_file_id, removed_sticker_id, removed_sticker_set_id, created_by_user_id)
SELECT ?1, sticker_file_id, id, sticker_set_id, ?2 FROM sticker WHERE sticker_file_id = ?3")
.bind::<Text, _>(canonical_file_id)
Expand Down
9 changes: 5 additions & 4 deletions fuzzle/src/database/queries/sticker_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl Database {
title: &str,
created_by_user_id: i64,
) -> Result<(), DatabaseError> {
self.pool.get()?.transaction(|conn| {
self.pool.get()?.immediate_transaction(|conn| {
self.check_removed(id, conn)?;
insert_into(sticker_set::table)
.values((
Expand All @@ -61,7 +61,7 @@ impl Database {
id: &str,
title: &str,
) -> Result<(), DatabaseError> {
self.pool.get()?.transaction(|conn| {
self.pool.get()?.immediate_transaction(|conn| {
self.check_removed(id, conn)?;
insert_into(sticker_set::table)
.values((sticker_set::id.eq(id), sticker_set::title.eq(title)))
Expand All @@ -81,7 +81,7 @@ impl Database {
id: &str,
added_by_user_id: i64,
) -> Result<(), DatabaseError> {
self.pool.get()?.transaction(|conn| {
self.pool.get()?.immediate_transaction(|conn| {
self.check_removed(id, conn)?;
insert_into(sticker_set::table)
.values((
Expand All @@ -98,7 +98,8 @@ impl Database {
fn check_removed(
&self,
set_id: &str,
conn: &mut PooledConnection<diesel::r2d2::ConnectionManager<SqliteConnection>>,
// conn: &mut PooledConnection<diesel::r2d2::ConnectionManager<SqliteConnection>>,
conn: &mut SqliteConnection,
) -> Result<(), DatabaseError> {
let removed: Option<String> = removed_set::table
.filter(removed_set::id.eq(set_id))
Expand Down
11 changes: 6 additions & 5 deletions fuzzle/src/database/queries/sticker_tagging.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl Database {
tag_names: &[String],
user: Option<i64>,
) -> Result<(), DatabaseError> {
self.pool.get()?.transaction(|conn| {
self.pool.get()?.immediate_transaction(|conn| {
for tag in tag_names {
let inserted = insert_into(sticker_file_tag::table)
.values((
Expand All @@ -55,7 +55,7 @@ impl Database {
tag_names: &[String],
user_id: i64,
) -> Result<(), DatabaseError> {
self.pool.get()?.transaction(|conn| {
self.pool.get()?.immediate_transaction(|conn| {
for tag in tag_names {
// TODO: figure out how to do this in a single query

Expand Down Expand Up @@ -95,7 +95,7 @@ impl Database {
tags: &[String],
user: i64,
) -> Result<usize, DatabaseError> {
let affected = self.pool.get()?.transaction(|conn| {
let affected = self.pool.get()?.immediate_transaction(|conn| {
let mut tags_affected = 0;
for tag in tags {
// TODO: translate to proper diesel query?
Expand All @@ -121,7 +121,7 @@ impl Database {
tags: &[String],
user: i64,
) -> Result<usize, DatabaseError> {
let affected = self.pool.get()?.transaction(|conn| {
let affected = self.pool.get()?.immediate_transaction(|conn| {
let mut tags_affected = 0;
for tag in tags {
let result: Vec<(String, Option<i64>)> = sticker_file_tag::table
Expand Down Expand Up @@ -167,7 +167,8 @@ impl Database {
&self,
sticker_file_id: &str,
tag: &str,
conn: &mut PooledConnection<ConnectionManager<SqliteConnection>>,
// conn: &mut PooledConnection<ConnectionManager<SqliteConnection>>,
conn: &mut SqliteConnection,
) -> Result<usize, diesel::result::Error> {
delete(
sticker_file_tag::table
Expand Down
3 changes: 2 additions & 1 deletion fuzzle/src/database/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ diesel::table! {
created_at -> Timestamp,
tags_locked_by_user_id -> Nullable<BigInt>,
thumbnail_file_id -> Nullable<Text>,
is_animated -> Bool,
// is_animated -> Bool,
sticker_type -> BigInt,
}
}

Expand Down
4 changes: 2 additions & 2 deletions fuzzle/src/database/schema_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use std::{

use crate::tags::Category;

use super::{schema, Blacklist, DatabaseError, DialogState, TagData, UserSettings};
use super::{schema, Blacklist, DatabaseError, DialogState, StickerType, TagData, UserSettings};

#[derive(Queryable, Selectable)]
#[diesel(table_name = schema::sticker_file)]
Expand All @@ -25,7 +25,7 @@ pub struct StickerFile {
pub created_at: chrono::NaiveDateTime,
pub tags_locked_by_user_id: Option<i64>,
pub thumbnail_file_id: Option<String>,
pub is_animated: bool,
pub sticker_type: StickerType,
}

#[derive(Queryable, Selectable)]
Expand Down
4 changes: 2 additions & 2 deletions fuzzle/src/sticker/analysis/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub use histogram::{calculate_color_histogram, create_historgram_image};
pub use measures::{Match, Measures};

use crate::{
background_tasks::BackgroundTaskExt, bot::{report_bot_error, report_internal_error_result, report_periodic_task_error, Bot, BotError, InternalError, RequestContext, UserError}, database::Database, inference::{image_to_clip_embedding, text_to_clip_embedding}, util::Required, Config
background_tasks::BackgroundTaskExt, bot::{report_bot_error, report_internal_error_result, report_periodic_task_error, Bot, BotError, InternalError, RequestContext, UserError}, database::{Database, StickerType}, inference::{image_to_clip_embedding, text_to_clip_embedding}, util::Required, Config
};

use crate::inline::SimilarityAspect;
Expand Down Expand Up @@ -118,7 +118,7 @@ pub async fn analyze_sticker(
let Some(file_info) = file_info else {
return Ok(false);
};
let buf = if !file_info.is_animated {
let buf = if file_info.sticker_type == StickerType::Static {
let sticker = database
.get_sticker_by_id(&sticker_unique_id)
.await?
Expand Down
1 change: 0 additions & 1 deletion fuzzle/src/sticker/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ pub async fn fetch_sticker_file(
bot: Bot,
) -> Result<(Vec<u8>, teloxide::types::File), InternalError> {
let file = bot.get_file(file_id).await?;
tracing::info!(file.path = file.path);
let mut buf = Vec::new();
bot.download_file(&file.path, &mut buf).await?;
if buf.len() == file.size as usize {
Expand Down
13 changes: 10 additions & 3 deletions fuzzle/src/sticker/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ async fn fetch_sticker_and_save_to_db(
bot: Bot,
database: Database,
) -> Result<(), InternalError> {
info!("fetching sticker from set {set_name}");
let emoji = sticker.emoji.map(|e| Emoji::new_from_string_single(&e));

let (buf, file) = fetch_sticker_file(sticker.file.id.clone(), bot.clone()).await?;
Expand All @@ -154,7 +153,11 @@ async fn fetch_sticker_and_save_to_db(
.create_file(
&canonical_file_hash,
sticker.thumb.map(|thumb| thumb.file.id),
!sticker.format.is_raster(),
match sticker.format {
teloxide::types::StickerFormat::Raster => crate::database::StickerType::Static,
teloxide::types::StickerFormat::Animated => crate::database::StickerType::Animated,
teloxide::types::StickerFormat::Video => crate::database::StickerType::Video,
},
)
.await?;
database
Expand Down Expand Up @@ -261,7 +264,11 @@ async fn fetch_sticker_set_and_save_to_db(
continue;
};
database
.___temp___update_is_animated(&sticker.id, !s.format.is_raster())
.___temp___update_sticker_type(&sticker.id, match s.format {
teloxide::types::StickerFormat::Raster => crate::database::StickerType::Static,
teloxide::types::StickerFormat::Animated => crate::database::StickerType::Animated,
teloxide::types::StickerFormat::Video => crate::database::StickerType::Video,
})
.await?;
if s.file.id != sticker.telegram_file_identifier {
// TODO: might be updated too frequently
Expand Down
7 changes: 2 additions & 5 deletions fuzzle/src/sticker/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize};
use crate::{
bot::{Bot, InternalError},
database::{
Database, MergeStatus
Database, MergeStatus, StickerType
},
qdrant::VectorDatabase, util::Required,
};
Expand Down Expand Up @@ -110,7 +110,6 @@ pub async fn automerge(
let already_considered = database
.get_all_merge_candidate_file_ids(&sticker.sticker_file_id)
.await?;
tracing::info!(embedding = ?similar_sticker_file_hashes_1, histogram = ?similar_sticker_file_hashes_2);
let similar_sticker_file_hashes_1 = similar_sticker_file_hashes_1
.into_iter()
.filter(|sticker| !already_considered.contains(&sticker.file_hash))
Expand All @@ -119,7 +118,6 @@ pub async fn automerge(
.into_iter()
.filter(|sticker| !already_considered.contains(&sticker.file_hash))
.collect_vec();
tracing::info!(embedding = ?similar_sticker_file_hashes_1, histogram = ?similar_sticker_file_hashes_2);
let most_similar = similar_sticker_file_hashes_1
.into_iter()
.filter_map(|match_1| {
Expand Down Expand Up @@ -158,7 +156,7 @@ pub async fn automerge(
else {
continue;
};
if sticker_b_file.is_animated || sticker_a_file.id == sticker_b_file.id {
if sticker_b_file.sticker_type != StickerType::Static || sticker_a_file.id == sticker_b_file.id {
continue;
}
let buf_b = get_sticker_file(database.clone(), bot.clone(), &sticker_id_b).await?;
Expand Down Expand Up @@ -224,7 +222,6 @@ pub async fn determine_canonical_sticker_and_merge(
}

// TODO: unit tests
#[tracing::instrument(ret)]
fn harmonic_mean(numbers: Vec<f32>) -> f32 {
let len = numbers.len() as f32;
let mut sum = 0.0;
Expand Down
5 changes: 0 additions & 5 deletions fuzzle/src/text/texts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,6 @@ impl Text {
Markdown::new("I can't add this sticker set\\.")
}

#[must_use]
pub fn sticker_not_found() -> Markdown {
Markdown::new("Sticker not found")
}

#[must_use]
pub fn get_settings_text(settings: &UserSettings) -> Markdown {
let order = match settings.order() {
Expand Down

0 comments on commit 040811d

Please sign in to comment.