Skip to content

Commit

Permalink
Format parity_plot correctly (#210)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcosfelt authored Sep 30, 2022
1 parent d215362 commit e04d0fa
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions summit/benchmarks/experimental_emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def _caclulate_input_dimensions(domain: Domain, descriptors_features):

@staticmethod
def _create_input_preprocessor(domain, **kwargs):
"""Create feature preprocessors """
"""Create feature preprocessors"""
transformers = []
# Numeric transforms
numeric_features = [
Expand Down Expand Up @@ -503,7 +503,7 @@ def _create_input_preprocessor(domain, **kwargs):

@staticmethod
def _create_output_preprocessor(output_variable_names):
""""Create target preprocessors"""
""" "Create target preprocessors"""
transformers = [
("scale", StandardScaler(), output_variable_names),
("dst", FunctionTransformer(numpy_to_tensor), output_variable_names),
Expand Down Expand Up @@ -865,13 +865,13 @@ def make_parity_plot(
handles = []
r2_train = r2_score(y_train, y_train_pred)
r2_train_patch = mpatches.Patch(
label=f"Train R2 = {r2_train:.2f}", color=train_color
label=r"Train $R^2$ =" + f"{r2_train:.2f}", color=train_color
)
handles.append(r2_train_patch)
if y_test is not None:
r2_test = r2_score(y_test, y_test_pred)
r2_test_patch = mpatches.Patch(
label=f"Test R2 = {r2_test:.2f}", color=test_color
label=r"Test $R^2$ =" + f"{r2_test:.2f}", color=test_color
)
handles.append(r2_test_patch)

Expand All @@ -888,7 +888,7 @@ def make_parity_plot(


def numpy_to_tensor(X):
"""Convert datasets into """
"""Convert datasets into"""
if issparse(X):
X = X.todense()
return torch.tensor(X).float()
Expand Down

0 comments on commit e04d0fa

Please sign in to comment.