Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@ Unreleased
- Fix reflection of JSONB columns (#277)
- Fix compatibility issues with Alembic 1.18 (via SQLA 2.0.47)
- Update minimum Python version to 3.10
- Compile MySQL-style `func.timestampdiff(unit, start, end)` to a
PostgreSQL-style `EXTRACT(EPOCH FROM ...)` expression on the cockroachdb
dialect. The arithmetic result is wrapped in `TRUNC()` so the value matches
MySQL's integer-truncation-toward-zero semantics (a 90-second diff at
`MINUTE` returns 1, not 1.5), and is cast to NUMERIC so callers may safely
combine it with integer or numeric divisors -- avoiding the `float / decimal`
arithmetic errors CockroachDB rejects but PostgreSQL accepts. Supported
units: MICROSECOND, MILLISECOND, SECOND, MINUTE, HOUR, DAY, WEEK.
Calendar-aware units (MONTH, QUARTER, YEAR) are explicitly rejected with a
specific error message because they require calendar walking that cannot be
derived from epoch arithmetic alone. Enables cross-dialect ORMs (e.g.
Apache Airflow) that fall back to `timestampdiff` for non-PostgreSQL
backends.


# Version 2.0.3
Expand Down
109 changes: 109 additions & 0 deletions sqlalchemy_cockroachdb/stmt_compiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from sqlalchemy.dialects.postgresql.base import PGCompiler
from sqlalchemy.dialects.postgresql.base import PGIdentifierPreparer
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.elements import BindParameter
from sqlalchemy.sql.functions import GenericFunction

# This is extracted from CockroachDB's `sql.y`. Add keywords here if *NEW* reserved keywords
# are added to sql.y. DO NOT DELETE keywords here, even if they are deleted from sql.y:
Expand Down Expand Up @@ -100,3 +103,109 @@ class CockroachIdentifierPreparer(PGIdentifierPreparer):
class CockroachCompiler(PGCompiler):
def format_from_hint_text(self, sqltext, table, hint, iscrud):
return f"{sqltext}@{hint}"


class timestampdiff(GenericFunction):
"""MySQL-style ``timestampdiff(unit, start, end)`` for cross-dialect SQL.

CockroachDB does not implement MySQL's ``timestampdiff()``. Applications
that target multiple database backends (notably Apache Airflow's ORM)
sometimes call ``func.timestampdiff(...)`` and rely on the database to
accept it. Registering this :class:`GenericFunction` lets the cockroachdb
statement compiler translate the call into a PostgreSQL-style
``EXTRACT(EPOCH FROM (end - start))`` expression.
"""

inherit_cache = True
name = "timestampdiff"


_TIMESTAMPDIFF_UNIT_FACTOR = {
"MICROSECOND": " * 1000000",
"MILLISECOND": " * 1000",
"SECOND": "",
"MINUTE": " / 60",
"HOUR": " / 3600",
"DAY": " / 86400",
"WEEK": " / 604800",
}

# Calendar-aware units are intentionally not implemented. MySQL's
# ``TIMESTAMPDIFF(MONTH, ...)`` walks the calendar so that ``Feb 28 -> Mar 1``
# is one month while ``Mar 1 -> Mar 30`` is zero months. That logic cannot be
# derived from epoch arithmetic; a faithful implementation would need
# ``EXTRACT(YEAR FROM AGE(end, start))`` plus month math. Listing them here
# lets us emit a specific error rather than the generic "unsupported unit" one.
_TIMESTAMPDIFF_CALENDAR_AWARE_UNITS = frozenset({"MONTH", "QUARTER", "YEAR"})


def _resolve_timestampdiff_unit(unit_arg, compiler, **kwargs):
"""Extract the unit token from a ``timestampdiff()`` first argument.

The unit must be known at compile time so it can be turned into a SQL
arithmetic factor. Plain Python strings (``func.timestampdiff("SECOND", ...)``)
and ``literal("SECOND")`` reach the compiler as :class:`BindParameter`; if we
delegated to ``compiler.process`` those would render as parameter
placeholders (``$1`` / ``%(...)s``), so we extract ``.value`` directly.
Other constructs such as ``text("SECOND")`` and ``literal_column("SECOND")``
render as literal SQL tokens and go through the normal path.
"""
if isinstance(unit_arg, BindParameter):
raw = unit_arg.value
if not isinstance(raw, str):
raise ValueError(
"timestampdiff() unit must be a string; " f"got {type(raw).__name__} ({raw!r})"
Comment thread
viragtripathi marked this conversation as resolved.
)
return raw.strip().upper()
return compiler.process(unit_arg, **kwargs).strip().strip("'\"").upper()


