Skip to content

Commit

Permalink
made modifications in line with Lars comments and they passed the tes…
Browse files Browse the repository at this point in the history
…ts that I had previously written
  • Loading branch information
eodole committed Jan 4, 2025
1 parent 2de69c9 commit 9f6a511
Showing 1 changed file with 16 additions and 45 deletions.
61 changes: 16 additions & 45 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,66 +76,37 @@ def __call__(self, data: dict[str, any], *, inverse: bool = False, **kwargs) ->
return self.forward(data, **kwargs)

def __repr__(self):
str_transf = ''
for i in range(0, len(self.transforms)):
str_transf = str_transf + str(i) + ': ' + repr(self.transforms[i])
if i != len(self.transforms) - 1:
str_transf = str_transf + ' -> '
return f"Adapter([{str_transf}])"
str_transf = ""
if isinstance(self.transforms, list):
for i in range(0, len(self.transforms)):
str_transf = str_transf + str(i) + ": " + repr(self.transforms[i])
if i != len(self.transforms) - 1:
str_transf = str_transf + " -> "
return f"Adapter([{str_transf}])"
else:
return f"Adapter([ 0: {repr(self.transforms)}])"

def __getitem__(self, index):
if isinstance(index, slice):
if index.start > index.stop:
raise IndexError("Index slice must be positive integers such that a < b for adapter[a:b]")
if index.stop < len(self.transforms):
sliced_transforms = self.transforms[index]
new_adapter = Adapter(transforms=sliced_transforms)
return new_adapter
else:
raise IndexError("Index slice out of range")

elif isinstance(index, int):
if index < 0:
index = index + len(self.transforms) # negative indexing
if index < 0 or index >= len(self.transforms):
raise IndexError("Adapter index out of range.")
sliced_transforms = self.transforms[index]
new_adapter = Adapter(transforms=sliced_transforms)
return new_adapter
else:
raise TypeError("Invalid index type. Must be int or slice.")
return Adapter(transforms=self.transforms[index])

def __setitem__(self, index, new_value):
if not isinstance(new_value, Adapter):
raise TypeError("new_value must be an Adapter instance")

new_transform = new_value.transforms
# new_transform = new_value.transforms

if len(new_transform) == 0:
# To be tested
if len(new_value.transforms) == 0:
raise ValueError(
"new_value is an Adapter instance without any specified transforms, new_value Adapter must contain at least one transform."
)

if isinstance(index, slice):
if index.start > index.stop:
raise IndexError("Index slice must be positive integers such that a < b for adapter[a:b]")

if index.stop < len(self.transforms):
self.transforms[index] = new_transform

else:
raise IndexError("Index slice out of range")
self.transforms[index] = new_value.transforms[:]

elif isinstance(index, int):
if index < 0: # negative indexing
index = index + len(self.transforms)

if index < 0 or index >= len(self.transforms):
raise IndexError("Index out of range.")
# could add that if the index is out of range, like index == len
# then we just add the transform
self.transforms[index : index + 1] = new_value.transforms[:]

self.transforms[index] = new_transform
else:
raise TypeError("Invalid index type. Must be int or slice.")

Expand Down Expand Up @@ -165,7 +136,7 @@ def apply(
self.transforms.append(transform)
return self

# Begin of transformed derived from transform classes
# Begin of transforms derived from transform classes
def as_set(self, keys: str | Sequence[str]):
if isinstance(keys, str):
keys = [keys]
Expand Down

0 comments on commit 9f6a511

Please sign in to comment.