-
Notifications
You must be signed in to change notification settings - Fork 0
/
6_thematic_coherence_prediction.R
135 lines (115 loc) · 4 KB
/
6_thematic_coherence_prediction.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
library(tidyverse)
library(progress)
library(openai)
# Read and preprocess the data
coherence_rating_sample <- read_csv("./data/final_sample.csv",
col_types = cols(year = col_character()))
coherence_rating_sample <- coherence_rating_sample %>%
mutate(input = paste("Text 1:", text_a, "\nText 2:", text_b))
# Read the JSONL file and convert to a dataframe
data <- jsonlite::stream_in(file("rawdata/validation_data.jsonl"))
df <- as_tibble(data)
# Flatten the dataframe
flat_df <- df %>%
unnest(messages) %>%
group_by(group = ceiling(row_number() / 3)) %>%
pivot_wider(names_from = role, values_from = content) %>%
ungroup() %>%
select(-group)
# Extract the system message
system_message <- flat_df$system[1]
# Function to get prediction
get_prediction <- function(user_input, system_message) {
tryCatch({
response <- openai::create_chat_completion(
model = "ft:gpt-4o-mini-2024-07-18:uniurb::9ovTt9Dp",
messages = list(
list(role = "system", content = system_message),
list(role = "user", content = user_input)
),
temperature = 0
)
as.numeric(str_extract(response$choices$message.content, "(?<=<result>)\\d+(?=</result>)"))
}, error = function(e) {
warning(paste("Error in prediction:", e$message))
return(NA)
})
}
# Get predictions for all inputs with a progress bar
total_inputs <- nrow(coherence_rating_sample)
pb <- progress_bar$new(
format = "[:bar] :percent ETA: :eta",
total = total_inputs,
clear = FALSE,
width = 60
)
predictions <- map_dbl(coherence_rating_sample$input, function(input) {
result <- get_prediction(input, system_message)
pb$tick()
return(result)
})
# Add predictions to the dataframe
coherence_rating_sample <- coherence_rating_sample %>%
mutate(model_rating = predictions)
saveRDS(coherence_rating_sample, "./data/coherence_rating_sample_rated.rds")
# Calculate statistics, excluding non-codable pairs (99) and disregarding the year
statistics_overall <- coherence_rating_sample %>%
filter(model_rating != 99) %>% # Exclude non-codable pairs
group_by(category) %>%
summarise(
mean_prediction = mean(model_rating, na.rm = TRUE),
n = n(),
.groups = "drop"
) %>%
arrange(category)
# Print the overall statistics
print("Overall statistics (disregarding year):")
print(statistics_overall)
# Calculate statistics by year (as before)
statistics_by_year <- coherence_rating_sample %>%
filter(model_rating != 99) %>% # Exclude non-codable pairs
group_by(year, category) %>%
summarise(
mean_prediction = mean(model_rating, na.rm = TRUE),
n = n(),
.groups = "drop"
) %>%
arrange(year, category)
# Print the statistics by year
print("Statistics by year:")
print(statistics_by_year)
# Optionally, you can save both statistics to CSV files
write_csv(statistics_overall, "./data/prediction_statistics_overall.csv")
write_csv(statistics_by_year, "./data/prediction_statistics_by_year.csv")
# Prepare data for ANOVA and post-hoc tests by year
anova_data <- coherence_rating_sample %>%
filter(model_rating != 99) %>% # Exclude non-codable pairs
select(year, category, model_rating)
# Perform one-way ANOVA and Tukey's HSD post-hoc test by year
anova_results_by_year <- anova_data %>%
group_by(year) %>%
nest() %>%
mutate(
anova = map(data, ~ aov(model_rating ~ category, data = .x)),
summary = map(anova, summary),
tukey = map(anova, TukeyHSD)
)
# Print ANOVA and Tukey's HSD results by year
anova_results_by_year %>%
rowwise() %>%
mutate(print_results = list({
cat("Year:", year, "\n")
cat("One-way ANOVA results:\n")
print(summary)
cat("Tukey's HSD post-hoc test results:\n")
print(tukey)
}))
# Visualize the results
ggplot(anova_data, aes(x = category, y = model_rating)) +
geom_boxplot() +
facet_wrap(~ year, scales = "free") +
theme_minimal() +
labs(title = "Distribution of Predictions by Category and Year",
x = "Category",
y = "Prediction")
ggsave("prediction_distribution_by_year.png", width = 12, height = 8)