@compiles(timestampdiff, "cockroachdb")
def _compile_timestampdiff_cockroachdb(element, compiler, **kwargs):
"""Compile ``timestampdiff(unit, start, end)`` for the cockroachdb dialect.

Output shape::

TRUNC(CAST(EXTRACT(EPOCH FROM (end - start)) AS NUMERIC) <factor>)

The ``TRUNC()`` wrap matches MySQL's ``TIMESTAMPDIFF`` semantics: the
result is the integer count of complete units between the two timestamps,
truncated toward zero. Without it, a 90-second diff at ``MINUTE`` would
return ``1.5`` on cockroachdb where MySQL returns ``1``.

The cast to ``NUMERIC`` (rather than to ``BIGINT``) is intentional. It
keeps the value integer-truncated like MySQL while still allowing
downstream divisors -- e.g. Apache Airflow's
``timestampdiff(MICROSECOND, ...) / 1_000_000`` pattern -- to do
floating-point division on cockroachdb. Returning ``BIGINT`` would force
integer division on the divisor and silently lose sub-second precision.

Calendar-aware units (``MONTH``, ``QUARTER``, ``YEAR``) are intentionally
rejected with a specific error; see ``_TIMESTAMPDIFF_CALENDAR_AWARE_UNITS``
for the rationale.
"""
args = list(element.clauses)
if len(args) != 3:
raise ValueError(f"timestampdiff() expects 3 arguments (unit, start, end); got {len(args)}")
unit_token = _resolve_timestampdiff_unit(args[0], compiler, **kwargs)
if unit_token in _TIMESTAMPDIFF_CALENDAR_AWARE_UNITS:
raise ValueError(
f"timestampdiff() unit {unit_token!r} is not supported on the cockroachdb "
"dialect. Calendar-aware units (MONTH, QUARTER, YEAR) require calendar "
"walking (e.g. Feb 28 -> Mar 1 is 1 month) that cannot be derived from "
"epoch arithmetic alone, and are intentionally omitted. "
"If you need them, please open an issue at "
"https://github.com/cockroachdb/sqlalchemy-cockroachdb/issues."
)
if unit_token not in _TIMESTAMPDIFF_UNIT_FACTOR:
raise ValueError(
f"Unsupported timestampdiff() unit for cockroachdb dialect: {unit_token!r}. "
f"Supported units: {sorted(_TIMESTAMPDIFF_UNIT_FACTOR)}. "
"Pass the unit as a plain string, sqlalchemy.literal(unit), "
"or sqlalchemy.text(unit)."
)
Comment thread
viragtripathi marked this conversation as resolved.
start_expr = compiler.process(args[1], **kwargs)
end_expr = compiler.process(args[2], **kwargs)
epoch_diff = f"CAST(EXTRACT(EPOCH FROM ({end_expr} - {start_expr})) AS NUMERIC)"
factor = _TIMESTAMPDIFF_UNIT_FACTOR[unit_token]
return f"TRUNC({epoch_diff}{factor})"
158 changes: 158 additions & 0 deletions test/test_timestampdiff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import pytest
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import Table
from sqlalchemy import TIMESTAMP
from sqlalchemy import literal
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.dialects.postgresql import dialect as postgresql_dialect
from sqlalchemy.sql import func
from sqlalchemy.testing import fixtures

from sqlalchemy_cockroachdb.psycopg2 import CockroachDBDialect_psycopg2
from sqlalchemy_cockroachdb.stmt_compiler import timestampdiff # noqa: F401 registers compiler


def _events_table():
metadata = MetaData()
return Table(
"events",
metadata,
Column("id", Integer, primary_key=True),
Column("start_date", TIMESTAMP),
Column("end_date", TIMESTAMP),
)


def _compile(stmt, dialect, literal_binds=True):
kwargs = {"literal_binds": True} if literal_binds else {}
return str(stmt.compile(dialect=dialect, compile_kwargs=kwargs))


class TimestampdiffCompilerTest(fixtures.TestBase):
"""Compile-only tests: no live database connection required."""

@pytest.fixture(autouse=True)
def _setup(self):
self.dialect = CockroachDBDialect_psycopg2()
self.events = _events_table()

@pytest.mark.parametrize(
"unit,expected_suffix",
[
("MICROSECOND", "* 1000000"),
("MILLISECOND", "* 1000"),
("MINUTE", "/ 60"),
("HOUR", "/ 3600"),
("DAY", "/ 86400"),
("WEEK", "/ 604800"),
],
)
def test_compiles_with_arithmetic_suffix(self, unit, expected_suffix):
expr = func.timestampdiff(text(unit), self.events.c.start_date, self.events.c.end_date)
sql = _compile(select(expr), self.dialect)
assert "EXTRACT(EPOCH FROM" in sql
assert "AS NUMERIC" in sql
assert expected_suffix in sql
assert "TRUNC(" in sql

