forked from LorenzoFramba/Flight_Delay_Prediction
-
Notifications
You must be signed in to change notification settings - Fork 0
/
getData.py
71 lines (50 loc) · 1.85 KB
/
getData.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from pyspark.sql import SQLContext
from pyspark.sql.session import SparkSession
class Data:
"""
Initiates Spark, builds a session to process the dataset along with checking for a wrong input
"""
def __init__(self, config):
"""
Initialization of Spark, Session and Context.
"""
self.cfg = config
self.spark, self.sc = self._init_spark()
self.checkFormatValidity()
def _init_spark(self):
spark = SparkSession.builder.appName("FlightArrivalDelay").getOrCreate()
sc = spark.sparkContext
return spark, sc
def checkFormatValidity(self):
"""
Checks the entry format validity
"""
try:
if '.csv' in self.cfg.dataset:
self.proceed = True
self.df = self.getDataset(self.spark,self.sc)
else:
self.proceed = False
print('File format not correct. It is required to provide a .csv file')
except ValueError:
print("file not compatible")
def getDataset(self, spark, sc):
"""
Opens, reads and processes the dataset
"""
sqlContext = SQLContext(sc)
if self.cfg.path != "":
self.cfg.path = self.cfg.path+'/'
getDf = sqlContext.read.load(self.cfg.path+self.cfg.dataset,
format='com.databricks.spark.csv',
header='true',
delimiter=',',
inferSchema='true')
max = getDf.count()
if(self.cfg.dataset_size > max):
"""
Checks if the user-entered size of data to be processed is not bigger than the actual dataset
"""
self.proceed = False
print("Please, select a dataset_size smaller than ", max)
return getDf