WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content

Commit eaa6b03

Browse files
authored
Merge pull request #938 from alan-turing-institute/fix_test_set
Let user specify test set
2 parents 9ae3626 + 863c66f commit eaa6b03

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

autoemulate/core/compare.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(
5555
self,
5656
x: InputLike,
5757
y: InputLike,
58+
test_data: tuple[InputLike, InputLike] | None = None,
5859
models: list[type[Emulator] | str] | None = None,
5960
x_transforms_list: list[list[Transform | dict]] | None = None,
6061
y_transforms_list: list[list[Transform | dict]] | None = None,
@@ -81,6 +82,9 @@ def __init__(
8182
Input features.
8283
y: InputLike or None
8384
Target values (not needed if x is a Dataset).
85+
test_data: tuple[InputLike, InputLike] | None
86+
Optional test data as a tuple (x_test, y_test). If None, a random split
87+
from the provided data is used. Defaults to None.
8488
models: list[type[Emulator]] | None
8589
List of emulator classes to compare. If None, all available emulators
8690
are used.
@@ -164,7 +168,17 @@ def __init__(
164168
self.models = updated_models
165169
if random_seed is not None:
166170
set_random_seed(seed=random_seed)
167-
self.train_val, self.test = self._random_split(self._convert_to_dataset(x, y))
171+
172+
if test_data is None:
173+
self.train_val, self.test = self._random_split(
174+
self._convert_to_dataset(x, y)
175+
)
176+
else:
177+
self.train_val = self._convert_to_dataset(x, y)
178+
test_x, test_y = self._move_tensors_to_device(
179+
*self._convert_to_tensors(*test_data)
180+
)
181+
self.test = self._convert_to_dataset(test_x, test_y)
168182

169183
# Run the compare method with the provided models
170184
if not self.models:

tests/core/test_compare.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from autoemulate.emulators import DEFAULT_EMULATORS
1111
from autoemulate.emulators.base import Emulator
1212
from torch.distributions import Transform
13+
from torch.utils.data import TensorDataset
1314

1415

1516
@pytest.mark.parametrize("device", SUPPORTED_DEVICES)
@@ -462,3 +463,29 @@ def __call__(
462463
metric_names = [m.name for m in result.test_metrics]
463464
assert "custom_r2" in metric_names
464465
assert "rmse" in metric_names
466+
467+
468+
def test_ae_with_fixed_test_data(sample_data_for_ae_compare):
469+
"""Test AutoEmulate with a fixed test dataset."""
470+
x, y = sample_data_for_ae_compare
471+
models: list[str | type[Emulator]] = ["mlp", "RandomForest"]
472+
473+
# Create fixed test set
474+
test_size = 25
475+
x_test, y_test = x[:test_size], y[:test_size]
476+
x_train, y_train = x[test_size:], y[test_size:]
477+
478+
ae = AutoEmulate(
479+
x_train,
480+
y_train,
481+
models=models,
482+
test_data=(x_test, y_test),
483+
n_iter=2,
484+
n_splits=2,
485+
model_params={}, # Skip tuning for speed
486+
)
487+
488+
assert isinstance(ae.test, TensorDataset)
489+
assert ae.test.tensors == (x_test, y_test)
490+
assert isinstance(ae.train_val, TensorDataset)
491+
assert ae.train_val.tensors == (x_train, y_train)

0 commit comments

Comments
 (0)