def test_seconds_has_no_arithmetic_suffix(self):
"""SECOND has no factor token but is still wrapped in TRUNC for MySQL parity."""
expr = func.timestampdiff(text("SECOND"), self.events.c.start_date, self.events.c.end_date)
sql = _compile(select(expr), self.dialect)
assert "EXTRACT(EPOCH FROM" in sql
assert "AS NUMERIC" in sql
assert "TRUNC(" in sql
after_cast = sql.split("AS NUMERIC", 1)[1]
assert " * " not in after_cast
assert " / " not in after_cast

def test_truncates_to_match_mysql_semantics(self):
"""Regression for #301 review: MySQL TIMESTAMPDIFF returns BIGINT (truncated
toward zero), so a 90-second diff at MINUTE returns 1, not 1.5. The cockroachdb
compilation must wrap the arithmetic in TRUNC() so the *value* matches MySQL,
even though we deliberately keep the NUMERIC return type for downstream
divisor compatibility (see compiler docstring).
"""
Comment thread
viragtripathi marked this conversation as resolved.
expr = func.timestampdiff(text("MINUTE"), self.events.c.start_date, self.events.c.end_date)
sql = _compile(select(expr), self.dialect)
# TRUNC must wrap the divided expression, not just the EPOCH cast.
assert "TRUNC(" in sql
trunc_idx = sql.index("TRUNC(")
factor_idx = sql.index("/ 60")
assert trunc_idx < factor_idx, f"TRUNC should wrap the division; got SQL: {sql!r}"
# The closing paren of TRUNC must come after the factor: "/ 60)" must appear.
assert "/ 60)" in sql
Comment thread
viragtripathi marked this conversation as resolved.
Outdated

def test_lowercase_unit_accepted(self):
expr = func.timestampdiff(
text("microsecond"), self.events.c.start_date, self.events.c.end_date
)
sql = _compile(select(expr), self.dialect)
assert "EXTRACT(EPOCH FROM" in sql
assert "* 1000000" in sql
assert "TRUNC(" in sql

def test_unknown_unit_rejected(self):
expr = func.timestampdiff(
text("FORTNIGHT"), self.events.c.start_date, self.events.c.end_date
)
with pytest.raises(ValueError, match="Unsupported timestampdiff"):
_compile(select(expr), self.dialect)

@pytest.mark.parametrize("unit", ["MONTH", "QUARTER", "YEAR", "month", "Year"])
def test_calendar_aware_units_rejected_with_explanation(self, unit):
"""MONTH/QUARTER/YEAR must be rejected with a specific error explaining
why they're intentionally omitted (calendar-walking vs epoch arithmetic),
not the generic 'unsupported unit' error.
"""
expr = func.timestampdiff(text(unit), self.events.c.start_date, self.events.c.end_date)
with pytest.raises(ValueError, match="Calendar-aware units"):
_compile(select(expr), self.dialect)

def test_wrong_arity_rejected(self):
expr = func.timestampdiff(text("SECOND"), self.events.c.start_date)
with pytest.raises(ValueError, match="3 arguments"):
_compile(select(expr), self.dialect)

def test_postgresql_dialect_unaffected(self):
"""The cockroachdb compiler hook must not change rendering for other dialects."""
expr = func.timestampdiff(text("SECOND"), self.events.c.start_date, self.events.c.end_date)
sql = _compile(select(expr), postgresql_dialect())
assert "EXTRACT" not in sql
assert "timestampdiff" in sql.lower()

def test_plain_string_unit_accepted(self):
"""Plain Python string unit must resolve to a real unit, not a bound placeholder.

Comment thread
viragtripathi marked this conversation as resolved.
Compiles WITHOUT literal_binds to mirror how Airflow's ORM actually executes
statements — BindParameters render as ``$1`` / ``%(...)s`` unless we extract
the value at compile time.
"""
expr = func.timestampdiff("MICROSECOND", self.events.c.start_date, self.events.c.end_date)
sql = _compile(select(expr), self.dialect, literal_binds=False)
assert "EXTRACT(EPOCH FROM" in sql
assert "* 1000000" in sql
assert "TRUNC(" in sql
assert "%(" not in sql
assert "$1" not in sql

def test_literal_unit_accepted(self):
"""``literal('SECOND')`` should resolve via BindParameter value, not a placeholder."""
expr = func.timestampdiff(
literal("SECOND"), self.events.c.start_date, self.events.c.end_date
)
sql = _compile(select(expr), self.dialect, literal_binds=False)
assert "EXTRACT(EPOCH FROM" in sql
assert "AS NUMERIC" in sql
assert "TRUNC(" in sql
assert "%(" not in sql
assert "$1" not in sql

def test_non_string_bind_value_rejected_clearly(self):
"""A BindParameter whose value isn't a string must produce a clear error."""
expr = func.timestampdiff(123, self.events.c.start_date, self.events.c.end_date)
with pytest.raises(ValueError, match="must be a string"):
_compile(select(expr), self.dialect, literal_binds=False)
Loading