+420
-69
lines changedFilter options
+420
-69
lines changed Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@ fancy-regex = "0.11.0"
15
15
regex = "1.8.3"
16
16
rustc-hash = "1.1.0"
17
17
bstr = "1.5.0"
18
+
base64 = "0.21.7"
18
19
19
20
[features]
20
21
lua54 = ["mlua/lua54"]
Original file line number Diff line number Diff line change
@@ -2,8 +2,11 @@ use fancy_regex::Regex;
2
2
use mlua::prelude::*;
3
3
use rustc_hash::FxHashMap as HashMap;
4
4
use std::collections::HashSet;
5
+
use std::fs::File;
6
+
use std::io::{BufRead, BufReader};
5
7
use std::sync::{Arc, Mutex};
6
8
use std::thread;
9
+
use base64;
7
10
8
11
#[cfg(feature = "multithreading")]
9
12
const MAX_NUM_THREADS: usize = 128;
@@ -191,12 +194,12 @@ pub fn tiktoken_core(lua: &mlua::Lua) -> LuaResult<LuaTable> {
191
194
192
195
let _new = lua.create_function(
193
196
move |_,
194
-
(encoder, special_tokens_encoder, pattern): (
195
-
HashMap<LuaString, usize>,
197
+
(encoder_path, special_tokens_encoder, pattern): (
198
+
String,
196
199
HashMap<String, usize>,
197
200
String,
198
201
)| {
199
-
new(&*state, encoder, special_tokens_encoder, pattern);
202
+
new(&*state, encoder_path, special_tokens_encoder, pattern);
200
203
Ok(())
201
204
},
202
205
)?;
@@ -210,14 +213,21 @@ pub fn tiktoken_core(lua: &mlua::Lua) -> LuaResult<LuaTable> {
210
213
211
214
fn new(
212
215
state: &State,
213
-
iencoder: HashMap<LuaString, usize>,
216
+
encoder_path: String,
214
217
special_tokens_encoder: HashMap<String, usize>,
215
218
pattern: String,
216
219
) {
217
-
let encoder: HashMap<Vec<u8>, usize> = iencoder
218
-
.into_iter()
219
-
.map(|(k, v)| (k.as_bytes().to_vec(), v))
220
-
.collect();
220
+
let mut encoder: HashMap<Vec<u8>, usize> = HashMap::default();
221
+
// Read the encoder file each line is a base64 encoded token and rank separated by a space
222
+
let file = File::open(encoder_path).unwrap();
223
+
let reader = BufReader::new(file);
224
+
for line in reader.lines() {
225
+
let line = line.unwrap();
226
+
let mut parts = line.split_whitespace();
227
+
let token = base64::decode(parts.next().unwrap().as_bytes()).unwrap();
228
+
let rank = parts.next().unwrap().parse().unwrap();
229
+
encoder.insert(token, rank);
230
+
}
221
231
let regex = Regex::new(&pattern)
222
232
.map_err(|e| mlua::Error::external(e))
223
233
.unwrap();
@@ -230,11 +240,6 @@ fn new(
230
240
.map_err(|e| mlua::Error::external(e))
231
241
.unwrap()
232
242
};
233
-
let decoder: HashMap<usize, Vec<u8>> = encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
234
-
assert!(
235
-
encoder.len() == decoder.len(),
236
-
"Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
237
-
);
238
243
let special_tokens_decoder: HashMap<usize, Vec<u8>> = special_tokens_encoder
239
244
.iter()
240
245
.map(|(k, v)| (*v, k.as_bytes().to_vec()))
@@ -245,7 +250,8 @@ fn new(
245
250
*core_bpe_lock = Some(CoreBPENative {
246
251
encoder,
247
252
special_tokens_encoder,
248
-
decoder,
253
+
// empty decoder
254
+
decoder: HashMap::default(),
249
255
special_tokens_decoder,
250
256
regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
251
257
special_regex_tls: (0..MAX_NUM_THREADS)
You can’t perform that action at this time.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4