diff --git a/dataframely/columns/array.py b/dataframely/columns/array.py index ddd9e9d..e100449 100644 --- a/dataframely/columns/array.py +++ b/dataframely/columns/array.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2025-2025 +# Copyright (c) QuantCo 2025-2026 # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations @@ -97,8 +97,17 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]: } def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: - # NOTE: We might want to add support for PostgreSQL's ARRAY type or use JSON in the future. - raise NotImplementedError("SQL column cannot have 'Array' type.") + match dialect.name: + case "postgresql": + # Note that the length of the array in each dimension is not supported in SQLAlchemy + # That is because PostgreSQL does not enforce the length anyway + return sa.ARRAY( + self.inner.sqlalchemy_dtype(dialect), dimensions=len(self.shape) + ) + case _: + raise NotImplementedError( + f"SQL column cannot have 'Array' type for dialect '{dialect}'." + ) def _pyarrow_field_of_shape(self, shape: Sequence[int]) -> pa.Field: if shape: diff --git a/dataframely/columns/binary.py b/dataframely/columns/binary.py index ceb8a7a..a6a6b4b 100644 --- a/dataframely/columns/binary.py +++ b/dataframely/columns/binary.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2025-2025 +# Copyright (c) QuantCo 2025-2026 # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations @@ -21,9 +21,11 @@ def dtype(self) -> pl.DataType: return pl.Binary() def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: - if dialect.name == "mssql": - return sa.VARBINARY() - return sa.LargeBinary() + match dialect.name: + case "mssql": + return sa.VARBINARY() + case _: + return sa.LargeBinary() @property def pyarrow_dtype(self) -> pa.DataType: diff --git a/dataframely/columns/list.py b/dataframely/columns/list.py index 6c1c66a..cd5ed80 100644 --- a/dataframely/columns/list.py +++ b/dataframely/columns/list.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2025-2025 +# Copyright (c) QuantCo 2025-2026 # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations @@ -120,8 +120,13 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]: } def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: - # NOTE: We might want to add support for PostgreSQL's ARRAY type or use JSON in the future. - raise NotImplementedError("SQL column cannot have 'List' type.") + match dialect.name: + case "postgresql": + return sa.ARRAY(self.inner.sqlalchemy_dtype(dialect)) + case _: + raise NotImplementedError( + f"SQL column cannot have 'List' type for dialect '{dialect}'." + ) @property def pyarrow_dtype(self) -> pa.DataType: diff --git a/tests/columns/test_sqlalchemy_columns.py b/tests/columns/test_sqlalchemy_columns.py index 1fc65a2..6731202 100644 --- a/tests/columns/test_sqlalchemy_columns.py +++ b/tests/columns/test_sqlalchemy_columns.py @@ -95,6 +95,10 @@ def test_mssql_datatype(column: Column, datatype: str) -> None: (dy.String(regex="^[abc]{1,3}d$"), "VARCHAR(4)"), (dy.Enum(["foo", "bar"]), "CHAR(3)"), (dy.Enum(["a", "abc"]), "VARCHAR(3)"), + (dy.List(dy.Integer()), "INTEGER[]"), + (dy.List(dy.String(max_length=5)), "VARCHAR(5)[]"), + (dy.Array(dy.Integer(), shape=5), "INTEGER[]"), + (dy.Array(dy.String(max_length=5), shape=(2, 1)), "VARCHAR(5)[][]"), (dy.Struct({"a": dy.String(nullable=True)}), "JSONB"), ], ) @@ -137,7 +141,7 @@ def test_sql_multiple_columns(dialect: Dialect) -> None: assert len(schema.to_sqlalchemy_columns(dialect)) == 2 -@pytest.mark.parametrize("dialect", [MSDialect_pyodbc(), PGDialect_psycopg2()]) +@pytest.mark.parametrize("dialect", [MSDialect_pyodbc()]) def test_raise_for_list_column(dialect: Dialect) -> None: with pytest.raises( NotImplementedError, match="SQL column cannot have 'List' type." @@ -145,7 +149,7 @@ def test_raise_for_list_column(dialect: Dialect) -> None: dy.List(dy.String()).sqlalchemy_dtype(dialect) -@pytest.mark.parametrize("dialect", [MSDialect_pyodbc(), PGDialect_psycopg2()]) +@pytest.mark.parametrize("dialect", [MSDialect_pyodbc()]) def test_raise_for_array_column(dialect: Dialect) -> None: with pytest.raises( NotImplementedError, match="SQL column cannot have 'Array' type."