diff --git a/ax/analysis/plotly/tests/test_marginal_effects.py b/ax/analysis/plotly/tests/test_marginal_effects.py index 25b299d5f77..107381844cd 100644 --- a/ax/analysis/plotly/tests/test_marginal_effects.py +++ b/ax/analysis/plotly/tests/test_marginal_effects.py @@ -43,7 +43,7 @@ def setUp(self) -> None: self.experiment.trials[i].mark_running(no_runner_required=True) self.experiment.attach_data( Data( - pd.DataFrame( + df=pd.DataFrame( { "trial_index": [i] * num_arms, "arm_name": [f"0_{j}" for j in range(num_arms)], diff --git a/ax/core/base_trial.py b/ax/core/base_trial.py index 649c17c57e9..8e9a4d2853e 100644 --- a/ax/core/base_trial.py +++ b/ax/core/base_trial.py @@ -15,7 +15,7 @@ from typing import Any, TYPE_CHECKING from ax.core.arm import Arm -from ax.core.data import Data, sort_by_trial_index_and_arm_name +from ax.core.data import Data from ax.core.evaluations_to_data import raw_evaluations_to_data from ax.core.generator_run import GeneratorRun, GeneratorRunType from ax.core.metric import Metric, MetricFetchResult @@ -442,8 +442,6 @@ def fetch_data(self, metrics: list[Metric] | None = None, **kwargs: Any) -> Data data = Metric._unwrap_trial_data_multi( results=self.fetch_data_results(metrics=metrics, **kwargs) ) - if not data.has_step_column: - data.full_df = sort_by_trial_index_and_arm_name(data.full_df) return data diff --git a/ax/core/data.py b/ax/core/data.py index 0ebd6fdc01a..e6e98d5a66b 100644 --- a/ax/core/data.py +++ b/ax/core/data.py @@ -14,7 +14,7 @@ from functools import cached_property from io import StringIO from logging import Logger -from typing import Any, TypeVar +from typing import Any import numpy as np import numpy.typing as npt @@ -34,11 +34,40 @@ logger: Logger = get_logger(__name__) -TData = TypeVar("TData", bound="Data") DF_REPR_MAX_LENGTH = 1000 MAP_KEY = "step" +class DataRow: + def __init__( + self, + trial_index: int, + arm_name: str, + metric_name: str, + metric_signature: str, + mean: float, + se: float, + step: float | None = None, + start_time: int | None = None, + end_time: int | None = None, + n: int | None = None, + ) -> None: + self.trial_index: int = trial_index + self.arm_name: str = arm_name + + self.metric_name: str = metric_name + self.metric_signature: str = metric_signature + + self.mean: float = mean + self.se: float = se + + self.step: float | None = step + + self.start_time: int | None = start_time + self.end_time: int | None = end_time + self.n: int | None = n + + class Data(Base, SerializationMixin): """Class storing numerical data for an experiment. @@ -102,8 +131,6 @@ class Data(Base, SerializationMixin): "start_time": pd.Timestamp, "end_time": pd.Timestamp, "n": int, - "frac_nonnull": np.float64, - "random_split": int, MAP_KEY: float, } @@ -116,16 +143,19 @@ class Data(Base, SerializationMixin): "metric_signature", ] - full_df: pd.DataFrame + _data_rows: list[DataRow] def __init__( - self: TData, + self, + data_rows: Iterable[DataRow] | None = None, df: pd.DataFrame | None = None, _skip_ordering_and_validation: bool = False, ) -> None: """Initialize a ``Data`` object from the given DataFrame. Args: + data_rows: Iterable of DataRows. If provided, this will be used as the + source of truth for Data, over df. df: DataFrame with underlying data, and required columns. Data must be unique at the level of ("trial_index", "arm_name", "metric_name"), plus "step" if a "step" column is present. A @@ -136,31 +166,84 @@ def __init__( Intended only for use in `Data.filter`, where the contents of the DataFrame are known to be ordered and valid. """ - if df is None: - # Initialize with barebones DF with expected dtypes - self.full_df = pd.DataFrame.from_dict( + if data_rows is not None: + if isinstance(data_rows, pd.DataFrame): + raise ValueError( + "data_rows must be an iterable of DataRows, not a DataFrame." + ) + self._data_rows = [*data_rows] + elif df is not None: + # Unroll the df into a list of DataRows + if missing_columns := self.REQUIRED_COLUMNS - {*df.columns}: + raise ValueError( + f"Dataframe must contain required columns {list(missing_columns)}." + ) + + self._data_rows = [ + DataRow( + trial_index=row["trial_index"], + arm_name=row["arm_name"], + metric_name=row["metric_name"], + metric_signature=row["metric_signature"], + mean=row["mean"], + se=row["sem"], + step=row.get(MAP_KEY), + start_time=row.get("start_time"), + end_time=row.get("end_time"), + n=row.get("n"), + ) + for _, row in df.iterrows() + ] + else: + self._data_rows = [] + + self._memo_df: pd.DataFrame | None = None + self.has_step_column: bool = any( + row.step is not None for row in self._data_rows + ) + + @cached_property + def full_df(self) -> pd.DataFrame: + """ + Convert the DataRows into a pandas DataFrame. If step, start_time, or end_time + is None for all rows the column will be elided. + """ + if len(self._data_rows) == 0: + return pd.DataFrame.from_dict( { col: pd.Series([], dtype=self.COLUMN_DATA_TYPES[col]) for col in self.REQUIRED_COLUMNS } ) - elif _skip_ordering_and_validation: - self.full_df = df - else: - columns = set(df.columns) - missing_columns = self.REQUIRED_COLUMNS - columns - if missing_columns: - raise ValueError( - f"Dataframe must contain required columns {list(missing_columns)}." - ) - # Drop rows where every input is null. Since `dropna` can be slow, first - # check trial index to see if dropping nulls might be needed. - if df["trial_index"].isnull().any(): - df = df.dropna(axis=0, how="all", ignore_index=True) - df = self._safecast_df(df=df) - self.full_df = self._get_df_with_cols_in_expected_order(df=df) - self._memo_df = None - self.has_step_column = MAP_KEY in self.full_df.columns + + # Detect whether any of the optional attributes are present and should be + # included as columns in the full DataFrame. + include_step = any(row.step is not None for row in self._data_rows) + include_start_time = any(row.start_time is not None for row in self._data_rows) + include_end_time = any(row.end_time is not None for row in self._data_rows) + include_n = any(row.n is not None for row in self._data_rows) + + records = [ + { + "trial_index": row.trial_index, + "arm_name": row.arm_name, + "metric_name": row.metric_name, + "metric_signature": row.metric_signature, + "mean": row.mean, + "sem": row.se, + **({"step": row.step} if include_step else {}), + **({"start_time": row.start_time} if include_start_time else {}), + **({"end_time": row.end_time} if include_end_time else {}), + **({"n": row.n} if include_n else {}), + } + for row in self._data_rows + ] + + return self._get_df_with_cols_in_expected_order( + df=self._safecast_df( + df=pd.DataFrame.from_records(records), + ), + ) @classmethod def _get_df_with_cols_in_expected_order(cls, df: pd.DataFrame) -> pd.DataFrame: @@ -175,7 +258,7 @@ def _get_df_with_cols_in_expected_order(cls, df: pd.DataFrame) -> pd.DataFrame: return df @classmethod - def _safecast_df(cls: type[TData], df: pd.DataFrame) -> pd.DataFrame: + def _safecast_df(cls, df: pd.DataFrame) -> pd.DataFrame: """Function for safely casting df to standard data types. Needed because numpy does not support NaNs in integer arrays. @@ -275,7 +358,7 @@ def df(self) -> pd.DataFrame: return self._memo_df @classmethod - def from_multiple_data(cls: type[TData], data: Iterable[Data]) -> TData: + def from_multiple_data(cls, data: Iterable[Data]) -> Data: """Combines multiple objects into one (with the concatenated underlying dataframe). @@ -339,7 +422,7 @@ def filter( _skip_ordering_and_validation=True, ) - def clone(self: TData) -> TData: + def clone(self) -> Data: """Returns a new Data object with the same underlying dataframe.""" return self.__class__(df=deepcopy(self.full_df)) @@ -347,13 +430,13 @@ def __eq__(self, o: Data) -> bool: return type(self) is type(o) and dataframe_equals(self.full_df, o.full_df) def relativize( - self: TData, + self, status_quo_name: str = "status_quo", as_percent: bool = False, include_sq: bool = False, bias_correction: bool = True, control_as_constant: bool = False, - ) -> TData: + ) -> Data: """Relativize a data object w.r.t. a status_quo arm. Args: @@ -437,12 +520,12 @@ def latest(self, rows_per_group: int = 1) -> Data: ) def subsample( - self: TData, + self, keep_every: int | None = None, limit_rows_per_group: int | None = None, limit_rows_per_metric: int | None = None, include_first_last: bool = True, - ) -> TData: + ) -> Data: """Return a new Data that subsamples the `MAP_KEY` column in an equally-spaced manner. This function considers only the relative ordering of the `MAP_KEY` values, making it most suitable when these values are diff --git a/ax/core/tests/test_data.py b/ax/core/tests/test_data.py index 8567135a126..45365e2a93f 100644 --- a/ax/core/tests/test_data.py +++ b/ax/core/tests/test_data.py @@ -119,29 +119,22 @@ def get_test_dataframe() -> pd.DataFrame: ) -class TestDataBase(TestCase): - # Also run with has_step_column = True below - has_step_column: bool = False +class DataTest(TestCase): + """Tests for Data without a "step" column.""" def setUp(self) -> None: super().setUp() self.data_without_df = Data() - df = get_test_dataframe() - if not self.has_step_column: - self.df = df - self.data_with_df = Data(df=self.df) - else: - df_1 = df.copy().assign(**{MAP_KEY: 0}) - df_2 = df.copy().assign(**{MAP_KEY: 1}) - self.df = pd.concat((df_1, df_2)) - self.data_with_df = Data(df=self.df) - + self.df = get_test_dataframe() + self.data_with_df = Data(df=self.df) self.metric_name_to_signature = {"a": "a_signature", "b": "b_signature"} def test_init(self) -> None: + # Test equality self.assertEqual(self.data_without_df, self.data_without_df) self.assertEqual(self.data_with_df, self.data_with_df) + # Test accessing values df = self.data_with_df.df self.assertEqual( float(df[df["arm_name"] == "0_0"][df["metric_name"] == "a"]["mean"].item()), @@ -152,7 +145,14 @@ def test_init(self) -> None: 0.5, ) - self.assertEqual(self.data_with_df.has_step_column, self.has_step_column) + # Test has_step_column is False + self.assertFalse(self.data_with_df.has_step_column) + + # Test empty initialization + empty = Data() + self.assertTrue(empty.full_df.empty) + self.assertEqual(set(empty.full_df.columns), empty.REQUIRED_COLUMNS) + self.assertFalse(empty.has_step_column) def test_clone(self) -> None: data = self.data_with_df @@ -164,14 +164,9 @@ def test_clone(self) -> None: self.assertIsNot(data, data_clone) self.assertIsNot(data.df, data_clone.df) self.assertIsNone(data_clone._db_id) - if self.has_step_column: - self.assertIsNot(data.full_df, data_clone.full_df) - self.assertTrue(data.full_df.equals(data_clone.full_df)) def test_BadData(self) -> None: data = {"bad_field": "0_0", "bad_field_2": {"x": 0, "y": "a"}} - if self.has_step_column: - data[MAP_KEY] = "0" df = pd.DataFrame([data]) with self.assertRaisesRegex( ValueError, "Dataframe must contain required columns" @@ -184,25 +179,13 @@ def test_EmptyData(self) -> None: self.assertTrue(df.empty) self.assertTrue(Data.from_multiple_data([]).df.empty) - if data.has_step_column: - self.assertTrue(data.full_df.empty) - expected_columns = Data.REQUIRED_COLUMNS.union({MAP_KEY}) - else: - expected_columns = Data.REQUIRED_COLUMNS + expected_columns = Data.REQUIRED_COLUMNS self.assertEqual(set(df.columns), expected_columns) def test_from_multiple_with_generator(self) -> None: data = Data.from_multiple_data(self.data_with_df for _ in range(2)) self.assertEqual(len(data.full_df), 2 * len(self.data_with_df.full_df)) - self.assertEqual(data.has_step_column, self.has_step_column) - - def test_extra_columns(self) -> None: - value = 3 - extra_col_df = self.df.assign(foo=value) - data = Data(df=extra_col_df) - self.assertIn("foo", data.full_df.columns) - self.assertIn("foo", data.df.columns) - self.assertTrue((data.full_df["foo"] == value).all()) + self.assertFalse(data.has_step_column) def test_get_df_with_cols_in_expected_order(self) -> None: with self.subTest("Wrong order"): @@ -235,26 +218,6 @@ def test_trial_indices(self) -> None: set(self.data_with_df.full_df["trial_index"].unique()), ) - -class TestMapData(TestDataBase): - has_step_column = True - - -class DataTest(TestCase): - """Tests that are specific to Data without a "step" column.""" - - def setUp(self) -> None: - super().setUp() - self.df = get_test_dataframe() - self.metric_name_to_signature = {"a": "a_signature", "b": "b_signature"} - - def test_init(self) -> None: - # Initialize empty - empty = Data() - self.assertTrue(empty.full_df.empty) - self.assertEqual(set(empty.full_df.columns), empty.REQUIRED_COLUMNS) - self.assertFalse(empty.has_step_column) - def test_repr(self) -> None: self.assertEqual( str(Data(df=self.df)), @@ -263,13 +226,6 @@ def test_repr(self) -> None: with patch(f"{Data.__module__}.DF_REPR_MAX_LENGTH", 500): self.assertEqual(str(Data(df=self.df)), REPR_500) - def test_OtherClassInequality(self) -> None: - class CustomData(Data): - pass - - data = CustomData(df=self.df) - self.assertNotEqual(data, Data(self.df)) - def test_from_multiple(self) -> None: with self.subTest("Combinining non-empty Data"): data = Data.from_multiple_data([Data(df=self.df), Data(df=self.df)]) @@ -279,34 +235,6 @@ def test_from_multiple(self) -> None: data = Data.from_multiple_data([Data(), Data()]) self.assertEqual(data, Data()) - with self.subTest("Different types"): - - class CustomData(Data): - pass - - data = Data.from_multiple_data([CustomData(), CustomData()]) - self.assertEqual(data, Data()) - data = CustomData.from_multiple_data([Data(), CustomData()]) - self.assertEqual(data, CustomData()) - - def test_FromMultipleDataMismatchedTypes(self) -> None: - # create two custom data types - class CustomDataA(Data): - pass - - class CustomDataB(Data): - pass - - # Test using `Data.from_multiple_data` to combine non-Data types - data = Data.from_multiple_data([CustomDataA(), CustomDataB()]) - self.assertEqual(data, Data()) - - # multiple non-empty types - data_elt_A = CustomDataA(df=self.df) - data_elt_B = CustomDataB(df=self.df) - data = Data.from_multiple_data([data_elt_A, data_elt_B]) - self.assertEqual(len(data.full_df), 2 * len(self.df)) - def test_filter(self) -> None: data = Data(df=self.df) # Test that filter throws when we provide metric names and metric signatures diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index 9488ee8bbf3..9f47be694f6 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -673,7 +673,7 @@ def test_fetch_and_store_data(self) -> None: # Verify we do get the stored data if there are an unimplemented metrics. # Remove attached data for nonexistent metric. - exp.data.full_df = exp.data.full_df.loc[lambda x: x["metric_name"] != "z"] + exp.data = Data(df=exp.data.full_df.loc[lambda x: x["metric_name"] != "z"]) # Remove implemented metric that is `available_while_running` # (and therefore not pulled from cache). @@ -685,7 +685,9 @@ def test_fetch_and_store_data(self) -> None: looked_up_df = looked_up_data.full_df self.assertFalse((looked_up_df["metric_name"] == "z").any()) self.assertTrue( - batch.fetch_data().full_df.equals( + batch.fetch_data() + .full_df.sort_values(["arm_name", "metric_name"], ignore_index=True) + .equals( looked_up_df.loc[lambda x: (x["trial_index"] == 0)].sort_values( ["arm_name", "metric_name"], ignore_index=True ) diff --git a/ax/plot/pareto_utils.py b/ax/plot/pareto_utils.py index 996c5c0fd99..a11af3c67dc 100644 --- a/ax/plot/pareto_utils.py +++ b/ax/plot/pareto_utils.py @@ -207,7 +207,7 @@ def get_observed_pareto_frontiers( ): # Make sure status quo is always included, for derelativization arm_names.append(experiment.status_quo.name) - data = Data(data.df[data.df["arm_name"].isin(arm_names)]) + data = Data(df=data.df[data.df["arm_name"].isin(arm_names)]) adapter = get_tensor_converter_adapter(experiment=experiment, data=data) pareto_observations = observed_pareto_frontier(adapter=adapter) # Convert to ParetoFrontierResults diff --git a/ax/plot/scatter.py b/ax/plot/scatter.py index a3ac1987146..27af4eb3c97 100644 --- a/ax/plot/scatter.py +++ b/ax/plot/scatter.py @@ -1731,7 +1731,7 @@ def tile_observations( if data is None: data = experiment.fetch_data() if arm_names is not None: - data = Data(data.df[data.df["arm_name"].isin(arm_names)]) + data = Data(df=data.df[data.df["arm_name"].isin(arm_names)]) m_ts = Generators.THOMPSON( data=data, search_space=experiment.search_space, diff --git a/ax/plot/tests/test_fitted_scatter.py b/ax/plot/tests/test_fitted_scatter.py index ef13bfd9016..9642dba9cd2 100644 --- a/ax/plot/tests/test_fitted_scatter.py +++ b/ax/plot/tests/test_fitted_scatter.py @@ -33,7 +33,7 @@ def test_fitted_scatter(self) -> None: model = Generators.BOTORCH_MODULAR( # Adapter kwargs experiment=exp, - data=Data.from_multiple_data([data, Data(df)]), + data=Data.from_multiple_data([data, Data(df=df)]), ) # Assert that each type of plot can be constructed successfully scalarized_metric_config = [ diff --git a/ax/plot/tests/test_pareto_utils.py b/ax/plot/tests/test_pareto_utils.py index bc97bbf3891..de2259a4f6d 100644 --- a/ax/plot/tests/test_pareto_utils.py +++ b/ax/plot/tests/test_pareto_utils.py @@ -107,7 +107,7 @@ def test_get_observed_pareto_frontiers(self) -> None: # For the check below, compute which arms are better than SQ df = experiment.fetch_data().df df["sem"] = np.nan - data = Data(df) + data = Data(df=df) sq_val = df[(df["arm_name"] == "status_quo") & (df["metric_name"] == "m1")][ "mean" ].values[0] diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index 00726651976..96d2c613007 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -705,10 +705,6 @@ def test_decode_map_data_backward_compatible(self) -> None: class_decoder_registry=CORE_CLASS_DECODER_REGISTRY, ) self.assertEqual(len(map_data.full_df), 2) - # Even though the "epoch" and "timestamps" columns have not been - # renamed to "step", they are present - self.assertEqual(map_data.full_df["epoch"].tolist(), [0.0, 1.0]) - self.assertEqual(map_data.full_df["timestamps"].tolist(), [3.0, 4.0]) self.assertIsInstance(map_data, Data) with self.subTest("Single map key"): @@ -729,8 +725,8 @@ def test_decode_map_data_backward_compatible(self) -> None: decoder_registry=CORE_DECODER_REGISTRY, class_decoder_registry=CORE_CLASS_DECODER_REGISTRY, ) - self.assertIn("epoch", map_data.full_df.columns) - self.assertEqual(map_data.full_df["epoch"].tolist(), [0.0, 1.0]) + self.assertEqual(len(map_data.full_df), 2) + self.assertIsInstance(map_data, Data) with self.subTest("No map key"): data_json = { diff --git a/ax/storage/sqa_store/utils.py b/ax/storage/sqa_store/utils.py index dfff1786dbf..3fc233ff15e 100644 --- a/ax/storage/sqa_store/utils.py +++ b/ax/storage/sqa_store/utils.py @@ -44,6 +44,7 @@ # don't need to recur into them during `copy_db_ids`. "auxiliary_experiments_by_purpose", "_metric_fetching_errors", + "_data_rows", } SKIP_ATTRS_ERROR_SUFFIX = "Consider adding to COPY_DB_IDS_ATTRS_TO_SKIP if appropriate." diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 6d6dff0e638..6a33f38ceb3 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -2581,7 +2581,7 @@ def get_branin_data_batch( for i in range(len(means)) for metric in metrics ] - return Data(pd.DataFrame.from_records(records)) + return Data(df=pd.DataFrame.from_records(records)) def get_branin_data_multi_objective(