Skip to content

Commit

Permalink
Using new openai crate to generate embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Oct 11, 2023
1 parent 79e9aed commit 9da10aa
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 156 deletions.
156 changes: 9 additions & 147 deletions tokenizer/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tokenizer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
mongodb = "2.0"
ingestion = { path = "../ingestion" }
openai = "1.0.0-alpha.13"
openai-rust = "0.5.1"
4 changes: 2 additions & 2 deletions tokenizer/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ async fn main() -> std::io::Result<()> {
mongo.connect().await?;

// set collection to tokenize
let collection = "github_data"
let oai_key = "generate_2023"
let collection = "github_data";
let oai_key = "generate_2023";

// TODO: create a new collection for each repo, insert documents into sub collections
let tokenizer: OpenAIClient = OpenAIClient::new(oai_key, mongo, collection);
Expand Down
13 changes: 7 additions & 6 deletions tokenizer/src/oai.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use ingestion::mongo_utils::MongoDriver;
use mongodb::bson::{doc, Document};
use openai::{Language, OpenaiClient};
use openai_rust::Client;
use std::error::Error;
use std::env;


pub struct OpenAIClient {
oai_client: OpenaiClient,
oai_client: Client,
mongo_model: &MongoDriver,
model: Language,
}
Expand All @@ -15,7 +15,7 @@ impl OpenAIClient {

pub fn new(openai_api_key: &str, mongo_model: &MongoDriver) -> Self {
OpenAIClient {
oai_client: OpenaiClient::new(openai_api_key),
oai_client: Client::new(openai_api_key),
mongo_model: mongo_model,
model: Language::English,
}
Expand All @@ -32,10 +32,11 @@ impl OpenAIClient {
match result {
Ok(document) => {
if let Some(text) = document.get_str("text") {

let tokens = openai.tokenize(&language_model, text)?;
let args: openai_rust::embeddings::EmbeddingsArguments = openai_rust::embeddings::EmbeddingsArguments::new("text-embedding-ada-002", text.to_owned());
let embedding: Vec<openai_rust::embeddings::EmbeddingsData> = self.oai_client.create_embeddings(args).await.unwrap().data;

let update_doc = doc! {
"$set": { "tokens": tokens }
"$set": { "embedding": embedding }
};

col.update_one(document, update_doc, None).await?;
Expand Down

0 comments on commit 9da10aa

Please sign in to comment.