Skip to content

Commit

Permalink
test: test yahoo finance connector
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri committed Sep 7, 2023
1 parent 8d9cd72 commit f5c4be0
Showing 1 changed file with 101 additions and 0 deletions.
101 changes: 101 additions & 0 deletions tests/connectors/test_yahoo_finance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from unittest.mock import patch
import pandas as pd
import pytest
import yfinance as yf
from pandasai.connectors.yahoo_finance import YahooFinanceConnector


@pytest.fixture
def stock_ticker():
return "AAPL"


@pytest.fixture
def where():
return [["column1", "=", "value1"], ["column2", ">", "value2"]]


@pytest.fixture
def cache_interval():
return 600


@pytest.fixture
def yahoo_finance_config(stock_ticker, where, cache_interval):
return {
"dialect": "yahoo_finance",
"username": "",
"password": "",
"host": "yahoo.finance.com",
"port": 443,
"database": "stock_data",
"table": stock_ticker,
"where": where,
}


@pytest.fixture
def yahoo_finance_connector(stock_ticker, where, cache_interval):
return YahooFinanceConnector(stock_ticker, where, cache_interval)


def test_head(yahoo_finance_connector):
with patch.object(yf.Ticker, "history") as mock_history:
mock_history.return_value = pd.DataFrame(
{
"Open": [1.0, 2.0, 3.0, 4.0, 5.0],
"High": [2.0, 3.0, 4.0, 5.0, 6.0],
"Low": [0.5, 1.5, 2.5, 3.5, 4.5],
"Close": [1.5, 2.5, 3.5, 4.5, 5.5],
"Volume": [100, 200, 300, 400, 500],
}
)
expected_result = pd.DataFrame(
{
"Open": [1.0, 2.0, 3.0, 4.0, 5.0],
"High": [2.0, 3.0, 4.0, 5.0, 6.0],
"Low": [0.5, 1.5, 2.5, 3.5, 4.5],
"Close": [1.5, 2.5, 3.5, 4.5, 5.5],
"Volume": [100, 200, 300, 400, 500],
}
)
assert yahoo_finance_connector.head().equals(expected_result)


def test_get_cache_path(yahoo_finance_connector):
with patch("os.path.join") as mock_join:
expected_result = "../AAPL_data.csv"
mock_join.return_value = expected_result
assert yahoo_finance_connector._get_cache_path() == expected_result


def test_rows_count(yahoo_finance_connector):
with patch.object(yf.Ticker, "history") as mock_history:
mock_history.return_value = pd.DataFrame(
{
"Open": [1.0, 2.0, 3.0, 4.0, 5.0],
"High": [2.0, 3.0, 4.0, 5.0, 6.0],
"Low": [0.5, 1.5, 2.5, 3.5, 4.5],
"Close": [1.5, 2.5, 3.5, 4.5, 5.5],
"Volume": [100, 200, 300, 400, 500],
}
)
assert yahoo_finance_connector.rows_count == 5


def test_columns_count(yahoo_finance_connector):
with patch.object(yf.Ticker, "history") as mock_history:
mock_history.return_value = pd.DataFrame(
{
"Open": [1.0, 2.0, 3.0, 4.0, 5.0],
"High": [2.0, 3.0, 4.0, 5.0, 6.0],
"Low": [0.5, 1.5, 2.5, 3.5, 4.5],
"Close": [1.5, 2.5, 3.5, 4.5, 5.5],
"Volume": [100, 200, 300, 400, 500],
}
)
assert yahoo_finance_connector.columns_count == 5


def test_fallback_name(yahoo_finance_connector, stock_ticker):
assert yahoo_finance_connector.fallback_name == stock_ticker

0 comments on commit f5c4be0

Please sign in to comment.