From 9e6dcd9cd8b68d7b10f8f9cce5cf0557d3148eaf Mon Sep 17 00:00:00 2001 From: alessio <104512673+alessiocastrica@users.noreply.github.com> Date: Mon, 7 Oct 2024 03:37:13 +0200 Subject: [PATCH] Rebalancing endpoints (#362) * feat: draft rebalancing create portfolio * feat: get all portfolios and create subscription * feat: get portfolio by ID * feat: update portfolio by ID * feat: inactivate portfolio by ID * feat: get all subscriptions * feat: get subscription by ID * feat: delete subscription by ID * feat: manual run endpoint * feat: list all runs with page iteration * feat: get and cancel run endpoints * fix: pagination like other endpoints * fix: missing client docstrings * chore: adjust typing * chore: black * feat: adjust rebalancing --------- Co-authored-by: Chihiro Hio --- alpaca/broker/client.py | 395 ++++++++++- alpaca/broker/enums.py | 97 +++ alpaca/broker/models/__init__.py | 1 + alpaca/broker/models/rebalancing.py | 80 +++ alpaca/broker/requests.py | 153 +++- alpaca/common/rest.py | 6 +- alpaca/common/types.py | 4 +- .../broker_client/test_rebalancing_routes.py | 664 ++++++++++++++++++ 8 files changed, 1387 insertions(+), 13 deletions(-) create mode 100644 alpaca/broker/models/rebalancing.py create mode 100644 tests/broker/broker_client/test_rebalancing_routes.py diff --git a/alpaca/broker/client.py b/alpaca/broker/client.py index 61db096d..3ca234cf 100644 --- a/alpaca/broker/client.py +++ b/alpaca/broker/client.py @@ -1,30 +1,63 @@ import base64 import warnings -from typing import Callable, Iterator, List, Optional, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Type, Union from uuid import UUID import sseclient from pydantic import TypeAdapter from requests import HTTPError, Response +from alpaca.broker.enums import ACHRelationshipStatus from alpaca.broker.models import ( Account, ACHRelationship, Bank, + BaseModel, BatchJournalResponse, - CIPInfo, Journal, Order, + Portfolio, + RebalancingRun, + Subscription, TradeAccount, TradeDocument, Transfer, ) +from alpaca.broker.requests import ( + CreateAccountRequest, + CreateACHRelationshipRequest, + CreateACHTransferRequest, + CreateBankRequest, + CreateBankTransferRequest, + CreateBatchJournalRequest, + CreateJournalRequest, + CreatePlaidRelationshipRequest, + CreatePortfolioRequest, + CreateReverseBatchJournalRequest, + CreateRunRequest, + CreateSubscriptionRequest, + GetAccountActivitiesRequest, + GetEventsRequest, + GetJournalsRequest, + GetPortfoliosRequest, + GetRunsRequest, + GetSubscriptionsRequest, + GetTradeDocumentsRequest, + GetTransfersRequest, + ListAccountsRequest, + OrderRequest, + UpdateAccountRequest, + UpdatePortfolioRequest, + UploadDocumentRequest, +) +from alpaca.common import RawData from alpaca.common.constants import ( ACCOUNT_ACTIVITIES_DEFAULT_PAGE_SIZE, BROKER_DOCUMENT_UPLOAD_LIMIT, ) from alpaca.common.enums import BaseURL, PaginationType from alpaca.common.exceptions import APIError +from alpaca.common.rest import HTTPResult, RESTClient from alpaca.common.utils import validate_symbol_or_asset_id, validate_uuid_id_param from alpaca.trading.enums import ActivityType from alpaca.trading.models import AccountConfiguration as TradeAccountConfiguration @@ -142,6 +175,62 @@ def _get_auth_headers(self) -> dict: return {"Authorization": "Basic " + auth_string_encoded.decode("utf-8")} + def _iterate_over_pages( + self, + endpoint: str, + params: Dict[str, Any], + response_field: str, + base_model_type: Type[BaseModel], + max_items_limit: Optional[int] = None, + ) -> Iterator[Union[RawData, BaseModel]]: + """ + Internal method to iterate over the result pages. + """ + + # we need to track total items retrieved + total_items = 0 + page_size = params.get("limit", 100) + + while True: + if max_items_limit is not None: + normalized_page_size = min( + int(max_items_limit) - total_items, page_size + ) + params["limit"] = normalized_page_size + + response = self.get(endpoint, params) + if response is None: + break + result = response.get(response_field, None) + + if not isinstance(result, List) or len(result) == 0: + break + + num_items_returned = len(result) + if ( + max_items_limit is not None + and num_items_returned + total_items > max_items_limit + ): + result = result[: (max_items_limit - total_items)] + total_items += max_items_limit - total_items + else: + total_items += num_items_returned + + if self._use_raw_data: + yield result + else: + yield TypeAdapter(type=List[base_model_type]).validate_python(result) + + if max_items_limit is not None and total_items >= max_items_limit: + break + + page_token = response.get("next_page_token", None) + + if page_token is None: + break + + params["page_token"] = page_token + # ############################## ACCOUNTS/TRADING ACCOUNTS ################################# # def create_account( @@ -1929,3 +2018,305 @@ def _get_sse_headers(self) -> dict: headers["Accept"] = "text/event-stream" return headers + + # ############################## REBALANCING ################################# # + + def create_portfolio( + self, portfolio_request: CreatePortfolioRequest + ) -> Union[Portfolio, RawData]: + """ + Create a new portfolio. + + ref. https://docs.alpaca.markets/reference/post-v1-rebalancing-portfolios + + Args: + portfolio_request (CreatePortfolioRequest): The details required to create a new portfolio. + + Returns: + Portfolio: Newly created portfolio. + """ + + response = self.post( + "/rebalancing/portfolios", data=portfolio_request.to_request_fields() + ) + + if self._use_raw_data: + return response + + return Portfolio(**response) + + def get_all_portfolios( + self, + filter: Optional[GetPortfoliosRequest] = None, + ) -> Union[List[Portfolio], List[RawData]]: + """ + Retrieves all portfolios based on the filter provided. + + ref. https://docs.alpaca.markets/reference/get-v1-rebalancing-portfolios + + Args: + filter (Optional[GetPortfoliosRequest]): Filter criteria to narrow down portfolio list. + + Returns: + List[Portfolio]: List of portfolios. + """ + + response = self.get( + "/rebalancing/portfolios", filter.to_request_fields() if filter else {} + ) + + if self._use_raw_data: + return response + + return TypeAdapter( + List[Portfolio], + ).validate_python(response) + + def get_portfolio_by_id( + self, portfolio_id: Union[UUID, str] + ) -> Union[Portfolio, RawData]: + """ + Retrieves a specific portfolio using its ID. + + Args: + portfolio_id (Union[UUID, str]): The ID of the desired portfolio. + + Returns: + Portfolio: The portfolio queried. + """ + + response = self.get(f"/rebalancing/portfolios/{portfolio_id}") + + if self._use_raw_data: + return response + + return Portfolio(**response) + + def update_portfolio_by_id( + self, + portfolio_id: Union[UUID, str], + update_request: UpdatePortfolioRequest, + ) -> Union[Portfolio, RawData]: + """ + Updates a portfolio by ID. + If weights or conditions are changed, all subscribed accounts will be evaluated for rebalancing at the next opportunity (normal market hours). + If a cooldown is active on the portfolio, the rebalancing will occur after the cooldown expired. + + ref. https://docs.alpaca.markets/reference/patch-v1-rebalancing-portfolios-portfolio_id-1 + + Args: + portfolio_id (Union[UUID, str]): The ID of the portfolio to be updated. + update_request: The details to be updated for the portfolio. + + Returns: + Portfolio: Updated portfolio. + """ + portfolio_id = validate_uuid_id_param(portfolio_id) + + response = self.patch( + f"/rebalancing/portfolios/{portfolio_id}", + data=update_request.to_request_fields(), + ) + + if self._use_raw_data: + return response + + return Portfolio(**response) + + def inactivate_portfolio_by_id(self, portfolio_id: Union[UUID, str]) -> None: + """ + Sets a portfolio to “inactive”, so it can be filtered out of the list request. + Only permitted if there are no active subscriptions to this portfolio and this portfolio is not a listed in the weights of any active portfolios. + Inactive portfolios cannot be linked in new subscriptions or added as weights to new portfolios. + + ref. https://docs.alpaca.markets/reference/delete-v1-rebalancing-portfolios-portfolio_id-1 + + Args: + portfolio_id (Union[UUID, str]): The ID of the portfolio to be inactivated. + """ + portfolio_id = validate_uuid_id_param(portfolio_id) + + self.delete( + f"/rebalancing/portfolios/{portfolio_id}", + ) + + def create_subscription( + self, subscription_request: CreateSubscriptionRequest + ) -> Union[Subscription, RawData]: + """ + Create a new subscription. + + Args: + subscription_request (CreateSubscriptionRequest): The details required to create a new subscription. + + Returns: + Subscription: Newly created subscription. + """ + + response = self.post( + "/rebalancing/subscriptions", data=subscription_request.to_request_fields() + ) + + if self._use_raw_data: + return response + + return Subscription(**response) + + def get_all_subscriptions( + self, + filter: Optional[GetSubscriptionsRequest] = None, + max_items_limit: Optional[int] = None, + handle_pagination: Optional[PaginationType] = None, + ) -> Union[List[Subscription], List[RawData]]: + """ + Retrieves all subscriptions based on the filter provided. + + ref. https://docs.alpaca.markets/reference/get-v1-rebalancing-subscriptions-1 + + Args: + filter (Optional[GetSubscriptionsRequest]): Filter criteria to narrow down subscription list. + max_items_limit (Optional[int]): A maximum number of items to return over all for when handle_pagination is + of type `PaginationType.FULL`. Ignored otherwise. + handle_pagination (Optional[PaginationType]): What kind of pagination you want. If None then defaults to + `PaginationType.FULL`. + + Returns: + List[Subscription]: List of subscriptions. + """ + handle_pagination = BrokerClient._validate_pagination( + max_items_limit, handle_pagination + ) + + subscriptions_iterator = self._iterate_over_pages( + endpoint="/rebalancing/subscriptions", + params=filter.to_request_fields() if filter else {}, + response_field="subscriptions", + base_model_type=Subscription, + max_items_limit=max_items_limit, + ) + + return BrokerClient._return_paginated_result( + subscriptions_iterator, handle_pagination + ) + + def get_subscription_by_id( + self, subscription_id: Union[UUID, str] + ) -> Union[Subscription, RawData]: + """ + Get a subscription by its ID. + + Args: + subscription_id (Union[UUID, str]): The ID of the desired subscription. + + Returns: + Subscription: The subscription queried. + """ + subscription_id = validate_uuid_id_param(subscription_id) + + response = self.get(f"/rebalancing/subscriptions/{subscription_id}") + + if self._use_raw_data: + return response + + return Subscription(**response) + + def unsubscribe_account(self, subscription_id: Union[UUID, str]) -> None: + """ + Deletes the subscription which stops the rebalancing of an account. + + Args: + subscription_id (Union[UUID, str]): The ID of the subscription to be removed. + """ + subscription_id = validate_uuid_id_param(subscription_id) + + self.delete( + f"/rebalancing/subscriptions/{subscription_id}", + ) + + def create_manual_run( + self, rebalancing_run_request: CreateRunRequest + ) -> Union[RebalancingRun, RawData]: + """ + Create a new manual rebalancing run. + + Args: + rebalancing_run_request: The details required to create a new rebalancing run. + + Returns: + RebalancingRun: The rebalancing run initiated. + """ + + response = self.post( + "/rebalancing/runs", data=rebalancing_run_request.to_request_fields() + ) + + if self._use_raw_data: + return response + + return RebalancingRun(**response) + + def get_all_runs( + self, + filter: Optional[GetRunsRequest] = None, + max_items_limit: Optional[int] = None, + handle_pagination: Optional[PaginationType] = None, + ) -> Union[List[RebalancingRun], List[RawData]]: + """ + Get all runs. + + Args: + filter (Optional[GetRunsRequest]): Filter criteria to narrow down run list. + max_items_limit (Optional[int]): A maximum number of items to return over all for when handle_pagination is + of type `PaginationType.FULL`. Ignored otherwise. + handle_pagination (Optional[PaginationType]): What kind of pagination you want. If None then defaults to + `PaginationType.FULL`. + + Returns: + List[RebalancingRun]: List of rebalancing runs. + """ + handle_pagination = BrokerClient._validate_pagination( + max_items_limit, handle_pagination + ) + + runs_iterator = self._iterate_over_pages( + endpoint="/rebalancing/runs", + params=filter.to_request_fields() if filter else {}, + response_field="runs", + base_model_type=RebalancingRun, + max_items_limit=max_items_limit, + ) + + return BrokerClient._return_paginated_result(runs_iterator, handle_pagination) + + def get_run_by_id(self, run_id: Union[UUID, str]) -> Union[RebalancingRun, RawData]: + """ + Get a run by its ID. + + Args: + run_id (Union[UUID, str]): The ID of the desired rebalancing run. + + Returns: + RebalancingRun: The rebalancing run queried. + """ + run_id = validate_uuid_id_param(run_id) + + response = self.get(f"/rebalancing/runs/{run_id}") + + if self._use_raw_data: + return response + + return RebalancingRun(**response) + + def cancel_run_by_id(self, run_id: Union[UUID, str]) -> None: + """ + Cancels a run. + + Only runs within certain statuses (QUEUED, CANCELED, SELLS_IN_PROGRESS, BUYS_IN_PROGRESS) are cancelable. + If this endpoint is called after orders have been submitted, we’ll attempt to cancel the orders. + + Args: + run_id (Union[UUID, str]): The ID of the desired rebalancing run. + """ + run_id = validate_uuid_id_param(run_id) + + self.delete(f"/rebalancing/runs/{run_id}") diff --git a/alpaca/broker/enums.py b/alpaca/broker/enums.py index 1da833cc..5362f94b 100644 --- a/alpaca/broker/enums.py +++ b/alpaca/broker/enums.py @@ -426,3 +426,100 @@ class JournalStatus(str, Enum): REFUSED = "refused" CORRECT = "correct" DELETED = "deleted" + + +class PortfolioStatus(str, Enum): + """ + The possible values of the Portfolio status. + + See https://docs.alpaca.markets/reference/get-v1-rebalancing-portfolios + """ + + ACTIVE = "active" + INACTIVE = "inactive" + NEEDS_ADJUSTMENT = "needs_adjustment" + + +class WeightType(str, Enum): + """ + The possible values of the Weight type. + + See https://docs.alpaca.markets/reference/post-v1-rebalancing-portfolios + """ + + CASH = "cash" + ASSET = "asset" + + +class RebalancingConditionsType(str, Enum): + """ + The possible values of the Rebalancing Conditions type. + + See https://docs.alpaca.markets/reference/post-v1-rebalancing-portfolios + """ + + DRIFT_BAND = "drift_band" + CALENDAR = "calendar" + + +class DriftBandSubType(str, Enum): + """ + The possible values of the Rebalancing Conditions subtype for drift_band. + + See https://docs.alpaca.markets/reference/post-v1-rebalancing-portfolios + """ + + ABSOLUTE = "absolute" + RELATIVE = "relative" + + +class CalendarSubType(str, Enum): + """ + The possible values of the Rebalancing Conditions subtype for drift_band. + + See https://docs.alpaca.markets/reference/post-v1-rebalancing-portfolios + """ + + WEEKLY = "weekly" + MONTHLY = "monthly" + QUARTERLY = "quarterly" + ANNUALLY = "annually" + + +class RunType(str, Enum): + """ + The possible values of the Run type. + + See https://docs.alpaca.markets/reference/post-v1-rebalancing-runs + """ + + FULL_REBALANCE = "full_rebalance" + INVEST_CASH = "invest_cash" + + +class RunInitiatedFrom(str, Enum): + """ + The possible values of the initiated_from field. + + See https://docs.alpaca.markets/docs/portfolio-rebalancing + """ + + SYSTEM = "system" + API = "api" + + +class RunStatus(str, Enum): + """ + The possible values of the Run status. + + See https://docs.alpaca.markets/reference/get-v1-rebalancing-runs + """ + + QUEUED = "QUEUED" + IN_PROGRESS = "IN_PROGRESS" + CANCELED = "CANCELED" + CANCELED_MID_RUN = "CANCELED_MID_RUN" + ERROR = "ERROR" + TIMEOUT = "TIMEOUT" + COMPLETED_SUCCESS = "COMPLETED_SUCCESS" + COMPLETED_ADJUSTED = "COMPLETED_ADJUSTED" diff --git a/alpaca/broker/models/__init__.py b/alpaca/broker/models/__init__.py index b02b2ec7..e007ca79 100644 --- a/alpaca/broker/models/__init__.py +++ b/alpaca/broker/models/__init__.py @@ -4,3 +4,4 @@ from .funding import * from .trading import * from .journals import * +from .rebalancing import * diff --git a/alpaca/broker/models/rebalancing.py b/alpaca/broker/models/rebalancing.py new file mode 100644 index 00000000..9fdd6cac --- /dev/null +++ b/alpaca/broker/models/rebalancing.py @@ -0,0 +1,80 @@ +from datetime import datetime +from typing import List, Optional +from uuid import UUID + +from alpaca.broker.enums import PortfolioStatus, RunInitiatedFrom, RunStatus, RunType +from alpaca.broker.models import Order +from alpaca.broker.requests import RebalancingConditions, Weight +from alpaca.common.models import ValidateBaseModel as BaseModel + + +class Portfolio(BaseModel): + """ + Portfolio response model. + + https://docs.alpaca.markets/reference/get-v1-rebalancing-portfolios + """ + + id: UUID + name: str + description: str + status: PortfolioStatus + cooldown_days: int + created_at: datetime + updated_at: datetime + weights: List[Weight] + rebalance_conditions: Optional[List[RebalancingConditions]] = None + + +class Subscription(BaseModel): + """ + Subscription response model. + + https://docs.alpaca.markets/reference/get-v1-rebalancing-subscriptions-1 + """ + + id: UUID + account_id: UUID + portfolio_id: UUID + created_at: datetime + last_rebalanced_at: Optional[datetime] = None + + +class SkippedOrder(BaseModel): + """ + Skipped order response model. + + https://docs.alpaca.markets/reference/get-v1-rebalancing-runs-run_id-1 + """ + + symbol: str + side: Optional[str] = None + notional: Optional[str] = None + currency: Optional[str] = None + reason: str + reason_details: str + + +class RebalancingRun(BaseModel): + """ + Rebalancing run response model. + + https://docs.alpaca.markets/reference/get-v1-rebalancing-runs + """ + + id: UUID + account_id: UUID + type: RunType + amount: Optional[str] = None + portfolio_id: UUID + weights: List[Weight] + initiated_from: Optional[RunInitiatedFrom] = None + created_at: datetime + updated_at: datetime + completed_at: Optional[datetime] = None + canceled_at: Optional[datetime] = None + status: RunStatus + reason: Optional[str] = None + orders: Optional[List[Order]] = None + failed_orders: Optional[List[Order]] = None + skipped_orders: Optional[List[SkippedOrder]] = None diff --git a/alpaca/broker/requests.py b/alpaca/broker/requests.py index a46f5531..0bbb8710 100644 --- a/alpaca/broker/requests.py +++ b/alpaca/broker/requests.py @@ -1,8 +1,8 @@ from datetime import date, datetime -from typing import List, Optional, Union, Dict, Any +from typing import Any, Dict, List, Optional, Union from uuid import UUID -from pydantic import model_validator, field_validator +from pydantic import field_validator, model_validator from alpaca.broker.models.accounts import ( AccountDocument, @@ -16,12 +16,17 @@ from alpaca.broker.enums import ( AccountEntities, BankAccountType, + CalendarSubType, DocumentType, - EmploymentStatus, + DriftBandSubType, FeePaymentMethod, FundingSource, IdentifierType, - TaxIdType, + JournalEntryType, + JournalStatus, + PortfolioStatus, + RebalancingConditionsType, + RunType, TradeDocumentType, TransferDirection, TransferTiming, @@ -29,9 +34,9 @@ UploadDocumentMimeType, UploadDocumentSubType, VisaType, - JournalEntryType, - JournalStatus, + WeightType, ) +from alpaca.common.models import BaseModel from alpaca.common.enums import Sort, SupportedCurrencies from alpaca.trading.enums import ActivityType, AccountStatus, OrderType, AssetClass from alpaca.common.requests import NonEmptyRequest @@ -44,7 +49,6 @@ TrailingStopOrderRequest as BaseTrailingStopOrderRequest, ) - # ############################## Accounts ################################# # @@ -985,3 +989,138 @@ class GetEventsRequest(NonEmptyRequest): until: Optional[Union[date, str]] = None since_id: Optional[int] = None until_id: Optional[int] = None + + +# ############################## Rebalancing ################################# # + + +class Weight(BaseModel): + """ + Weight model. + + https://docs.alpaca.markets/reference/post-v1-rebalancing-portfolios + """ + + type: WeightType + symbol: Optional[str] = None + percent: float + + @field_validator("percent") + def percent_must_be_positive(cls, value: float) -> float: + """Validate and round the percent field to 2 decimal places.""" + if value <= 0: + raise ValueError("You must provide an amount > 0.") + return round(value, 2) + + @model_validator(mode="before") + def validator(cls, values: dict) -> dict: + """Verify that the symbol is provided when the weights type is asset.""" + if ( + values["type"] == WeightType.ASSET.value + and values.get("symbol", None) is None + ): + raise ValueError + return values + + +class RebalancingConditions(BaseModel): + """ + Rebalancing conditions model. + + https://docs.alpaca.markets/reference/post-v1-rebalancing-portfolios + """ + + type: RebalancingConditionsType + sub_type: Union[DriftBandSubType, CalendarSubType] + percent: Optional[float] = None + day: Optional[str] = None + + +class CreatePortfolioRequest(NonEmptyRequest): + """ + Portfolio request model. + + https://docs.alpaca.markets/reference/post-v1-rebalancing-portfolios + """ + + name: str + description: str + weights: List[Weight] + cooldown_days: int + rebalance_conditions: Optional[List[RebalancingConditions]] = None + + +class UpdatePortfolioRequest(NonEmptyRequest): + """ + Portfolio request update model. + + https://docs.alpaca.markets/reference/patch-v1-rebalancing-portfolios-portfolio_id-1 + """ + + name: Optional[str] = None + description: Optional[str] = None + weights: Optional[List[Weight]] = None + cooldown_days: Optional[int] = None + rebalance_conditions: Optional[List[RebalancingConditions]] = None + + +class GetPortfoliosRequest(NonEmptyRequest): + """ + Get portfolios request query parameters. + + https://docs.alpaca.markets/reference/get-v1-rebalancing-portfolios + """ + + name: Optional[str] = None + description: Optional[str] = None + symbol: Optional[str] = None + portfolio_id: Optional[UUID] = None + status: Optional[PortfolioStatus] = None + + +class CreateSubscriptionRequest(NonEmptyRequest): + """ + Subscription request model. + + https://docs.alpaca.markets/reference/post-v1-rebalancing-subscriptions-1 + """ + + account_id: UUID + portfolio_id: UUID + + +class GetSubscriptionsRequest(NonEmptyRequest): + """ + Get subscriptions request query parameters. + + https://docs.alpaca.markets/reference/get-v1-rebalancing-subscriptions-1 + """ + + account_id: Optional[UUID] = None + portfolio_id: Optional[UUID] = None + limit: Optional[int] = None + page_token: Optional[str] = None + + +class CreateRunRequest(NonEmptyRequest): + """ + Manually creates a rebalancing run. + + https://docs.alpaca.markets/reference/post-v1-rebalancing-runs + """ + + account_id: UUID + type: RunType + weights: List[Weight] + + +class GetRunsRequest(NonEmptyRequest): + """ + Get runs request query parameters. + + https://docs.alpaca.markets/reference/get-v1-rebalancing-runs + """ + + account_id: Optional[UUID] = None + type: Optional[RunType] = None + limit: Optional[int] = None diff --git a/alpaca/common/rest.py b/alpaca/common/rest.py index 021ec7bf..a42662fb 100644 --- a/alpaca/common/rest.py +++ b/alpaca/common/rest.py @@ -57,7 +57,7 @@ def __init__( """ self._api_key, self._secret_key, self._oauth_token = self._validate_credentials( - api_key, secret_key, oauth_token + api_key=api_key, secret_key=secret_key, oauth_token=oauth_token ) self._api_version: str = api_version self._base_url: Union[BaseURL, str] = base_url @@ -207,7 +207,9 @@ def _one_request(self, method: str, url: str, opts: dict, retry: int) -> dict: if response.text != "": return response.json() - def get(self, path: str, data: Union[dict, str] = None, **kwargs) -> HTTPResult: + def get( + self, path: str, data: Optional[Union[dict, str]] = None, **kwargs + ) -> HTTPResult: """Performs a single GET request Args: diff --git a/alpaca/common/types.py b/alpaca/common/types.py index f19420c3..632aae9c 100644 --- a/alpaca/common/types.py +++ b/alpaca/common/types.py @@ -1,7 +1,7 @@ -from typing import Dict, Any, List, Union, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union RawData = Dict[str, Any] # TODO: Refine this type HTTPResult = Union[dict, List[dict], Any] -Credentials = Tuple[str, str] +Credentials = Tuple[Optional[str], Optional[str], Optional[str]] diff --git a/tests/broker/broker_client/test_rebalancing_routes.py b/tests/broker/broker_client/test_rebalancing_routes.py new file mode 100644 index 00000000..af9da783 --- /dev/null +++ b/tests/broker/broker_client/test_rebalancing_routes.py @@ -0,0 +1,664 @@ +from uuid import UUID + +from requests_mock import Mocker + +from alpaca.broker.client import BrokerClient +from alpaca.broker.enums import WeightType +from alpaca.broker.models import Portfolio, RebalancingRun, Subscription +from alpaca.broker.requests import ( + CreatePortfolioRequest, + CreateRunRequest, + CreateSubscriptionRequest, + GetPortfoliosRequest, + GetRunsRequest, + GetSubscriptionsRequest, + UpdatePortfolioRequest, +) +from alpaca.common.enums import BaseURL + + +def test_create_portfolio(reqmock: Mocker, client: BrokerClient) -> None: + """Test to create a portfolio.""" + reqmock.post( + f"{BaseURL.BROKER_SANDBOX.value}/v1/rebalancing/portfolios", + text=""" + { + "id": "6819ecd2-db92-4688-821d-8fac2a8f4744", + "name": "Balanced", + "description": "A balanced portfolio of stocks and bonds", + "status": "active", + "cooldown_days": 7, + "created_at": "2022-08-06T19:12:13.555858187-04:00", + "updated_at": "2022-08-06T19:12:13.628551899-04:00", + "weights": [ + { + "type": "cash", + "symbol": null, + "percent": "5" + }, + { + "type": "asset", + "symbol": "SPY", + "percent": "60" + }, + { + "type": "asset", + "symbol": "TLT", + "percent": "35" + } + ], + "rebalance_conditions": [ + { + "type": "drift_band", + "sub_type": "absolute", + "percent": "5", + "day": null + }, + { + "type": "drift_band", + "sub_type": "relative", + "percent": "20", + "day": null + } + ] + } + """, + ) + + portfolio_request = CreatePortfolioRequest( + **{ + "name": "Balanced", + "description": "A balanced portfolio of stocks and bonds", + "weights": [ + {"type": "cash", "percent": "5"}, + {"type": "asset", "symbol": "SPY", "percent": "60"}, + {"type": "asset", "symbol": "TLT", "percent": "35"}, + ], + "cooldown_days": 7, + "rebalance_conditions": [ + {"type": "drift_band", "sub_type": "absolute", "percent": "5"}, + {"type": "drift_band", "sub_type": "relative", "percent": "20"}, + ], + } + ) + ptf = client.create_portfolio(portfolio_request) + + assert reqmock.called_once + assert isinstance(ptf, Portfolio) + + +def test_get_all_portfolios(reqmock: Mocker, client: BrokerClient) -> None: + """Test the get_all_portfolios method.""" + reqmock.get( + f"{BaseURL.BROKER_SANDBOX.value}/v1/rebalancing/portfolios", + text=""" + [ + { + "id": "57d4ec79-9658-4916-9eb1-7c672be97e3e", + "name": "My Portfolio", + "description": "Some description", + "status": "active", + "cooldown_days": 2, + "created_at": "2022-07-28T20:33:59.665962Z", + "updated_at": "2022-07-28T20:33:59.786528Z", + "weights": [ + { + "type": "asset", + "symbol": "AAPL", + "percent": "35" + }, + { + "type": "asset", + "symbol": "TSLA", + "percent": "20" + }, + { + "type": "asset", + "symbol": "SPY", + "percent": "45" + } + ], + "rebalance_conditions": [ + { + "type": "drift_band", + "sub_type": "absolute", + "percent": "5", + "day": null + }, + { + "type": "drift_band", + "sub_type": "relative", + "percent": "20", + "day": null + } + ] + }, + { + "id": "6819ecd2-db92-4688-821d-8fac2a8f4744", + "name": "Balanced", + "description": "A balanced portfolio of stocks and bonds", + "status": "active", + "cooldown_days": 7, + "created_at": "2022-08-06T23:12:13.555858Z", + "updated_at": "2022-08-06T23:12:13.628551Z", + "weights": [ + { + "type": "cash", + "symbol": null, + "percent": "5" + }, + { + "type": "asset", + "symbol": "SPY", + "percent": "60" + }, + { + "type": "asset", + "symbol": "TLT", + "percent": "35" + } + ], + "rebalance_conditions": [ + { + "type": "drift_band", + "sub_type": "absolute", + "percent": "5", + "day": null + }, + { + "type": "drift_band", + "sub_type": "relative", + "percent": "20", + "day": null + } + ] + }, + { + "id": "2d49d00e-ab1c-4014-89d8-70c5f64df2fc", + "name": "Balanced Two", + "description": "A balanced portfolio of stocks and bonds", + "status": "active", + "cooldown_days": 7, + "created_at": "2022-08-07T18:56:45.116867Z", + "updated_at": "2022-08-07T18:56:45.196857Z", + "weights": [ + { + "type": "cash", + "symbol": null, + "percent": "5" + }, + { + "type": "asset", + "symbol": "SPY", + "percent": "60" + }, + { + "type": "asset", + "symbol": "TLT", + "percent": "35" + } + ], + "rebalance_conditions": [ + { + "type": "drift_band", + "sub_type": "absolute", + "percent": "5", + "day": null + }, + { + "type": "drift_band", + "sub_type": "relative", + "percent": "20", + "day": null + } + ] + } +] + """, + ) + response = client.get_all_portfolios(filter=GetPortfoliosRequest()) + + assert reqmock.called_once + assert len(response) > 0 + assert isinstance(response[0], Portfolio) + + +def test_get_portfolio_by_id(reqmock: Mocker, client: BrokerClient) -> None: + """Test the get_portfolio_by_id method.""" + ptf_id = UUID("57d4ec79-9658-4916-9eb1-7c672be97e3e") + reqmock.get( + f"{BaseURL.BROKER_SANDBOX.value}/v1/rebalancing/portfolios/{ptf_id}", + text=""" + { + "id": "57d4ec79-9658-4916-9eb1-7c672be97e3e", + "name": "My Portfolio", + "description": "Some description", + "status": "active", + "cooldown_days": 2, + "created_at": "2022-07-28T20:33:59.665962Z", + "updated_at": "2022-07-28T20:33:59.786528Z", + "weights": [ + { + "type": "asset", + "symbol": "AAPL", + "percent": "35" + }, + { + "type": "asset", + "symbol": "TSLA", + "percent": "20" + }, + { + "type": "asset", + "symbol": "SPY", + "percent": "45" + } + ], + "rebalance_conditions": [ + { + "type": "drift_band", + "sub_type": "absolute", + "percent": "5", + "day": null + }, + { + "type": "drift_band", + "sub_type": "relative", + "percent": "20", + "day": null + } + ] + } + """, + ) + response = client.get_portfolio_by_id(portfolio_id=ptf_id) + + assert reqmock.called_once + assert isinstance(response, Portfolio) + assert response.id == ptf_id + + +def test_update_portfolio_by_id(reqmock: Mocker, client: BrokerClient) -> None: + """Test the update_portfolio_by_id method.""" + ptf_id = UUID("57d4ec79-9658-4916-9eb1-7c672be97e3e") + reqmock.patch( + f"{BaseURL.BROKER_SANDBOX.value}/v1/rebalancing/portfolios/{ptf_id}", + text=""" + { + "id": "57d4ec79-9658-4916-9eb1-7c672be97e3e", + "name": "My Portfolio", + "description": "Some description", + "status": "active", + "cooldown_days": 2, + "created_at": "2022-07-28T20:33:59.665962Z", + "updated_at": "2022-07-28T20:33:59.786528Z", + "weights": [ + { + "type": "cash", + "percent": "10" + }, + { + "type": "asset", + "symbol": "GOOG", + "percent": "90" + } + ], + "rebalance_conditions": [ + { + "type": "drift_band", + "sub_type": "absolute", + "percent": "5", + "day": null + }, + { + "type": "drift_band", + "sub_type": "relative", + "percent": "20", + "day": null + } + ] + } + """, + ) + response = client.update_portfolio_by_id( + portfolio_id=ptf_id, + update_request=UpdatePortfolioRequest( + **{ + "weights": [ + {"type": "cash", "percent": "10"}, + {"type": "asset", "symbol": "GOOG", "percent": "90"}, + ] + } + ), + ) + + assert reqmock.called_once + assert isinstance(response, Portfolio) + assert response.id == ptf_id + assert response.weights[0].type == WeightType.CASH + assert response.weights[0].percent == 10 + assert response.weights[1].type == WeightType.ASSET + assert response.weights[1].percent == 90 + + +def test_inactivate_portfolio_by_id(reqmock: Mocker, client: BrokerClient) -> None: + """Test the inactivate_portfolio_by_id method.""" + ptf_id = UUID("57d4ec79-9658-4916-9eb1-7c672be97e3e") + reqmock.delete(f"{BaseURL.BROKER_SANDBOX.value}/v1/rebalancing/portfolios/{ptf_id}") + client.inactivate_portfolio_by_id( + portfolio_id=ptf_id, + ) + assert reqmock.called_once + + +def test_create_subscription(reqmock: Mocker, client: BrokerClient) -> None: + """Test to create a portfolio subscription.""" + reqmock.post( + f"{BaseURL.BROKER_SANDBOX.value}/v1/rebalancing/subscriptions", + text=""" + { + "id": "2ded098b-ee17-4f48-9496-f8b66e3627aa", + "account_id": "bf2b0f93-f296-4276-a9cf-288586cf4fb7", + "portfolio_id": "57d4ec79-9658-4916-9eb1-7c672be97e3e", + "last_rebalanced_at": null, + "created_at": "2022-08-06T19:34:43.428080852-04:00" + } + """, + ) + subscription_request = CreateSubscriptionRequest( + **{ + "account_id": "bf2b0f93-f296-4276-a9cf-288586cf4fb7", + "portfolio_id": "57d4ec79-9658-4916-9eb1-7c672be97e3e", + } + ) + subscription = client.create_subscription(subscription_request) + + assert reqmock.called_once + assert isinstance(subscription, Subscription) + + +def test_get_all_subscriptions(reqmock: Mocker, client: BrokerClient) -> None: + """Test the get_all_subscriptions method.""" + reqmock.get( + f"{BaseURL.BROKER_SANDBOX.value}/v1/rebalancing/subscriptions", + text=""" + { + "subscriptions": [ + { + "id": "9341be15-8786-4d23-ba1a-fc10ef4f90f4", + "account_id": "bf2b0f93-f296-4276-a9cf-288586cf4fb7", + "portfolio_id": "57d4ec79-9658-4916-9eb1-7c672be97e3e", + "last_rebalanced_at": null, + "created_at": "2022-08-07T23:52:05.942964Z" + } + ], + "next_page_token": null +} + """, + ) + response = client.get_all_subscriptions(filter=GetSubscriptionsRequest()) + + assert reqmock.called_once + assert len(response) > 0 + assert isinstance(response[0], Subscription) + + +def test_get_subscription_by_id(reqmock: Mocker, client: BrokerClient) -> None: + """Test the get_subscription_by_id method.""" + sub_id = UUID("9341be15-8786-4d23-ba1a-fc10ef4f90f4") + reqmock.get( + f"{BaseURL.BROKER_SANDBOX.value}/v1/rebalancing/subscriptions/{sub_id}", + text=""" + { + "id": "9341be15-8786-4d23-ba1a-fc10ef4f90f4", + "account_id": "bf2b0f93-f296-4276-a9cf-288586cf4fb7", + "portfolio_id": "57d4ec79-9658-4916-9eb1-7c672be97e3e", + "last_rebalanced_at": null, + "created_at": "2022-08-07T23:52:05.942964Z" + } + """, + ) + response = client.get_subscription_by_id(subscription_id=sub_id) + + assert reqmock.called_once + assert isinstance(response, Subscription) + assert response.id == sub_id + + +def test_unsubscribe_account(reqmock: Mocker, client: BrokerClient) -> None: + """Test the unsubscribe_account method.""" + sub_id = UUID("9341be15-8786-4d23-ba1a-fc10ef4f90f4") + reqmock.delete( + f"{BaseURL.BROKER_SANDBOX.value}/v1/rebalancing/subscriptions/{sub_id}" + ) + client.unsubscribe_account( + subscription_id=sub_id, + ) + assert reqmock.called_once + + +def test_create_manual_run(reqmock: Mocker, client: BrokerClient) -> None: + """Test to create a portfolio subscription.""" + reqmock.post( + f"{BaseURL.BROKER_SANDBOX.value}/v1/rebalancing/runs", + text=""" + { + "id": "b4f32f6f-f8b3-4f8e-9b36-30b560000bfa", + "type": "full_rebalance", + "amount": null, + "initiated_from": "api", + "status": "QUEUED", + "reason": null, + "account_id": "bf2b0f93-f296-4276-a9cf-288586cf4fb7", + "portfolio_id": "448ba7b3-2fda-4d8e-ac9f-61ff2aa36c60", + "weights": [ + { + "type": "asset", + "symbol": "AAPL", + "percent": "35" + }, + { + "type": "asset", + "symbol": "TSLA", + "percent": "20" + }, + { + "type": "asset", + "symbol": "SPY", + "percent": "45" + } + ], + "orders": [], + "completed_at": null, + "canceled_at": null, + "created_at": "2023-10-17T10:16:55.582507Z", + "updated_at": "2023-10-17T10:16:55.582507Z" + } + """, + ) + run_req = CreateRunRequest( + **{ + "account_id": "bf2b0f93-f296-4276-a9cf-288586cf4fb7", + "type": "full_rebalance", + "weights": [ + {"type": "asset", "symbol": "AAPL", "percent": "35"}, + {"type": "asset", "symbol": "TSLA", "percent": "20"}, + {"type": "asset", "symbol": "SPY", "percent": "45"}, + ], + } + ) + run_resp = client.create_manual_run(run_req) + + assert reqmock.called_once + assert isinstance(run_resp, RebalancingRun) + + +def test_get_all_runs(reqmock: Mocker, client: BrokerClient) -> None: + """Test the get_all_runs method.""" + reqmock.get( + f"{BaseURL.BROKER_SANDBOX.value}/v1/rebalancing/runs", + response_list=[ + { + "text": """ + { + "runs": [ + { + "id": "2ad28f83-796c-4c4d-895e-d360aeb95297", + "type": "full_rebalance", + "amount": null, + "initiated_from": "system", + "status": "CANCELED", + "reason": "create OMS order: create closed order exceeded max retries: order not a filled state", + "account_id": "cf175fbc-ca19-4741-88ed-70d6c133f8d7", + "portfolio_id": "ac89fa84-e3bb-48ff-9b81-d5f313e77463", + "weights": [ + { + "type": "cash", + "symbol": null, + "percent": "5" + }, + { + "type": "asset", + "symbol": "SPY", + "percent": "60" + }, + { + "type": "asset", + "symbol": "TLT", + "percent": "35" + } + ], + "orders": [], + "completed_at": null, + "canceled_at": null, + "created_at": "2022-04-14T10:46:08.045817Z", + "updated_at": "2022-04-14T13:11:07.84719Z" + } + ], + "next_page_token": 1 + } + """ + }, + { + "text": """ + { + "runs": [ + { + "id": "2ad28f83-796c-4c4d-895e-d360aeb95297", + "type": "full_rebalance", + "amount": null, + "initiated_from": "system", + "status": "CANCELED", + "reason": "create OMS order: create closed order exceeded max retries: order not a filled state", + "account_id": "cf175fbc-ca19-4741-88ed-70d6c133f8d7", + "portfolio_id": "ac89fa84-e3bb-48ff-9b81-d5f313e77463", + "weights": [ + { + "type": "cash", + "symbol": null, + "percent": "5" + }, + { + "type": "asset", + "symbol": "SPY", + "percent": "60" + }, + { + "type": "asset", + "symbol": "TLT", + "percent": "35" + } + ], + "orders": [], + "skipped_orders": [ + { + "symbol": "SPY", + "side": "buy", + "notional": "0", + "currency": "USD", + "reason": "ORDER_LESS_THAN_MIN_NOTIONAL", + "reason_details": "order notional value ($0) is less than min-notional set for correspondent ($1)" + } + ], + "completed_at": null, + "canceled_at": null, + "created_at": "2022-04-14T10:46:08.045817Z", + "updated_at": "2022-04-14T13:11:07.84719Z" + } + ], + "next_page_token": null + } + """ + }, + ], + ) + response = client.get_all_runs(filter=GetRunsRequest()) + + assert reqmock.call_count == 2 + assert len(response) == 2 + assert isinstance(response[0], RebalancingRun) + + +def test_get_run_by_id(reqmock: Mocker, client: BrokerClient) -> None: + """Test the get_run_by_id method.""" + run_id = UUID("2ad28f83-796c-4c4d-895e-d360aeb95297") + reqmock.get( + f"{BaseURL.BROKER_SANDBOX.value}/v1/rebalancing/runs/{run_id}", + text=""" + { + "id": "2ad28f83-796c-4c4d-895e-d360aeb95297", + "type": "full_rebalance", + "amount": null, + "initiated_from": "system", + "status": "CANCELED", + "reason": "create OMS order: create closed order exceeded max retries: order not a filled state", + "account_id": "cf175fbc-ca19-4741-88ed-70d6c133f8d7", + "portfolio_id": "ac89fa84-e3bb-48ff-9b81-d5f313e77463", + "weights": [ + { + "type": "cash", + "symbol": null, + "percent": "5" + }, + { + "type": "asset", + "symbol": "SPY", + "percent": "60" + }, + { + "type": "asset", + "symbol": "TLT", + "percent": "35" + } + ], + "orders": [], + "skipped_orders": [ + { + "symbol": "SPY", + "side": "buy", + "notional": "0", + "currency": "USD", + "reason": "ORDER_LESS_THAN_MIN_NOTIONAL", + "reason_details": "order notional value ($0) is less than min-notional set for correspondent ($1)" + } + ], + "completed_at": null, + "canceled_at": null, + "created_at": "2022-04-14T10:46:08.045817Z", + "updated_at": "2022-04-14T13:11:07.84719Z" + } + """, + ) + response = client.get_run_by_id(run_id=run_id) + + assert reqmock.called_once + assert isinstance(response, RebalancingRun) + assert response.id == run_id + + +def test_cancel_run_by_id(reqmock: Mocker, client: BrokerClient) -> None: + """Test the cancel_run_by_id method.""" + run_id = UUID("9341be15-8786-4d23-ba1a-fc10ef4f90f4") + reqmock.delete(f"{BaseURL.BROKER_SANDBOX.value}/v1/rebalancing/runs/{run_id}") + client.cancel_run_by_id(run_id=run_id) + assert reqmock.called_once