Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ax/analysis/plotly/tests/test_marginal_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def setUp(self) -> None:
self.experiment.trials[i].mark_running(no_runner_required=True)
self.experiment.attach_data(
Data(
pd.DataFrame(
df=pd.DataFrame(
{
"trial_index": [i] * num_arms,
"arm_name": [f"0_{j}" for j in range(num_arms)],
Expand Down
4 changes: 1 addition & 3 deletions ax/core/base_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Any, TYPE_CHECKING

from ax.core.arm import Arm
from ax.core.data import Data, sort_by_trial_index_and_arm_name
from ax.core.data import Data
from ax.core.evaluations_to_data import raw_evaluations_to_data
from ax.core.generator_run import GeneratorRun, GeneratorRunType
from ax.core.metric import Metric, MetricFetchResult
Expand Down Expand Up @@ -442,8 +442,6 @@ def fetch_data(self, metrics: list[Metric] | None = None, **kwargs: Any) -> Data
data = Metric._unwrap_trial_data_multi(
results=self.fetch_data_results(metrics=metrics, **kwargs)
)
if not data.has_step_column:
data.full_df = sort_by_trial_index_and_arm_name(data.full_df)

return data

Expand Down
149 changes: 116 additions & 33 deletions ax/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from functools import cached_property
from io import StringIO
from logging import Logger
from typing import Any, TypeVar
from typing import Any

import numpy as np
import numpy.typing as npt
Expand All @@ -34,11 +34,40 @@

logger: Logger = get_logger(__name__)

TData = TypeVar("TData", bound="Data")
DF_REPR_MAX_LENGTH = 1000
MAP_KEY = "step"


class DataRow:
def __init__(
self,
trial_index: int,
arm_name: str,
metric_name: str,
metric_signature: str,
mean: float,
se: float,
step: float | None = None,
start_time: int | None = None,
end_time: int | None = None,
n: int | None = None,
) -> None:
self.trial_index: int = trial_index
self.arm_name: str = arm_name

self.metric_name: str = metric_name
self.metric_signature: str = metric_signature

self.mean: float = mean
self.se: float = se

self.step: float | None = step

self.start_time: int | None = start_time
self.end_time: int | None = end_time
self.n: int | None = n


class Data(Base, SerializationMixin):
"""Class storing numerical data for an experiment.

Expand Down Expand Up @@ -102,8 +131,6 @@ class Data(Base, SerializationMixin):
"start_time": pd.Timestamp,
"end_time": pd.Timestamp,
"n": int,
"frac_nonnull": np.float64,
"random_split": int,
MAP_KEY: float,
}

Expand All @@ -116,16 +143,19 @@ class Data(Base, SerializationMixin):
"metric_signature",
]

full_df: pd.DataFrame
_data_rows: list[DataRow]

def __init__(
self: TData,
self,
data_rows: Iterable[DataRow] | None = None,
df: pd.DataFrame | None = None,
_skip_ordering_and_validation: bool = False,
) -> None:
"""Initialize a ``Data`` object from the given DataFrame.

Args:
data_rows: Iterable of DataRows. If provided, this will be used as the
source of truth for Data, over df.
df: DataFrame with underlying data, and required columns. Data must
be unique at the level of ("trial_index", "arm_name",
"metric_name"), plus "step" if a "step" column is present. A
Expand All @@ -136,31 +166,84 @@ def __init__(
Intended only for use in `Data.filter`, where the contents
of the DataFrame are known to be ordered and valid.
"""
if df is None:
# Initialize with barebones DF with expected dtypes
self.full_df = pd.DataFrame.from_dict(
if data_rows is not None:
if isinstance(data_rows, pd.DataFrame):
raise ValueError(
"data_rows must be an iterable of DataRows, not a DataFrame."
)
self._data_rows = [*data_rows]
elif df is not None:
# Unroll the df into a list of DataRows
if missing_columns := self.REQUIRED_COLUMNS - {*df.columns}:
raise ValueError(
f"Dataframe must contain required columns {list(missing_columns)}."
)

self._data_rows = [
DataRow(
trial_index=row["trial_index"],
arm_name=row["arm_name"],
metric_name=row["metric_name"],
metric_signature=row["metric_signature"],
mean=row["mean"],
se=row["sem"],
step=row.get(MAP_KEY),
start_time=row.get("start_time"),
end_time=row.get("end_time"),
n=row.get("n"),
)
for _, row in df.iterrows()
]
else:
self._data_rows = []

self._memo_df: pd.DataFrame | None = None
self.has_step_column: bool = any(
row.step is not None for row in self._data_rows
)

@cached_property
def full_df(self) -> pd.DataFrame:
"""
Convert the DataRows into a pandas DataFrame. If step, start_time, or end_time
is None for all rows the column will be elided.
"""
if len(self._data_rows) == 0:
return pd.DataFrame.from_dict(
{
col: pd.Series([], dtype=self.COLUMN_DATA_TYPES[col])
for col in self.REQUIRED_COLUMNS
}
)
elif _skip_ordering_and_validation:
self.full_df = df
else:
columns = set(df.columns)
missing_columns = self.REQUIRED_COLUMNS - columns
if missing_columns:
raise ValueError(
f"Dataframe must contain required columns {list(missing_columns)}."
)
# Drop rows where every input is null. Since `dropna` can be slow, first
# check trial index to see if dropping nulls might be needed.
if df["trial_index"].isnull().any():
df = df.dropna(axis=0, how="all", ignore_index=True)
df = self._safecast_df(df=df)
self.full_df = self._get_df_with_cols_in_expected_order(df=df)
self._memo_df = None
self.has_step_column = MAP_KEY in self.full_df.columns

# Detect whether any of the optional attributes are present and should be
# included as columns in the full DataFrame.
include_step = any(row.step is not None for row in self._data_rows)
include_start_time = any(row.start_time is not None for row in self._data_rows)
include_end_time = any(row.end_time is not None for row in self._data_rows)
include_n = any(row.n is not None for row in self._data_rows)

records = [
{
"trial_index": row.trial_index,
"arm_name": row.arm_name,
"metric_name": row.metric_name,
"metric_signature": row.metric_signature,
"mean": row.mean,
"sem": row.se,
**({"step": row.step} if include_step else {}),
**({"start_time": row.start_time} if include_start_time else {}),
**({"end_time": row.end_time} if include_end_time else {}),
**({"n": row.n} if include_n else {}),
}
for row in self._data_rows
]

return self._get_df_with_cols_in_expected_order(
df=self._safecast_df(
df=pd.DataFrame.from_records(records),
),
)

@classmethod
def _get_df_with_cols_in_expected_order(cls, df: pd.DataFrame) -> pd.DataFrame:
Expand All @@ -175,7 +258,7 @@ def _get_df_with_cols_in_expected_order(cls, df: pd.DataFrame) -> pd.DataFrame:
return df

@classmethod
def _safecast_df(cls: type[TData], df: pd.DataFrame) -> pd.DataFrame:
def _safecast_df(cls, df: pd.DataFrame) -> pd.DataFrame:
"""Function for safely casting df to standard data types.

Needed because numpy does not support NaNs in integer arrays.
Expand Down Expand Up @@ -275,7 +358,7 @@ def df(self) -> pd.DataFrame:
return self._memo_df

@classmethod
def from_multiple_data(cls: type[TData], data: Iterable[Data]) -> TData:
def from_multiple_data(cls, data: Iterable[Data]) -> Data:
"""Combines multiple objects into one (with the concatenated
underlying dataframe).

Expand Down Expand Up @@ -339,21 +422,21 @@ def filter(
_skip_ordering_and_validation=True,
)

def clone(self: TData) -> TData:
def clone(self) -> Data:
"""Returns a new Data object with the same underlying dataframe."""
return self.__class__(df=deepcopy(self.full_df))

def __eq__(self, o: Data) -> bool:
return type(self) is type(o) and dataframe_equals(self.full_df, o.full_df)

def relativize(
self: TData,
self,
status_quo_name: str = "status_quo",
as_percent: bool = False,
include_sq: bool = False,
bias_correction: bool = True,
control_as_constant: bool = False,
) -> TData:
) -> Data:
"""Relativize a data object w.r.t. a status_quo arm.

Args:
Expand Down Expand Up @@ -437,12 +520,12 @@ def latest(self, rows_per_group: int = 1) -> Data:
)

def subsample(
self: TData,
self,
keep_every: int | None = None,
limit_rows_per_group: int | None = None,
limit_rows_per_metric: int | None = None,
include_first_last: bool = True,
) -> TData:
) -> Data:
"""Return a new Data that subsamples the `MAP_KEY` column in an
equally-spaced manner. This function considers only the relative ordering
of the `MAP_KEY` values, making it most suitable when these values are
Expand Down
Loading