forked from robinhad/ukrainian-tts
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_logger.py
41 lines (32 loc) · 1.17 KB
/
data_logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from gradio import utils
import os
import csv
import huggingface_hub
def log_data(hf_token: str, dataset_name: str, private=True):
path_to_dataset_repo = huggingface_hub.create_repo(
repo_id=dataset_name,
token=hf_token,
private=private,
repo_type="dataset",
exist_ok=True,
)
flagging_dir = "flagged"
dataset_dir = os.path.join(flagging_dir, dataset_name)
repo = huggingface_hub.Repository(
local_dir=dataset_dir,
clone_from=path_to_dataset_repo,
use_auth_token=hf_token,
)
repo.git_pull(lfs=True)
log_file = os.path.join(dataset_dir, "data_speed.csv")
def log_function(data):
repo.git_pull(lfs=True)
with open(log_file, "a", newline="", encoding="utf-8") as csvfile:
writer = csv.writer(csvfile)
for row in data:
writer.writerow(utils.sanitize_list_for_csv(row))
with open(log_file, "r", encoding="utf-8") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
repo.push_to_hub(commit_message="Flagged sample #{}".format(line_count))
return line_count
return log_function