-
Notifications
You must be signed in to change notification settings - Fork 0
/
eda.py
129 lines (108 loc) · 4.7 KB
/
eda.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# eda.py - Utility script for exploratory data analysis
# Import necessary libraries
import matplotlib.pyplot as plt
import numpy as np
from datasets import load_dataset
from scipy import stats
import os
# Configuration for the histograms
HISTOGRAM_CONFIG = {
'cnn_dailymail': {'bins': 40, 'range': (0, 1300)},
'xsum': {'bins': 20, 'range': (0, 80)}
}
# Configuration for datasets and their keys
DATASET_CONFIG = {
'cnn_dailymail': {'version': '3.0.0', 'article_key': 'article', 'summary_key': 'highlights'},
'xsum': {'version': None, 'article_key': 'document', 'summary_key': 'summary'}
}
SAMPLE_SIZE = 10000 # Default sample size
def load_and_print_dataset_info(dataset_name, version=None):
try:
if version:
dataset = load_dataset(dataset_name, version)
else:
dataset = load_dataset(dataset_name)
print(f"\n{dataset_name} dataset:")
for split, data in dataset.items():
print(f"{split} size: {len(data)}")
return dataset
except Exception as e:
print(f"Error loading {dataset_name} dataset. Check dataset name, version, or network connection. Detailed Error: {e}")
return None
def print_example_texts_and_summaries(dataset, article_key, summary_key, num_examples=3):
for i in range(num_examples):
example_text = dataset[article_key][i]
example_summary = dataset[summary_key][i]
print(f"\nExample text {i+1}:")
print(example_text)
print(f"\nExample summary {i+1}:")
print(example_summary)
def calculate_summary_lengths(dataset, summary_key):
return [len(summary.split()) for summary in dataset[summary_key]]
def display_statistics(data, dataset_name):
mode_data = stats.mode(data).mode
# Check if mode_data is an array or scalar and handle accordingly
if isinstance(mode_data, (list, np.ndarray)) and len(mode_data) > 0:
mode = mode_data[0]
else:
mode = mode_data
mean = np.mean(data)
median = np.median(data)
print(f"\nStatistics for {dataset_name}:")
print(f"Mean: {mean}")
print(f"Median: {median}")
print(f"Mode: {mode}")
def draw_mean_median_lines(data):
"""Helper function to draw mean and median lines on a histogram."""
mean = np.mean(data)
median = np.median(data)
plt.axvline(mean, color='r', linestyle='--')
plt.axvline(median, color='g', linestyle='-')
plt.legend({'Mean': mean, 'Median': median})
def plot_summary_lengths_histogram(data, title, dataset_name, subplot_position):
config = HISTOGRAM_CONFIG.get(dataset_name)
plt.subplot(*subplot_position)
plt.hist(data, bins=config['bins'], range=config['range'], edgecolor="k", alpha=0.7)
draw_mean_median_lines(data) # Use the helper function
plt.xlabel("Summary Length")
plt.ylabel("Frequency")
plt.title(title)
if __name__ == "__main__":
# Using the dataset configurations to manage dataset-specific attributes.
datasets = {}
for dataset_name, config in DATASET_CONFIG.items():
dataset = load_and_print_dataset_info(dataset_name, config['version'])
if dataset:
datasets[dataset_name] = dataset
if 'cnn_dailymail' in datasets:
train_cnn_dailymail = datasets['cnn_dailymail'].get('train')
if train_cnn_dailymail:
print("\nCNN/Daily Mail examples:")
print_example_texts_and_summaries(train_cnn_dailymail, DATASET_CONFIG['cnn_dailymail']['article_key'], DATASET_CONFIG['cnn_dailymail']['summary_key'])
if 'xsum' in datasets:
train_xsum = datasets['xsum'].get('train')
if train_xsum:
print("\nXSum examples:")
print_example_texts_and_summaries(train_xsum, DATASET_CONFIG['xsum']['article_key'], DATASET_CONFIG['xsum']['summary_key'])
summary_lengths = {}
for dataset_name, dataset_data in datasets.items():
train_data = dataset_data.get('train')
if train_data:
lengths = calculate_summary_lengths(train_data, DATASET_CONFIG[dataset_name]['summary_key'])
summary_lengths[dataset_name] = lengths
display_statistics(lengths, dataset_name)
# Plot histograms for summary lengths of each dataset
plt.figure(figsize=(10, 6))
for index, (dataset_name, lengths) in enumerate(summary_lengths.items(), 1):
plot_summary_lengths_histogram(lengths, f"Distribution of Summary Lengths ({dataset_name})", dataset_name, (2, 1, index))
try:
if not os.path.exists('Outputs'):
os.makedirs('Outputs')
except Exception as e:
print(f"Error creating 'Outputs' directory: {e}")
try:
plt.savefig(f'Outputs/histograms_combined.png')
except Exception as e:
print(f"Error saving histogram: {e}")
plt.tight_layout()
plt.show()