Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Version 2.0.5
Unreleased

- Re-work get_multi_columns() to
- include identity column info (#297)
- avoid parse error when reflecting ENUMs (#303)

# Version 2.0.4
April 23, 2026
Expand Down
282 changes: 97 additions & 185 deletions sqlalchemy_cockroachdb/base.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,26 @@
import collections
import re
import threading
from sqlalchemy import text
from sqlalchemy import ARRAY
from sqlalchemy import BIGINT
from sqlalchemy import BLOB
from sqlalchemy import DECIMAL
from sqlalchemy import DOUBLE_PRECISION
from sqlalchemy import FLOAT
from sqlalchemy import INTEGER
from sqlalchemy import NUMERIC
from sqlalchemy import REAL
from sqlalchemy import SMALLINT
from sqlalchemy import TEXT
from sqlalchemy import VARCHAR
from sqlalchemy.dialects.postgresql.base import PGDialect
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.dialects.postgresql import INET
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.dialects.postgresql import BYTEA
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.util import warn

import sqlalchemy.types as sqltypes

from .stmt_compiler import CockroachCompiler, CockroachIdentifierPreparer
from .ddl_compiler import CockroachDDLCompiler


# Map type names (as returned by information_schema) to sqlalchemy type
# objects.
#
# TODO(bdarnell): test more of these. The stock test suite only covers
# a few basic ones.
_type_map = {
"bool": sqltypes.BOOLEAN, # introspection returns "BOOL" not boolean
"boolean": sqltypes.BOOLEAN,
"bigint": sqltypes.INT,
"int": sqltypes.INT,
"int2": sqltypes.INT,
"int4": sqltypes.INT,
"int64": sqltypes.INT,
"int8": sqltypes.INT,
"integer": sqltypes.INT,
"smallint": sqltypes.INT,
"double precision": sqltypes.FLOAT,
"float": sqltypes.FLOAT,
"float4": sqltypes.FLOAT,
"float8": sqltypes.FLOAT,
"real": sqltypes.FLOAT,
"dec": sqltypes.DECIMAL,
"decimal": sqltypes.DECIMAL,
"numeric": sqltypes.DECIMAL,
"date": sqltypes.DATE,
"time": sqltypes.Time,
"time without time zone": sqltypes.Time,
"timestamp": sqltypes.TIMESTAMP,
"timestamptz": sqltypes.TIMESTAMP(timezone=True),
"timestamp with time zone": sqltypes.TIMESTAMP(timezone=True),
"timestamp without time zone": sqltypes.TIMESTAMP,
"interval": sqltypes.Interval,
"char": sqltypes.CHAR,
"char varying": sqltypes.VARCHAR,
"character": sqltypes.CHAR,
"character varying": sqltypes.VARCHAR,
"string": sqltypes.VARCHAR,
"text": sqltypes.VARCHAR,
"varchar": sqltypes.VARCHAR,
"blob": sqltypes.BLOB,
"bytea": sqltypes.BLOB,
"bytes": sqltypes.BLOB,
"json": sqltypes.JSON,
"jsonb": JSONB,
"uuid": UUID,
"inet": INET,
}


class _SavepointState(threading.local):
"""Hack to override names used in savepoint statements.

Expand Down Expand Up @@ -152,7 +108,7 @@ def _get_server_version_info(self, conn):
# PGDialect expects a postgres server version number here,
# although we've overridden most of the places where it's
# used.
return (9, 5, 0)
return (12, 0, 0)

def get_table_names(self, conn, schema=None, **kw):
# Upstream implementation needs correlated subqueries.
Expand All @@ -175,126 +131,88 @@ def has_table(self, conn, table, schema=None, info_cache=None):
return any(t == table for t in self.get_table_names(conn, schema=schema))

def get_multi_columns(self, connection, schema, filter_names, scope, kind, **kw):
if not filter_names:
filter_names = self.get_table_names(connection, schema)
return {
(schema, table_name): self.get_columns(connection, table_name, schema, **kw)
for table_name in filter_names
}

# The upstream implementations of the reflection functions below depend on
# correlated subqueries which are not yet supported.
def get_columns(self, conn, table_name, schema=None, **kw):
_include_hidden = kw.get("include_hidden", False)
if not self._is_v191plus:
# v2.x does not have is_generated or generation_expression
sql = (
"SELECT column_name, data_type, is_nullable::bool, column_default,"
"numeric_precision, numeric_scale, character_maximum_length, "
"NULL AS is_generated, NULL AS generation_expression, is_hidden::bool,"
"column_comment AS comment "
"FROM information_schema.columns "
"WHERE table_schema = :table_schema AND table_name = :table_name "
)
sql += "" if _include_hidden else "AND NOT is_hidden::bool"
rows = conn.execute(
text(sql),
{"table_schema": schema or self.default_schema_name, "table_name": table_name},
)
else:
# v19.1 or later. Information schema columns are all usable.
sql = (
"SELECT column_name, data_type, is_nullable::bool, column_default, "
"numeric_precision, numeric_scale, character_maximum_length, "
"CASE is_generated WHEN 'ALWAYS' THEN true WHEN 'NEVER' THEN false "
"ELSE is_generated::bool END AS is_generated, "
"generation_expression, is_hidden::bool, crdb_sql_type, column_comment AS comment "
"FROM information_schema.columns "
"WHERE table_schema = :table_schema AND table_name = :table_name "
)
sql += "" if _include_hidden else "AND NOT is_hidden::bool"
rows = conn.execute(
text(sql),
{"table_schema": schema or self.default_schema_name, "table_name": table_name},
)

res = []
for row in rows:
name, type_str, nullable, default = row[:4]
if type_str == "ARRAY":
is_array = True
type_str, _ = row.crdb_sql_type.split("[", maxsplit=1)
else:
is_array = False
# When there are type parameters, attach them to the
# returned type object.
m = re.match(r"^(\w+(?: \w+)*)(?:\(([0-9, ]*)\))?$", type_str)
if m is None:
warn("Could not parse type name '%s'" % type_str)
typ = sqltypes.NULLTYPE
else:
type_name, type_args = m.groups()
try:
type_class = _type_map[type_name.lower()]
except KeyError:
warn(f"Did not recognize type '{type_name}' of column '{name}'")
type_class = sqltypes.NULLTYPE
if type_args:
typ = type_class(*[int(s.strip()) for s in type_args.split(",")])
elif type_class is sqltypes.DECIMAL:
typ = type_class(
precision=row.numeric_precision,
scale=row.numeric_scale,
)
elif type_class is sqltypes.VARCHAR or type_class is sqltypes.CHAR:
typ = type_class(length=row.character_maximum_length)
else:
typ = type_class
if row.is_generated:
# Currently, all computed columns are persisted.
computed = dict(sqltext=row.generation_expression, persisted=True)
default = None
else:
computed = None
# Check if a sequence is being used and adjust the default value.
autoincrement = False
if default is not None:
nextval_match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the old code rewrote bare nextval('seq') to become nextval('"schema".seq') when a schema was provided. The new code drops it entirely, unless I missed somehing.

Was it intentional? If so, a note in the commit message explaining why would be useful. And possibly even a mention in CHANGES.md?

Tools that depend on schema-qualified sequence names might get affected.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, cockroachdb/cockroach#170049 should take care of that.

unique_rowid_match = re.search(r"""unique_rowid\(""", default)
if nextval_match is not None or unique_rowid_match is not None:
if issubclass(type_class, sqltypes.Integer):
autoincrement = True
# the default is related to a Sequence
sch = schema
if (
nextval_match is not None
and "." not in nextval_match.group(2)
and sch is not None
):
# unconditionally quote the schema name. this could
# later be enhanced to obey quoting rules /
# "quote schema"
default = (
nextval_match.group(1)
+ ('"%s"' % sch)
+ "."
+ nextval_match.group(2)
+ nextval_match.group(3)
multi_columns = super().get_multi_columns(
connection, schema, filter_names, scope, kind, **kw
)
to_return = []
if multi_columns:
current = connection.execute(
text("select current_database() as db, current_schema() as schema")
).one()
for table, columns in multi_columns:
if table not in [
(None, "geography_columns"),
(None, "geometry_columns"),
(None, "spatial_ref_sys"),
]:
table_columns = (
connection.execute(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is doing a per-table information_schema.columns lookup inside the loop, on top of the one PGDialect already ran. Could we replace with one bulk SELECT keyed by (schema, table) that we do outside the loop and add to a Python dict?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to keep it simple for now. We can re-visit this if user feedback (re: performance) warrants.

text(
"select column_name, is_hidden::bool "
"from information_schema.columns "
"where table_catalog = :tc "
"and table_schema = :ts and table_name = :tn"
),
dict(tc=current.db, ts=table[0] or current.schema, tn=table[1]),
Comment thread
gordthompson marked this conversation as resolved.
)

column_info = dict(
name=name,
type=ARRAY(typ) if is_array else typ,
nullable=nullable,
default=default,
autoincrement=autoincrement,
is_hidden=row.is_hidden,
comment=row.comment,
)
if computed is not None:
column_info["computed"] = computed
res.append(column_info)
return res
.mappings()
.all()
)
is_hidden = {x["column_name"]: x["is_hidden"] for x in table_columns}
for col in columns[:]:
if is_hidden[col["name"]] and not _include_hidden:
Comment thread
gordthompson marked this conversation as resolved.
columns.remove(col)
Comment thread
gordthompson marked this conversation as resolved.
else:
col["is_hidden"] = is_hidden[col["name"]]
if col["default"] == "unique_rowid()":
col["autoincrement"] = True
if isinstance(col["type"], BIGINT):
col["type"] = INTEGER()
elif isinstance(col["type"], ARRAY) and isinstance(
col["type"].item_type, BIGINT
):
col["type"].item_type = INTEGER()
elif isinstance(col["type"], BYTEA):
col["type"] = BLOB()
elif isinstance(col["type"], ARRAY) and isinstance(
col["type"].item_type, BYTEA
):
col["type"].item_type = BLOB()
elif isinstance(col["type"], DOUBLE_PRECISION):
col["type"] = FLOAT()
elif isinstance(col["type"], ARRAY) and isinstance(
col["type"].item_type, DOUBLE_PRECISION
):
col["type"].item_type = FLOAT()
elif isinstance(col["type"], NUMERIC):
col["type"] = DECIMAL(col["type"].precision, col["type"].scale)
elif isinstance(col["type"], ARRAY) and isinstance(
col["type"].item_type, NUMERIC
):
col["type"].item_type = DECIMAL(
col["type"].item_type.precision, col["type"].item_type.scale
)
elif isinstance(col["type"], REAL):
col["type"] = FLOAT()
elif isinstance(col["type"], ARRAY) and isinstance(
col["type"].item_type, REAL
):
col["type"].item_type = FLOAT()
elif isinstance(col["type"], SMALLINT):
col["type"] = INTEGER()
elif isinstance(col["type"], ARRAY) and isinstance(
col["type"].item_type, SMALLINT
):
col["type"].item_type = INTEGER()
elif isinstance(col["type"], TEXT):
col["type"] = VARCHAR()
elif isinstance(col["type"], ARRAY) and isinstance(
col["type"].item_type, TEXT
):
col["type"].item_type = VARCHAR()
to_return.append((table, columns))
return to_return

def get_indexes(self, conn, table_name, schema=None, **kw):
if self._is_v192plus:
Expand Down Expand Up @@ -348,12 +266,8 @@ def get_indexes(self, conn, table_name, schema=None, **kw):
)
return result

def get_multi_indexes(
self, connection, schema, filter_names, scope, kind, **kw
):
result = super().get_multi_indexes(
connection, schema, filter_names, scope, kind, **kw
)
def get_multi_indexes(self, connection, schema, filter_names, scope, kind, **kw):
result = super().get_multi_indexes(connection, schema, filter_names, scope, kind, **kw)
if schema is None:
result = dict(result)
for k in [
Expand Down Expand Up @@ -418,9 +332,7 @@ def get_unique_constraints(self, conn, table_name, schema=None, **kw):
res.append(index)
return res

def get_multi_check_constraints(
self, connection, schema, filter_names, scope, kind, **kw
):
def get_multi_check_constraints(self, connection, schema, filter_names, scope, kind, **kw):
result = super().get_multi_check_constraints(
connection, schema, filter_names, scope, kind, **kw
)
Expand Down
Loading