Skip to content

Commit

Permalink
Rewritten Composite. Now should run more smoothly. also updated callb…
Browse files Browse the repository at this point in the history
…acks and tested that example runs
  • Loading branch information
perara committed Oct 21, 2024
1 parent df55ecb commit 79cb1c9
Show file tree
Hide file tree
Showing 5 changed files with 358 additions and 150 deletions.
4 changes: 2 additions & 2 deletions examples/composite/TMCompositeCIFAR10Demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def main(args):

class TMCompositeCheckpointCallback(TMCompositeCallback):

def on_epoch_component_begin(self, component, epoch, logs=None):
def on_epoch_component_begin(self, component, epoch, logs=None, **kwargs):
pass

def on_epoch_component_end(self, component, epoch, logs=None):
def on_epoch_component_end(self, component, epoch, logs=None, **kwargs):
component.save(component_path / f"{component}-{epoch}.pkl")

class TMCompositeEvaluationCallback(TMCompositeCallback):
Expand Down
34 changes: 23 additions & 11 deletions tmu/composite/callbacks/base.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,49 @@
from dataclasses import dataclass
from enum import auto, Enum
from multiprocessing import Queue
from typing import Any, Dict


class TMCompositeCallback:
class CallbackMethod(Enum):
ON_TRAIN_COMPOSITE_BEGIN = auto()
ON_TRAIN_COMPOSITE_END = auto()
ON_EPOCH_COMPONENT_BEGIN = auto()
ON_EPOCH_COMPONENT_END = auto()
UPDATE_PROGRESS = auto()

@dataclass
class CallbackMessage:
method: CallbackMethod
kwargs: Dict[str, Any]

class TMCompositeCallback:
def __init__(self):
pass

def on_epoch_component_begin(self, component, epoch, logs=None):
def on_epoch_component_begin(self, component, epoch, logs=None, **kwargs):
pass

def on_epoch_component_end(self, component, epoch, logs=None):
def on_epoch_component_end(self, component, epoch, logs=None, **kwargs):
pass

def on_train_composite_end(self, composite, logs=None):
def on_train_composite_end(self, composite, logs=None, **kwargs):
pass

def on_train_composite_begin(self, composite, logs=None):
def on_train_composite_begin(self, composite, logs=None, **kwargs):
pass


class TMCompositeCallbackProxy:

def __init__(self, queue: Queue):
self.queue = queue

def on_epoch_component_begin(self, component, epoch, logs=None):
self.queue.put(('on_epoch_component_begin', component, epoch, logs))
self.queue.put(CallbackMessage(CallbackMethod.ON_EPOCH_COMPONENT_BEGIN, {'component': component, 'epoch': epoch, 'logs': logs}))

def on_epoch_component_end(self, component, epoch, logs=None):
self.queue.put(('on_epoch_component_end', component, epoch, logs))
self.queue.put(CallbackMessage(CallbackMethod.ON_EPOCH_COMPONENT_END, {'component': component, 'epoch': epoch, 'logs': logs}))

def on_train_composite_end(self, composite, logs=None):
self.queue.put(('on_train_composite_end', composite, logs))
self.queue.put(CallbackMessage(CallbackMethod.ON_TRAIN_COMPOSITE_END, {'composite': composite, 'logs': logs}))

def on_train_composite_begin(self, composite, logs=None):
self.queue.put(('on_train_composite_begin', composite, logs))
self.queue.put(CallbackMessage(CallbackMethod.ON_TRAIN_COMPOSITE_BEGIN, {'composite': composite, 'logs': logs}))
33 changes: 20 additions & 13 deletions tmu/composite/components/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import abc
import uuid

import numpy as np
from pathlib import Path
from typing import Union, Tuple
Expand All @@ -13,6 +15,7 @@ def __init__(self, model_cls, model_config, epochs=1, **kwargs) -> None:
self.model_cls = model_cls
self.model_config = model_config
self.epochs = epochs
self.uuid = uuid.uuid4()

# Warn about unused kwargs
if kwargs:
Expand All @@ -36,19 +39,23 @@ def preprocess(self, data: dict) -> dict:
return data

def fit(self, data: dict) -> None:
x_train, y_train = data["X"], data["Y"]

# Check if type is uint32
if x_train.dtype != np.uint32:
x_train = x_train.astype(np.uint32)

if y_train.dtype != np.uint32:
y_train = y_train.astype(np.uint32)

self.model_instance.fit(
x_train,
y_train,
)
try:
x_train, y_train = data["X"], data["Y"]

# Check if type is uint32
if x_train.dtype != np.uint32:
x_train = x_train.astype(np.uint32)

if y_train.dtype != np.uint32:
y_train = y_train.astype(np.uint32)

self.model_instance.fit(
x_train,
y_train,
)
except Exception as e:
print(f"Error: {e}")
raise e

def predict(self, data: dict) -> Tuple[np.array, np.array]:
X_test = data["X"]
Expand Down
47 changes: 35 additions & 12 deletions tmu/composite/components/color_thermometer_scoring.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,49 @@
import numpy as np
from typing import Dict, Any
from tmu.composite.components.base import TMComponent


class ColorThermometerComponent(TMComponent):

def __init__(self, model_cls, model_config, resolution=8, **kwargs) -> None:
super().__init__(model_cls=model_cls, model_config=model_config, **kwargs)
if resolution < 2 or resolution > 255:
raise ValueError("Resolution must be between 2 and 255")
self.resolution = resolution
self._thresholds = None

def _create_thresholds(self) -> None:
self._thresholds = np.linspace(0, 255, self.resolution + 1)[1:-1]

def preprocess(self, data: dict):
def preprocess(self, data: dict) -> Dict[str, Any]:
super().preprocess(data=data)
X_org = data.get("X")
Y = data.get("Y")

if X_org is None:
raise ValueError("Input data 'X' is missing")

if X_org.ndim != 4:
raise ValueError(f"Expected 4D input, got {X_org.ndim}D")

if X_org.shape[-1] != 3:
raise ValueError(f"Expected 3 color channels, got {X_org.shape[-1]}")

if self._thresholds is None:
self._create_thresholds()

X_org = data["X"]
Y = data["Y"]
# Use broadcasting for efficient computation
X = (X_org[:, :, :, :, np.newaxis] >= self._thresholds).astype(np.uint8)

X = np.empty((X_org.shape[0], X_org.shape[1], X_org.shape[2], X_org.shape[3], self.resolution), dtype=np.uint8)
for z in range(self.resolution):
X[:, :, :, :, z] = X_org[:, :, :, :] >= (z + 1) * 255 / (self.resolution + 1)
# Reshape correctly
batch_size, height, width, channels, _ = X.shape
X = X.transpose(0, 1, 2, 4, 3).reshape(batch_size, height, width, channels * (self.resolution - 1))

X = X.reshape((X_org.shape[0], X_org.shape[1], X_org.shape[2], 3 * self.resolution))
return {
"X": X,
"Y": Y
}

return dict(
X=X,
Y=Y,
)
def get_output_shape(self, input_shape: tuple) -> tuple:
if len(input_shape) != 4:
raise ValueError(f"Expected 4D input shape, got {len(input_shape)}D")
return (*input_shape[:-1], input_shape[-1] * (self.resolution - 1))
Loading

0 comments on commit 79cb1c9

Please sign in to comment.