Skip to content

Commit

Permalink
All tokens take penalty.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Feb 26, 2024
1 parent fd19839 commit 5a00ca0
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
name = "ai00_server"
repository = "https://github.com/cgisky1980/ai00_rwkv_server"
rust-version = "1.75"
version = "0.3.19"
version = "0.3.20"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand Down
10 changes: 5 additions & 5 deletions src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ where
backed: Mutex<Trie<Tokens, Arc<B>>>,
max_runtime_batch: usize,
state_chunk_size: usize,
penalty_free_tokens: HashSet<u16>,
_penalty_free_tokens: HashSet<u16>,
}

impl<M, S, B> Runtime<M, S, B>
Expand All @@ -280,7 +280,7 @@ where
let slots = (0..state.num_batch())
.map(|_| SlotState::default())
.collect();
let penalty_free_tokens = (0..u16::MAX)
let _penalty_free_tokens = (0..u16::MAX)
.filter(|&token| {
let word = tokenizer.decode(&[token]).unwrap_or_default();
let word = String::from_utf8_lossy(&word).into_owned();
Expand All @@ -296,7 +296,7 @@ where
backed: Mutex::new(Trie::new()),
max_runtime_batch,
state_chunk_size,
penalty_free_tokens,
_penalty_free_tokens,
}
}

Expand Down Expand Up @@ -527,7 +527,7 @@ where
}
},
};
let penalty_free_tokens = &self.penalty_free_tokens;
// let penalty_free_tokens = &self._penalty_free_tokens;
let outputs = payloads
.par_iter()
.zip_eq(outputs.into_par_iter())
Expand All @@ -541,7 +541,7 @@ where
context
.penalties
.iter()
.filter(|(token, _)| !penalty_free_tokens.contains(token))
// .filter(|(token, _)| !penalty_free_tokens.contains(token))
.for_each(|(token, penalty)| data[*token as usize] -= penalty);
context
.request
Expand Down

0 comments on commit 5a00ca0

Please sign in to comment.