Skip to content

Commit

Permalink
chore: fix update endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoreira-valory committed Dec 20, 2024
1 parent ae546a7 commit b529640
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 43 deletions.
6 changes: 4 additions & 2 deletions operate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,13 +776,14 @@ async def _update_service(request: Request) -> JSONResponse:
service_config_id=service_config_id,
service_template=template,
allow_different_service_public_id=allow_different_service_public_id,
partial_update=False,
)

return JSONResponse(content=output.json)

@app.patch("/api/v2/service/{service_config_id}")
@with_retries
async def _partial_update_service(request: Request) -> JSONResponse:
async def _partial_update_service(request: Request) -> JSONResponse:
"""Partially update a service (merge update)."""
if operate.password is None:
return USER_NOT_LOGGED_IN_ERROR
Expand All @@ -797,10 +798,11 @@ async def _partial_update_service(request: Request) -> JSONResponse:
allow_different_service_public_id = template.get(
"allow_different_service_public_id", False
)
output = manager.partial_update(
output = manager.update(
service_config_id=service_config_id,
service_template=template,
allow_different_service_public_id=allow_different_service_public_id,
partial_update=True,
)

return JSONResponse(content=output.json)
Expand Down
20 changes: 6 additions & 14 deletions operate/services/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1689,25 +1689,17 @@ def update(
service_config_id: str,
service_template: ServiceTemplate,
allow_different_service_public_id: bool = False,
partial_update: bool = True,
) -> Service:
"""Update a service."""

self.logger.info(f"Updating {service_config_id=}")
service = self.load(service_config_id=service_config_id)
service.update(service_template, allow_different_service_public_id)
return service

def update(
self,
service_config_id: str,
service_template: ServiceTemplate,
allow_different_service_public_id: bool = False,
) -> Service:
"""Update a service."""

self.logger.info(f"Updating {service_config_id=}")
service = self.load(service_config_id=service_config_id)
service.update(service_template, allow_different_service_public_id)
service.update(
service_template=service_template,
allow_different_service_public_id=allow_different_service_public_id,
partial_update=partial_update,
)
return service

def update_all_matching(
Expand Down
72 changes: 45 additions & 27 deletions operate/services/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
SERVICE_CONFIG_VERSION = 4
SERVICE_CONFIG_PREFIX = "sc-"

DUMMY_MULTISIG = "0xm"
NON_EXISTENT_MULTISIG = "0xm"
NON_EXISTENT_TOKEN = -1

DEFAULT_TRADER_ENV_VARS = {
Expand Down Expand Up @@ -897,7 +897,7 @@ def new( # pylint: disable=too-many-locals
chain_data = OnChainData(
instances=[],
token=NON_EXISTENT_TOKEN,
multisig=DUMMY_MULTISIG,
multisig=NON_EXISTENT_MULTISIG,
staked=False,
on_chain_state=OnChainState.NON_EXISTENT,
user_params=OnChainUserParams.from_json(config), # type: ignore
Expand Down Expand Up @@ -981,47 +981,65 @@ def update(
self,
service_template: ServiceTemplate,
allow_different_service_public_id: bool = False,
partial_update: bool = False,
) -> None:
"""Update service."""

target_hash = service_template["hash"]
target_service_public_id = Service.get_service_public_id(target_hash, self.path)

if not allow_different_service_public_id and (
self.service_public_id() != target_service_public_id
):
raise ValueError(
f"Trying to update a service with a different public id: {self.service_public_id()=} {self.hash=} {target_service_public_id=} {target_hash=}."
target_hash = service_template.get("hash")
if target_hash:
target_service_public_id = Service.get_service_public_id(
target_hash, self.path
)

if not allow_different_service_public_id and (
self.service_public_id() != target_service_public_id
):
raise ValueError(
f"Trying to update a service with a different public id: {self.service_public_id()=} {self.hash=} {target_service_public_id=} {target_hash=}."
)

self.hash = service_template.get("hash", self.hash)

# hash_history - Only update if latest inserted hash is different
if self.hash_history[max(self.hash_history.keys())] != self.hash:
current_timestamp = int(time.time())
self.hash_history[current_timestamp] = self.hash

self.home_chain = service_template.get("home_chain", self.home_chain)
self.description = service_template.get("description", self.description)
self.name = service_template.get("name", self.name)

shutil.rmtree(self.service_path)
service_path = Path(
IPFSTool().download(
hash_id=service_template["hash"],
hash_id=self.hash,
target_dir=self.path,
)
)
self.service_path = service_path
self.name = service_template["name"]
self.hash = service_template["hash"]
self.description = service_template["description"]

# TODO temporarily disable update env variables - hotfix for Memeooorr
# self.env_variables = service_template["env_variables"]

# Only update hash_history if latest inserted hash is different
if self.hash_history[max(self.hash_history.keys())] != service_template["hash"]:
current_timestamp = int(time.time())
self.hash_history[current_timestamp] = service_template["hash"]

self.home_chain = service_template["home_chain"]
# env_variables
if partial_update:
for var, attrs in service_template.get("env_variables", {}).items():
self.env_variables.setdefault(var, {}).update(attrs)
else:
self.env_variables = service_template["env_variables"]

# chain_configs
# TODO support remove chains for non-partial updates
# TODO ensure all and only existing chains are passed for non-partial updates
ledger_configs = ServiceHelper(path=self.service_path).ledger_configs()
for chain, config in service_template["configurations"].items():
for chain, new_config in service_template.get("configurations", {}).items():
if chain in self.chain_configs:
# The template is providing a chain configuration that already
# exists in this service - update only the user parameters.
# This is to avoid losing on-chain data like safe, token, etc.
if partial_update:
config = self.chain_configs[chain].chain_data.user_params.json
config.update(new_config)
else:
config = new_config

self.chain_configs[
chain
].chain_data.user_params = OnChainUserParams.from_json(
Expand All @@ -1032,15 +1050,15 @@ def update(
# not currently exist in this service - copy all config as
# when creating a new service.
ledger_config = ledger_configs[chain]
ledger_config.rpc = config["rpc"]
ledger_config.rpc = new_config["rpc"]

chain_data = OnChainData(
instances=[],
token=NON_EXISTENT_TOKEN,
multisig=DUMMY_MULTISIG,
multisig=NON_EXISTENT_MULTISIG,
staked=False,
on_chain_state=OnChainState.NON_EXISTENT,
user_params=OnChainUserParams.from_json(config), # type: ignore
user_params=OnChainUserParams.from_json(new_config), # type: ignore
)

self.chain_configs[chain] = ChainConfig(
Expand Down
182 changes: 182 additions & 0 deletions tests/test_services_manage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# -*- coding: utf-8 -*-
# ------------------------------------------------------------------------------
#
# Copyright 2024 Valory AG
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ------------------------------------------------------------------------------

"""Tests for services.service module."""

import random
import string
import typing as t
from pathlib import Path

import pytest
from deepdiff import DeepDiff

from operate.cli import OperateApp
from operate.operate_types import ServiceTemplate
from .test_services_service import DEFAULT_CONFIG_KWARGS

ROOT_PATH = Path(__file__).resolve().parent
OPERATE_HOME = ROOT_PATH / ".operate_test"

from operate.services.service import Service


@pytest.fixture
def random_string() -> str:
length = 8
chars = string.ascii_letters + string.digits
return "".join(random.choices(chars, k=length))


def get_template(**kwargs: t.Any) -> ServiceTemplate:
"""get_template"""

return {
"name": kwargs.get("name"),
"hash": kwargs.get("hash"),
"description": kwargs.get("description"),
"image": "https://image_url",
"service_version": "",
"home_chain": "gnosis",
"configurations": {
"gnosis": {
"staking_program_id": kwargs.get("staking_program_id"),
"nft": kwargs.get("nft"),
"rpc": "http://localhost:8545",
"threshold": kwargs.get("threshold"),
"agent_id": kwargs.get("agent_id"),
"use_staking": kwargs.get("use_staking"),
"use_mech_marketplace": kwargs.get("use_mech_marketplace"),
"cost_of_bond": kwargs.get("cost_of_bond"),
"fund_requirements": {
"agent": kwargs.get("fund_requirements_agent"),
"safe": kwargs.get("fund_requirements_safe"),
},
"fallback_chain_params": {}
}
},
"env_variables": {
"VAR1": {
"name": "var1_name",
"description": "var1_description",
"value": "var1_value",
"provision_type": "var1_provision_type",
},
"VAR2": {
"name": "var2_name",
"description": "var2_description",
"value": "var2_value",
"provision_type": "var2_provision_type",
},
},
}


class TestServiceManager:
"""Tests for services.manager.ServiceManager class."""

@pytest.mark.parametrize("update_new_var", [True])
@pytest.mark.parametrize("update_update_var", [True])
@pytest.mark.parametrize("update_name", [True])
@pytest.mark.parametrize("update_description", [True])
@pytest.mark.parametrize("update_hash", [True])
def test_service_update(
self,
update_new_var: bool,
update_update_var: bool,
update_name: bool,
update_description: bool,
update_hash: bool,
tmp_path: Path,
random_string: str,
) -> None:
"""Test operate.service_manager().update()"""

operate = OperateApp(
home=tmp_path / ".operate_test",
)
operate.setup()
password = random_string
operate.create_user_account(password=password)
operate.password = password
service_manager = operate.service_manager()
service_template = get_template(**DEFAULT_CONFIG_KWARGS)
service = service_manager.create(service_template)
service_config_id = service.service_config_id
service_json = service_manager.load(service_config_id).json

new_hash = "bafybeicts6zhavxzz2rxahz3wzs2pzamoq64n64wp4q4cdanfuz7id6c2q"
VAR2_updated_attributes = {
"name": "var2_name_updated",
"description": "var2_description_updated",
"value": "var2_value_updated",
"provision_type": "var2_provision_type_updated",
"extra_attr": "extra_val",
}

VAR3_attributes = {
"name": "var3_name",
"description": "var3_description",
"value": "var3_value",
"provision_type": "var3_provision_type",
}

# Partial update
update_template: t.Dict = {}
expected_service_json = service_json.copy()

if update_new_var:
update_template["env_variables"] = update_template.get("env_variables", {})
update_template["env_variables"]["VAR3"] = VAR3_attributes
expected_service_json["env_variables"]["VAR3"] = VAR3_attributes

if update_update_var:
update_template["env_variables"] = update_template.get("env_variables", {})
update_template["env_variables"]["VAR2"] = VAR2_updated_attributes
expected_service_json["env_variables"]["VAR2"] = VAR2_updated_attributes

if update_name:
update_template["name"] = "name_updated"
expected_service_json["name"] = "name_updated"

if update_description:
update_template["description"] = "description_updated"
expected_service_json["description"] = "description_updated"

if update_hash:
update_template["hash"] = new_hash
expected_service_json["hash"] = new_hash

service_manager.update(
service_config_id=service_config_id,
service_template=update_template,
allow_different_service_public_id=False,
partial_update=True,
)
service_json = service_manager.load(service_config_id).json

if update_hash:
timestamp = max(service_json["hash_history"].keys())
expected_service_json["hash_history"][timestamp] = new_hash

diff = DeepDiff(service_json, expected_service_json)
if diff:
print(diff)

assert not diff, "Updated service does not match expected service."

0 comments on commit b529640

Please sign in to comment.