Skip to content

Commit

Permalink
update src code
Browse files Browse the repository at this point in the history
  • Loading branch information
ccxzhang committed Sep 14, 2023
1 parent 61a1b89 commit e767590
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 14 deletions.
Empty file added src/__init__.py
Empty file.
76 changes: 76 additions & 0 deletions src/google_trends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import os
import json
import pandas as pd
import requests
# !pip install google-api-python-client
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
# local import
from scripts.python.config import GoogleAPIkey

SERVICE_NAME = 'trends'
SERVICE_VERSION = 'v1beta'
_DISCOVERY_SERVICE_URL = 'https://www.googleapis.com/discovery/v1/apis/trends/v1beta/rest'


class GT:
def __init__(self, _GOOGLE_API_KEY = GoogleAPIkey):
self.service = build(
serviceName=SERVICE_NAME,
version=SERVICE_VERSION,
discoveryServiceUrl=_DISCOVERY_SERVICE_URL,
developerKey=_GOOGLE_API_KEY,
cache_discovery=False)
self.block_until = None

def get_health_trends(self, terms, timelineResolution="month"):
graph = self.service.getTimelinesForHealth(
terms=terms,
timelineResolution=timelineResolution
)

try:
response = graph.execute()
return response

except HttpError as http_error:
data = json.loads(http_error.content.decode('utf-8'))
code = data['error']['code']
reason = data['error']['errors'][0]['reason']
if code == 403 and reason == 'dailyLimitExceeded':
self.block_until = datetime.combine(
date.today() + timedelta(days=1), dtime.min)
raise RuntimeError('%s: blocked until %s' %
(reason, self.block_until))
import logging
logging.warning(http_error)
return []


def get_graph(self, terms,
restrictions_geo,
restrictions_startDate="2004-01"):
graph = self.service.getGraph(
terms=terms,
restrictions_geo=restrictions_geo,
restrictions_startDate=restrictions_startDate
)

try:
response = graph.execute()
return response

except HttpError as http_error:
import logging
logging.warning(http_error)

return []


@staticmethod
def to_df(result:json) -> pd.DataFrame:
df = pd.json_normalize(result["lines"], meta=["term"], record_path=["points"])
if "date" in df.columns:
df["date"] = pd.to_datetime(df["date"])

return df
13 changes: 13 additions & 0 deletions src/text/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,19 @@
from gensim.utils import simple_preprocess
from gensim.models import CoherenceModel

def is_in_word_list(row: str, terms: list) -> bool:
"""
Check if any of the given terms are present in the input row.
Args:
row (str): The input row to search for terms in.
terms (list): A list of terms to search for in the row.
Returns:
bool: True if any of the terms are found in the row, False otherwise.
"""
return any([word in str(row) for word in terms])

def sent_to_words(sentences):
for sentence in sentences:
sentence = re.sub(r'\s', ' ', sentence).strip()
Expand Down
5 changes: 2 additions & 3 deletions src/tsa/mtsmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import sklearn
import os
from statsmodels.tsa.api import VARMAX
from scripts.python.tsa.utsmodel import SARIMAXData
from scripts.python.tsa.ts_eval import *
from .utsmodel import SARIMAXData
from .ts_eval import *
from .ts_utils import *

class MultiTSData(SARIMAXData):
Expand Down Expand Up @@ -57,7 +57,6 @@ def __init__(self, country, var_name, exog, data=None):
self.exog = exog

def test_stationarity(self):
from .ts_utils import get_adf_df
adf_df = get_adf_df(self.data[self.var_name], self.var_name)

order = 0
Expand Down
8 changes: 4 additions & 4 deletions src/tsa/ts_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import pandas as pd
import numpy as np

import statsmodels
from statsmodels.tsa.stattools import ccf
from statsmodels.tsa.stattools import kpss

def cross_corr_df(data: pd.DataFrame,
series_a: str,
series_b: str) -> pd.DataFrame:

from statsmodels.tsa.stattools import ccf

sig_a, sig_b = data[series_a], data[series_b]

ccorr = ccf(sig_a, sig_b, adjusted=False)
Expand All @@ -22,7 +22,7 @@ def cross_corr_df(data: pd.DataFrame,
def kpss_test(data: pd.DataFrame,
incl_columns: list) -> pd.DataFrame:

from statsmodels.tsa.stattools import kpss

import warnings
warnings.filterwarnings("ignore")

Expand Down
12 changes: 5 additions & 7 deletions src/tsa/utsmodel.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
import os
import itertools
import seaborn as sns
import logging
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import scipy

#!pip install pmdarima
from statsmodels.tsa.seasonal import seasonal_decompose, STL
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.preprocessing import MinMaxScaler
from statsmodels.tsa.arima.model import ARIMA
import pmdarima as pm
from pmdarima import model_selection
from pmdarima import auto_arima
from pmdarima.model_selection import SlidingWindowForecastCV, cross_val_score
from .ts_eval import *
from .ts_utils import *


logging.basicConfig(filename='sarimax.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class SARIMAXData:
def __init__(self,
country: str,
Expand Down Expand Up @@ -140,9 +140,8 @@ def transform(self):
if self.transform_method == "scaledlogit":
self.transformed_y = self.scaledlogit_transform(self.y)
elif self.transform_method == "minmax":
from sklearn.preprocessing import MinMaxScaler
self.minmax = MinMaxScaler()
self.transformed_y = minmax.fit_transform(self.y)
self.transformed_y = self.minmax.fit_transform(self.y)
else:
self.transformed_y = self.y

Expand Down Expand Up @@ -229,7 +228,6 @@ def compare_models(y, exog,
hyper_params=None,
verbose=0):

from pmdarima.model_selection import SlidingWindowForecastCV, cross_val_score

if hyper_params is None:
hyper_params = {
Expand Down

0 comments on commit e767590

Please sign in to comment.