Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ Unreleased
- Fix reflection of JSONB columns (#277)
- Fix compatibility issues with Alembic 1.18 (via SQLA 2.0.47)
- Update minimum Python version to 3.10
- Compile MySQL-style `func.timestampdiff(unit, start, end)` to a
PostgreSQL-style `EXTRACT(EPOCH FROM ...)` expression on the cockroachdb
dialect, casting the result to NUMERIC to avoid `float / decimal`
arithmetic errors. Enables cross-dialect ORMs (e.g. Apache Airflow) that
fall back to `timestampdiff` for non-PostgreSQL backends.


# Version 2.0.3
Expand Down
52 changes: 52 additions & 0 deletions sqlalchemy_cockroachdb/stmt_compiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from sqlalchemy.dialects.postgresql.base import PGCompiler
from sqlalchemy.dialects.postgresql.base import PGIdentifierPreparer
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.functions import GenericFunction

# This is extracted from CockroachDB's `sql.y`. Add keywords here if *NEW* reserved keywords
# are added to sql.y. DO NOT DELETE keywords here, even if they are deleted from sql.y:
Expand Down Expand Up @@ -100,3 +102,53 @@ class CockroachIdentifierPreparer(PGIdentifierPreparer):
class CockroachCompiler(PGCompiler):
def format_from_hint_text(self, sqltext, table, hint, iscrud):
return f"{sqltext}@{hint}"


class timestampdiff(GenericFunction):
"""MySQL-style ``timestampdiff(unit, start, end)`` for cross-dialect SQL.

CockroachDB does not implement MySQL's ``timestampdiff()``. Applications
that target multiple database backends (notably Apache Airflow's ORM)
sometimes call ``func.timestampdiff(...)`` and rely on the database to
accept it. Registering this :class:`GenericFunction` lets the cockroachdb
statement compiler translate the call into a PostgreSQL-style
``EXTRACT(EPOCH FROM (end - start))`` expression.
"""

inherit_cache = True
name = "timestampdiff"


_TIMESTAMPDIFF_UNIT_FACTOR = {
"MICROSECOND": " * 1000000",
"MILLISECOND": " * 1000",
"SECOND": "",
"MINUTE": " / 60",
"HOUR": " / 3600",
"DAY": " / 86400",
}


@compiles(timestampdiff, "cockroachdb")
def _compile_timestampdiff_cockroachdb(element, compiler, **kwargs):
"""Compile ``timestampdiff(unit, start, end)`` for the cockroachdb dialect.

The result is cast to ``NUMERIC`` so callers may safely combine it with
integer or numeric divisors. CockroachDB rejects ``float / decimal``
arithmetic that PostgreSQL accepts, and ``EXTRACT(EPOCH FROM ...)``
returns a float on CockroachDB.
"""
args = list(element.clauses)
if len(args) != 3:
raise ValueError(f"timestampdiff() expects 3 arguments (unit, start, end); got {len(args)}")
unit_token = compiler.process(args[0], **kwargs).strip().strip("'\"").upper()
if unit_token not in _TIMESTAMPDIFF_UNIT_FACTOR:
raise ValueError(
f"Unsupported timestampdiff() unit for cockroachdb dialect: {unit_token!r}. "
f"Supported units: {sorted(_TIMESTAMPDIFF_UNIT_FACTOR)}"
)
Comment thread
viragtripathi marked this conversation as resolved.
start_expr = compiler.process(args[1], **kwargs)
end_expr = compiler.process(args[2], **kwargs)
epoch_diff = f"CAST(EXTRACT(EPOCH FROM ({end_expr} - {start_expr})) AS NUMERIC)"
factor = _TIMESTAMPDIFF_UNIT_FACTOR[unit_token]
return f"({epoch_diff}{factor})" if factor else epoch_diff
91 changes: 91 additions & 0 deletions test/test_timestampdiff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import pytest
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import Table
from sqlalchemy import TIMESTAMP
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.dialects.postgresql import dialect as postgresql_dialect
from sqlalchemy.sql import func
from sqlalchemy.testing import fixtures

from sqlalchemy_cockroachdb.psycopg2 import CockroachDBDialect_psycopg2
from sqlalchemy_cockroachdb.stmt_compiler import timestampdiff # noqa: F401 registers compiler


def _events_table():
metadata = MetaData()
return Table(
"events",
metadata,
Column("id", Integer, primary_key=True),
Column("start_date", TIMESTAMP),
Column("end_date", TIMESTAMP),
)


def _compile(stmt, dialect):
return str(stmt.compile(dialect=dialect, compile_kwargs={"literal_binds": True}))


class TimestampdiffCompilerTest(fixtures.TestBase):
"""Compile-only tests: no live database connection required."""

@pytest.fixture(autouse=True)
def _setup(self):
self.dialect = CockroachDBDialect_psycopg2()
self.events = _events_table()

@pytest.mark.parametrize(
"unit,expected_suffix",
[
("MICROSECOND", "* 1000000"),
("MILLISECOND", "* 1000"),
("MINUTE", "/ 60"),
("HOUR", "/ 3600"),
("DAY", "/ 86400"),
],
)
def test_compiles_with_arithmetic_suffix(self, unit, expected_suffix):
expr = func.timestampdiff(text(unit), self.events.c.start_date, self.events.c.end_date)
sql = _compile(select(expr), self.dialect)
assert "EXTRACT(EPOCH FROM" in sql
assert "AS NUMERIC" in sql
assert expected_suffix in sql

def test_seconds_has_no_arithmetic_suffix(self):
expr = func.timestampdiff(text("SECOND"), self.events.c.start_date, self.events.c.end_date)
sql = _compile(select(expr), self.dialect)
assert "EXTRACT(EPOCH FROM" in sql
assert "AS NUMERIC" in sql
after_cast = sql.split("AS NUMERIC", 1)[1]
assert " * " not in after_cast
assert " / " not in after_cast

def test_lowercase_unit_accepted(self):
expr = func.timestampdiff(
text("microsecond"), self.events.c.start_date, self.events.c.end_date
)
sql = _compile(select(expr), self.dialect)
assert "EXTRACT(EPOCH FROM" in sql
assert "* 1000000" in sql

def test_unknown_unit_rejected(self):
expr = func.timestampdiff(
text("FORTNIGHT"), self.events.c.start_date, self.events.c.end_date
)
with pytest.raises(ValueError, match="Unsupported timestampdiff"):
_compile(select(expr), self.dialect)

def test_wrong_arity_rejected(self):
expr = func.timestampdiff(text("SECOND"), self.events.c.start_date)
with pytest.raises(ValueError, match="3 arguments"):
_compile(select(expr), self.dialect)

def test_postgresql_dialect_unaffected(self):
"""The cockroachdb compiler hook must not change rendering for other dialects."""
expr = func.timestampdiff(text("SECOND"), self.events.c.start_date, self.events.c.end_date)
sql = _compile(select(expr), postgresql_dialect())
assert "EXTRACT" not in sql
assert "timestampdiff" in sql.lower()
Loading