From 5759d4302c98e8090ea86416c9e04f0c3370ed38 Mon Sep 17 00:00:00 2001 From: brianreicher Date: Sat, 18 Nov 2023 17:48:37 -0500 Subject: [PATCH] Add tokenization functions --- tokenizer/src/tokenizer.rs | 104 ++++++++++++++++++++++++++++++------- 1 file changed, 86 insertions(+), 18 deletions(-) diff --git a/tokenizer/src/tokenizer.rs b/tokenizer/src/tokenizer.rs index 75e3f13..d711545 100644 --- a/tokenizer/src/tokenizer.rs +++ b/tokenizer/src/tokenizer.rs @@ -3,6 +3,9 @@ use mongodb::bson::{doc, Document}; use openai_rust::Client; use std::error::Error; use std::env; +use tokio::task; +use tokio::sync::mpsc; +use regex::Regex; pub struct Tokenizer { @@ -21,30 +24,95 @@ impl Tokenizer { } } - pub async fn tokenize_collection(&self, collection: &str) -> std::io::Result<>{ + pub async fn tokenize_collection(&self, collection: &str) -> std::io::Result<()> { let result: Vec = self.mongo_client.get_all_documents(self.collection); - let col: mongodb::Collection= self.mongo_model.client.as_ref() - .unwrap() - .database(&self.mongo_model.db_name) - .collection(collection); - - while let Some(result) = cursor.next().await { + let col: mongodb::Collection = self.mongo_model.client.as_ref() + .unwrap() + .database(&self.mongo_model.db_name) + .collection(collection); + + let (tx, rx) = mpsc::channel::>(32); + + for document in result { + let tx = tx.clone(); + let col_clone: mongodb::Collection = col.clone(); + + task::spawn(async move { + if let Some(text) = document.get_str("text") { + let tokens = tokenize_file(text).await; + + let update_doc: Document = doc! { + "$set": { "tokens": tokens } + }; + + if let Err(e) = col_clone.update_one(document.clone(), update_doc, None).await { + tx.send(Err(format!("Error updating document: {}", e))).await.unwrap(); + } else { + tx.send(Ok(document)).await.unwrap(); + } + } + }); + } + + drop(tx); + + while let Some(result) = rx.recv().await { match result { Ok(document) => { - if let Some(text) = document.get_str("text") { - let args: openai_rust::embeddings::EmbeddingsArguments = openai_rust::embeddings::EmbeddingsArguments::new("text-embedding-ada-002", text.to_owned()); - let embedding: Vec = self.oai_client.create_embeddings(args).await.unwrap().data; - - let update_doc: Document = doc! { - "$set": { "embedding": embedding } - }; - - col.update_one(document, update_doc, None).await?; - } + print!("Processesed document successfully") } - Error(e) => eprintln!("Error processing document: {}", e), + Err(e) => eprintln!("Error processing document: {}", e), } } + Ok(()) } + + async fn tokenize_file(input: &str) -> Vec { + let number_regex = Regex::new(r"\b\d+\b").unwrap(); + let identifier_regex = Regex::new(r"\b[a-zA-Z_][a-zA-Z0-9_]*\b").unwrap(); + let operator_regex = Regex::new(r"[\+\-\*/]").unwrap(); + + let mut tokens: Vec = Vec::new(); + + for mat in number_regex.find_iter(input) { + let token: Token = Token { + token_type: TokenType::Number, + value: mat.as_str().to_owned(), + }; + tokens.push(token); + } + + for mat in identifier_regex.find_iter(input) { + let token = Token { + token_type: TokenType::Identifier, + value: mat.as_str().to_owned(), + }; + tokens.push(token); + } + + for mat in operator_regex.find_iter(input) { + let token = Token { + token_type: TokenType::Operator, + value: mat.as_str().to_owned(), + }; + tokens.push(token); + } + + tokens + } +} + +#[derive(Debug)] +enum TokenType { + Number, + Identifier, + Operator, + // Add more token types as needed +} + +#[derive(Debug)] +struct Token { + token_type: TokenType, + value: String, }