From cf6e91ed48943c361ce2cbc4e45e4b92a30933a0 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Wed, 14 Jan 2026 11:53:07 +0100 Subject: [PATCH 1/6] feat: Add support for converting `dy.List` columns to postgres array --- dataframely/columns/array.py | 13 ++++++++++--- dataframely/columns/list.py | 9 ++++++--- tests/columns/test_sqlalchemy_columns.py | 10 +++++++--- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/dataframely/columns/array.py b/dataframely/columns/array.py index ddd9e9d..a7538bf 100644 --- a/dataframely/columns/array.py +++ b/dataframely/columns/array.py @@ -10,7 +10,7 @@ import polars as pl -from dataframely._compat import pa, sa, sa_TypeEngine +from dataframely._compat import PGDialect_psycopg2, pa, sa, sa_TypeEngine from dataframely.random import Generator from ._base import Check, Column @@ -97,8 +97,15 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]: } def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: - # NOTE: We might want to add support for PostgreSQL's ARRAY type or use JSON in the future. - raise NotImplementedError("SQL column cannot have 'Array' type.") + if isinstance(dialect, PGDialect_psycopg2): + # Note that the length of the array in each dimension is not supported in SQLAlchemy + # That is because PostgreSQL does not enforce the length anyway + return sa.ARRAY( + self.inner.sqlalchemy_dtype(dialect), dimensions=len(self.shape) + ) + raise NotImplementedError( + f"SQL column cannot have 'Array' type for dialect '{dialect}'." + ) def _pyarrow_field_of_shape(self, shape: Sequence[int]) -> pa.Field: if shape: diff --git a/dataframely/columns/list.py b/dataframely/columns/list.py index 6c1c66a..1351cd3 100644 --- a/dataframely/columns/list.py +++ b/dataframely/columns/list.py @@ -11,7 +11,7 @@ from polars.expr.array import ExprArrayNameSpace from polars.expr.list import ExprListNameSpace -from dataframely._compat import pa, sa, sa_TypeEngine +from dataframely._compat import PGDialect_psycopg2, pa, sa, sa_TypeEngine from dataframely._polars import PolarsDataType from dataframely.random import Generator @@ -120,8 +120,11 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]: } def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: - # NOTE: We might want to add support for PostgreSQL's ARRAY type or use JSON in the future. - raise NotImplementedError("SQL column cannot have 'List' type.") + if isinstance(dialect, PGDialect_psycopg2): + return sa.ARRAY(self.inner.sqlalchemy_dtype(dialect)) + raise NotImplementedError( + f"SQL column cannot have 'List' type for dialect '{dialect}'." + ) @property def pyarrow_dtype(self) -> pa.DataType: diff --git a/tests/columns/test_sqlalchemy_columns.py b/tests/columns/test_sqlalchemy_columns.py index 732b395..f49a5aa 100644 --- a/tests/columns/test_sqlalchemy_columns.py +++ b/tests/columns/test_sqlalchemy_columns.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2025-2025 +# Copyright (c) QuantCo 2025-2026 # SPDX-License-Identifier: BSD-3-Clause import pytest @@ -95,6 +95,10 @@ def test_mssql_datatype(column: Column, datatype: str) -> None: (dy.String(regex="^[abc]{1,3}d$"), "VARCHAR(4)"), (dy.Enum(["foo", "bar"]), "CHAR(3)"), (dy.Enum(["a", "abc"]), "VARCHAR(3)"), + (dy.List(dy.Integer()), "INTEGER[]"), + (dy.List(dy.String(max_length=5)), "VARCHAR(5)[]"), + (dy.Array(dy.Integer(), shape=5), "INTEGER[]"), + (dy.Array(dy.String(max_length=5), shape=(2, 1)), "VARCHAR(5)[][]"), ], ) def test_postgres_datatype(column: Column, datatype: str) -> None: @@ -136,7 +140,7 @@ def test_sql_multiple_columns(dialect: Dialect) -> None: assert len(schema.to_sqlalchemy_columns(dialect)) == 2 -@pytest.mark.parametrize("dialect", [MSDialect_pyodbc(), PGDialect_psycopg2()]) +@pytest.mark.parametrize("dialect", [MSDialect_pyodbc()]) def test_raise_for_list_column(dialect: Dialect) -> None: with pytest.raises( NotImplementedError, match="SQL column cannot have 'List' type." @@ -144,7 +148,7 @@ def test_raise_for_list_column(dialect: Dialect) -> None: dy.List(dy.String()).sqlalchemy_dtype(dialect) -@pytest.mark.parametrize("dialect", [MSDialect_pyodbc(), PGDialect_psycopg2()]) +@pytest.mark.parametrize("dialect", [MSDialect_pyodbc()]) def test_raise_for_array_column(dialect: Dialect) -> None: with pytest.raises( NotImplementedError, match="SQL column cannot have 'Array' type." From d90015d8684a748f45e8f2b78ad478510717c177 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Wed, 14 Jan 2026 16:42:17 +0100 Subject: [PATCH 2/6] review --- dataframely/columns/array.py | 6 +++--- dataframely/columns/list.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dataframely/columns/array.py b/dataframely/columns/array.py index a7538bf..768cc0c 100644 --- a/dataframely/columns/array.py +++ b/dataframely/columns/array.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2025-2025 +# Copyright (c) QuantCo 2025-2026 # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations @@ -10,7 +10,7 @@ import polars as pl -from dataframely._compat import PGDialect_psycopg2, pa, sa, sa_TypeEngine +from dataframely._compat import pa, sa, sa_TypeEngine from dataframely.random import Generator from ._base import Check, Column @@ -97,7 +97,7 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]: } def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: - if isinstance(dialect, PGDialect_psycopg2): + if dialect.name == "postgresql": # Note that the length of the array in each dimension is not supported in SQLAlchemy # That is because PostgreSQL does not enforce the length anyway return sa.ARRAY( diff --git a/dataframely/columns/list.py b/dataframely/columns/list.py index 1351cd3..f219308 100644 --- a/dataframely/columns/list.py +++ b/dataframely/columns/list.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2025-2025 +# Copyright (c) QuantCo 2025-2026 # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations @@ -11,7 +11,7 @@ from polars.expr.array import ExprArrayNameSpace from polars.expr.list import ExprListNameSpace -from dataframely._compat import PGDialect_psycopg2, pa, sa, sa_TypeEngine +from dataframely._compat import pa, sa, sa_TypeEngine from dataframely._polars import PolarsDataType from dataframely.random import Generator @@ -120,7 +120,7 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]: } def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: - if isinstance(dialect, PGDialect_psycopg2): + if dialect.name == "postgresql": return sa.ARRAY(self.inner.sqlalchemy_dtype(dialect)) raise NotImplementedError( f"SQL column cannot have 'List' type for dialect '{dialect}'." From 5ab33423b7b537d240d80a6d62a16abcaa3e92f9 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Wed, 14 Jan 2026 16:49:30 +0100 Subject: [PATCH 3/6] fix --- dataframely/columns/array.py | 20 +++++++++++--------- dataframely/columns/list.py | 12 +++++++----- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/dataframely/columns/array.py b/dataframely/columns/array.py index 768cc0c..de97e61 100644 --- a/dataframely/columns/array.py +++ b/dataframely/columns/array.py @@ -97,15 +97,17 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]: } def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: - if dialect.name == "postgresql": - # Note that the length of the array in each dimension is not supported in SQLAlchemy - # That is because PostgreSQL does not enforce the length anyway - return sa.ARRAY( - self.inner.sqlalchemy_dtype(dialect), dimensions=len(self.shape) - ) - raise NotImplementedError( - f"SQL column cannot have 'Array' type for dialect '{dialect}'." - ) + match dialect.name: + case "postgresql": + # Note that the length of the array in each dimension is not supported in SQLAlchemy + # That is because PostgreSQL does not enforce the length anyway + return sa.ARRAY( + self.inner.sqlalchemy_dtype(dialect), dimensions=len(self.shape) + ) + case _: + raise NotImplementedError( + f"SQL column cannot have 'List' type for dialect '{dialect}'." + ) def _pyarrow_field_of_shape(self, shape: Sequence[int]) -> pa.Field: if shape: diff --git a/dataframely/columns/list.py b/dataframely/columns/list.py index f219308..cd5ed80 100644 --- a/dataframely/columns/list.py +++ b/dataframely/columns/list.py @@ -120,11 +120,13 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]: } def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: - if dialect.name == "postgresql": - return sa.ARRAY(self.inner.sqlalchemy_dtype(dialect)) - raise NotImplementedError( - f"SQL column cannot have 'List' type for dialect '{dialect}'." - ) + match dialect.name: + case "postgresql": + return sa.ARRAY(self.inner.sqlalchemy_dtype(dialect)) + case _: + raise NotImplementedError( + f"SQL column cannot have 'List' type for dialect '{dialect}'." + ) @property def pyarrow_dtype(self) -> pa.DataType: From 51ea29eacc7a7934025541fcf4b71cb180b7de33 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Wed, 14 Jan 2026 16:52:07 +0100 Subject: [PATCH 4/6] fix --- dataframely/columns/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataframely/columns/array.py b/dataframely/columns/array.py index de97e61..e100449 100644 --- a/dataframely/columns/array.py +++ b/dataframely/columns/array.py @@ -106,7 +106,7 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: ) case _: raise NotImplementedError( - f"SQL column cannot have 'List' type for dialect '{dialect}'." + f"SQL column cannot have 'Array' type for dialect '{dialect}'." ) def _pyarrow_field_of_shape(self, shape: Sequence[int]) -> pa.Field: From ce4008d64668282d91644fced1729ef03c92ac2d Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Wed, 14 Jan 2026 16:54:22 +0100 Subject: [PATCH 5/6] cleanup --- dataframely/columns/binary.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dataframely/columns/binary.py b/dataframely/columns/binary.py index ceb8a7a..1686816 100644 --- a/dataframely/columns/binary.py +++ b/dataframely/columns/binary.py @@ -21,9 +21,11 @@ def dtype(self) -> pl.DataType: return pl.Binary() def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: - if dialect.name == "mssql": - return sa.VARBINARY() - return sa.LargeBinary() + match dialect.name: + case "mssql": + return sa.VARBINARY() + case _: + return sa.LargeBinary() @property def pyarrow_dtype(self) -> pa.DataType: From 13bcd6adfe719cae209ce67610334b40aa4f0fd1 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Wed, 14 Jan 2026 16:59:11 +0100 Subject: [PATCH 6/6] fix --- dataframely/columns/binary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataframely/columns/binary.py b/dataframely/columns/binary.py index 1686816..a6a6b4b 100644 --- a/dataframely/columns/binary.py +++ b/dataframely/columns/binary.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2025-2025 +# Copyright (c) QuantCo 2025-2026 # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations