Skip to content

Commit

Permalink
u_roll now works properly after fixes of indexing and resetting.
Browse files Browse the repository at this point in the history
  • Loading branch information
Suchismit4 committed Jan 8, 2025
1 parent 3a78db5 commit d6e39c0
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 8 deletions.
Binary file modified apple_tsla_ema.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 6 additions & 6 deletions example_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def main():
"config": {
"provider": "yfinance",
"symbols": ["AAPL", "TSLA"],
"start_date": "2024-01-01",
"start_date": "2015-01-01",
"end_date": "2024-12-31",
}
}])
Expand Down Expand Up @@ -73,7 +73,7 @@ def ema(i: int, carry, block: jnp.ndarray, window_size: int):
dataset = datasets["openbb/equity/price/historical"]

# Rolling-EMA of "close" over a 200-day window
ema_dataset = dataset.dt.rolling(dim='time', window=60).reduce(ema)
ema_dataset = dataset.dt.rolling(dim='time', window=252).reduce(ema)

# Convert to time-indexed form for plotting
# -- Original closing prices --
Expand Down Expand Up @@ -101,9 +101,9 @@ def ema(i: int, carry, block: jnp.ndarray, window_size: int):
x="time", ax=ax1, label="AAPL Close", color="blue", linestyle="-"
)
apple_close_ema.plot.line(
x="time", ax=ax1, label="AAPL EMA(10)", color="blue", linestyle="--"
x="time", ax=ax1, label="AAPL EMA", color="blue", linestyle="--"
)
ax1.set_title("Apple (AAPL) Closing Prices vs. EMA(10)")
ax1.set_title("Apple (AAPL) Closing Prices vs. EMA")
ax1.set_xlabel("Time")
ax1.set_ylabel("Price (USD)")
ax1.legend()
Expand All @@ -113,9 +113,9 @@ def ema(i: int, carry, block: jnp.ndarray, window_size: int):
x="time", ax=ax2, label="TSLA Close", color="red", linestyle="-"
)
tsla_close_ema.plot.line(
x="time", ax=ax2, label="TSLA EMA(10)", color="red", linestyle="--"
x="time", ax=ax2, label="TSLA EMA", color="red", linestyle="--"
)
ax2.set_title("Tesla (TSLA) Closing Prices vs. EMA(10)")
ax2.set_title("Tesla (TSLA) Closing Prices vs. EMA")
ax2.set_xlabel("Time")
ax2.set_ylabel("Price (USD)")
ax2.legend()
Expand Down
Binary file modified src/data/__pycache__/processors_registry.cpython-312.pyc
Binary file not shown.
Binary file modified src/data/core/__pycache__/util.cpython-312.pyc
Binary file not shown.
4 changes: 2 additions & 2 deletions src/data/core/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def __init__(self,
self.window = window

self.mask = mask
self.indices = indices.astype(jnp.int32)
self.indices = jnp.where(indices == -1, 0, indices).astype(jnp.int32)

@eqx.filter_jit
def reduce(
Expand Down Expand Up @@ -461,7 +461,7 @@ def reduce(

# Remove the extra dimension added earlier
rolled_full = rolled_full[..., 0] # Shape: (T_full, assets, 1)
jax.debug.breakpoint()
# jax.debug.breakpoint()

# Reconstruct the DataArray with rolled data
rolled_da = stacked_obj.copy(data=rolled_full)
Expand Down

0 comments on commit d6e39c0

Please sign in to comment.