This commit is contained in:
2025-05-22 20:25:38 +02:00
parent 09f6750c2b
commit ce03fbf12f
529 changed files with 3353 additions and 3312 deletions

View File

@@ -50,6 +50,7 @@ class CursorSQL(SQLMatchRule):
if self.statement != stmt.statement or (
self.params is not None and self.params != stmt.parameters
):
self.consume_statement = True
self.errormessage = (
"Testing for exact SQL %s parameters %s received %s %s"
% (

View File

@@ -22,10 +22,15 @@ from typing import Tuple
from typing import TypeVar
from typing import Union
from . import mock
from . import requirements as _requirements
from .util import fail
from .. import util
requirements = None
# default requirements; this is replaced by plugin_base when pytest
# is run
requirements = _requirements.SuiteRequirements()
db = None
db_url = None
db_opts = None
@@ -42,7 +47,42 @@ if typing.TYPE_CHECKING:
_fixture_functions: FixtureFunctions
else:
_fixture_functions = None # installed by plugin_base
class _NullFixtureFunctions:
def _null_decorator(self):
def go(fn):
return fn
return go
def skip_test_exception(self, *arg, **kw):
return Exception()
@property
def add_to_marker(self):
return mock.Mock()
def mark_base_test_class(self):
return self._null_decorator()
def combinations(self, *arg_sets, **kw):
return self._null_decorator()
def param_ident(self, *parameters):
return self._null_decorator()
def fixture(self, *arg, **kw):
return self._null_decorator()
def get_current_test_name(self):
return None
def async_test(self, fn):
return fn
# default fixture functions; these are replaced by plugin_base when
# pytest runs
_fixture_functions = _NullFixtureFunctions()
_FN = TypeVar("_FN", bound=Callable[..., Any])
@@ -121,10 +161,7 @@ def combinations(
)
def combinations_list(
arg_iterable: Iterable[Tuple[Any,]],
**kw,
):
def combinations_list(arg_iterable: Iterable[Tuple[Any, ...]], **kw):
"As combination, but takes a single iterable"
return combinations(*arg_iterable, **kw)

View File

@@ -14,6 +14,7 @@ from argparse import Namespace
import configparser
import logging
import os
from pathlib import Path
import re
import sys
from typing import Any
@@ -320,6 +321,10 @@ def _log(opt_str, value, parser):
def _list_dbs(*args):
if file_config is None:
# assume the current working directory is the one containing the
# setup file
read_config(Path.cwd())
print("Available --db options (use --dburi to override)")
for macro in sorted(file_config.options("db")):
print("%20s\t%s" % (macro, file_config.get("db", macro)))
@@ -420,6 +425,7 @@ def _engine_uri(options, file_config):
from sqlalchemy import testing
from sqlalchemy.testing import config
from sqlalchemy.testing import provision
from sqlalchemy.engine import url as sa_url
if options.dburi:
db_urls = list(options.dburi)
@@ -444,18 +450,19 @@ def _engine_uri(options, file_config):
config._current = None
expanded_urls = list(provision.generate_db_urls(db_urls, extra_drivers))
for db_url in expanded_urls:
log.info("Adding database URL: %s", db_url)
if options.write_idents and provision.FOLLOWER_IDENT:
if options.write_idents and provision.FOLLOWER_IDENT:
for db_url in [sa_url.make_url(db_url) for db_url in db_urls]:
with open(options.write_idents, "a") as file_:
file_.write(
f"{provision.FOLLOWER_IDENT} "
f"{db_url.render_as_string(hide_password=False)}\n"
)
expanded_urls = list(provision.generate_db_urls(db_urls, extra_drivers))
for db_url in expanded_urls:
log.info("Adding database URL: %s", db_url)
cfg = provision.setup_config(
db_url, options, file_config, provision.FOLLOWER_IDENT
)
@@ -473,9 +480,6 @@ def _setup_requirements(argument):
from sqlalchemy.testing import config
from sqlalchemy import testing
if config.requirements is not None:
return
modname, clsname = argument.split(":")
# importlib.import_module() only introduced in 2.7, a little

View File

@@ -22,9 +22,8 @@ from __future__ import annotations
import platform
from . import asyncio as _test_asyncio
from . import config
from . import exclusions
from . import only_on
from .exclusions import only_on
from .. import create_engine
from .. import util
from ..pool import QueuePool
@@ -59,6 +58,12 @@ class SuiteRequirements(Requirements):
return exclusions.closed()
@property
def uuid_data_type(self):
"""Return databases that support the UUID datatype."""
return exclusions.closed()
@property
def foreign_keys(self):
"""Target database must support foreign keys."""
@@ -840,6 +845,14 @@ class SuiteRequirements(Requirements):
"""Target driver can create tables with a name like 'some " table'"""
return exclusions.open()
@property
def datetime_interval(self):
"""target dialect supports rendering of a datetime.timedelta as a
literal string, e.g. via the TypeEngine.literal_processor() method.
"""
return exclusions.closed()
@property
def datetime_literals(self):
"""target dialect supports rendering of a date, time, or datetime as a
@@ -1448,10 +1461,14 @@ class SuiteRequirements(Requirements):
@property
def timing_intensive(self):
from . import config
return config.add_to_marker.timing_intensive
@property
def memory_intensive(self):
from . import config
return config.add_to_marker.memory_intensive
@property

View File

@@ -18,7 +18,6 @@ from ... import literal_column
from ... import Numeric
from ... import select
from ... import String
from ...dialects.postgresql import BYTEA
from ...types import LargeBinary
from ...types import UUID
from ...types import Uuid
@@ -104,6 +103,15 @@ class InsertBehaviorTest(fixtures.TablesTest):
Column("id", Integer, primary_key=True, autoincrement=False),
Column("data", String(50)),
)
Table(
"no_implicit_returning",
metadata,
Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
),
Column("data", String(50)),
implicit_returning=False,
)
Table(
"includes_defaults",
metadata,
@@ -119,6 +127,33 @@ class InsertBehaviorTest(fixtures.TablesTest):
),
)
@testing.variation("style", ["plain", "return_defaults"])
@testing.variation("executemany", [True, False])
def test_no_results_for_non_returning_insert(
self, connection, style, executemany
):
"""test another INSERT issue found during #10453"""
table = self.tables.no_implicit_returning
stmt = table.insert()
if style.return_defaults:
stmt = stmt.return_defaults()
if executemany:
data = [
{"data": "d1"},
{"data": "d2"},
{"data": "d3"},
{"data": "d4"},
{"data": "d5"},
]
else:
data = {"data": "d1"}
r = connection.execute(stmt, data)
assert not r.returns_rows
@requirements.autoincrement_insert
def test_autoclose_on_insert(self, connection):
r = connection.execute(
@@ -394,7 +429,7 @@ class ReturningTest(fixtures.TablesTest):
True,
testing.requires.float_or_double_precision_behaves_generically,
),
(Float(), 8.5514, False),
(Float(), 8.5514, True),
(
Float(8),
8.5514,
@@ -517,7 +552,6 @@ class ReturningTest(fixtures.TablesTest):
b"this is binary",
),
("LargeBinary2", LargeBinary(), b"7\xe7\x9f"),
("PG BYTEA", BYTEA(), b"7\xe7\x9f", testing.only_on("postgresql")),
argnames="type_,value",
id_="iaa",
)

View File

@@ -287,6 +287,65 @@ class HasIndexTest(fixtures.TablesTest):
)
class BizarroCharacterFKResolutionTest(fixtures.TestBase):
"""tests for #10275"""
__backend__ = True
@testing.combinations(
("id",), ("(3)",), ("col%p",), ("[brack]",), argnames="columnname"
)
@testing.variation("use_composite", [True, False])
@testing.combinations(
("plain",),
("(2)",),
("per % cent",),
("[brackets]",),
argnames="tablename",
)
def test_fk_ref(
self, connection, metadata, use_composite, tablename, columnname
):
tt = Table(
tablename,
metadata,
Column(columnname, Integer, key="id", primary_key=True),
test_needs_fk=True,
)
if use_composite:
tt.append_column(Column("id2", Integer, primary_key=True))
if use_composite:
Table(
"other",
metadata,
Column("id", Integer, primary_key=True),
Column("ref", Integer),
Column("ref2", Integer),
sa.ForeignKeyConstraint(["ref", "ref2"], [tt.c.id, tt.c.id2]),
test_needs_fk=True,
)
else:
Table(
"other",
metadata,
Column("id", Integer, primary_key=True),
Column("ref", ForeignKey(tt.c.id)),
test_needs_fk=True,
)
metadata.create_all(connection)
m2 = MetaData()
o2 = Table("other", m2, autoload_with=connection)
t1 = m2.tables[tablename]
assert o2.c.ref.references(t1.c[0])
if use_composite:
assert o2.c.ref2.references(t1.c[1])
class QuotedNameArgumentTest(fixtures.TablesTest):
run_create_tables = "once"
__backend__ = True
@@ -3053,6 +3112,7 @@ __all__ = (
"ComponentReflectionTestExtra",
"TableNoColumnsTest",
"QuotedNameArgumentTest",
"BizarroCharacterFKResolutionTest",
"HasTableTest",
"HasIndexTest",
"NormalizedNameTest",

View File

@@ -254,7 +254,7 @@ class ServerSideCursorsTest(
elif self.engine.dialect.driver == "pymysql":
sscursor = __import__("pymysql.cursors").cursors.SSCursor
return isinstance(cursor, sscursor)
elif self.engine.dialect.driver in ("aiomysql", "asyncmy"):
elif self.engine.dialect.driver in ("aiomysql", "asyncmy", "aioodbc"):
return cursor.server_side
elif self.engine.dialect.driver == "mysqldb":
sscursor = __import__("MySQLdb.cursors").cursors.SSCursor
@@ -311,7 +311,7 @@ class ServerSideCursorsTest(
True,
"SELECT 1 FOR UPDATE",
True,
testing.skip_if("sqlite"),
testing.skip_if(["sqlite", "mssql"]),
),
("text_no_ss", False, text("select 42"), False),
(

View File

@@ -66,6 +66,49 @@ class RowCountTest(fixtures.TablesTest):
eq_(rows, self.data)
@testing.variation("statement", ["update", "delete", "insert", "select"])
@testing.variation("close_first", [True, False])
def test_non_rowcount_scenarios_no_raise(
self, connection, statement, close_first
):
employees_table = self.tables.employees
# WHERE matches 3, 3 rows changed
department = employees_table.c.department
if statement.update:
r = connection.execute(
employees_table.update().where(department == "C"),
{"department": "Z"},
)
elif statement.delete:
r = connection.execute(
employees_table.delete().where(department == "C"),
{"department": "Z"},
)
elif statement.insert:
r = connection.execute(
employees_table.insert(),
[
{"employee_id": 25, "name": "none 1", "department": "X"},
{"employee_id": 26, "name": "none 2", "department": "Z"},
{"employee_id": 27, "name": "none 3", "department": "Z"},
],
)
elif statement.select:
s = select(
employees_table.c.name, employees_table.c.department
).where(employees_table.c.department == "C")
r = connection.execute(s)
r.all()
else:
statement.fail()
if close_first:
r.close()
assert r.rowcount in (-1, 3)
def test_update_rowcount1(self, connection):
employees_table = self.tables.employees

View File

@@ -28,6 +28,7 @@ from ... import Date
from ... import DateTime
from ... import Float
from ... import Integer
from ... import Interval
from ... import JSON
from ... import literal
from ... import literal_column
@@ -82,6 +83,11 @@ class _LiteralRoundTripFixture:
)
connection.execute(ins)
ins = t.insert().values(
x=literal(None, type_, literal_execute=True)
)
connection.execute(ins)
if support_whereclause and self.supports_whereclause:
if compare:
stmt = t.select().where(
@@ -108,7 +114,7 @@ class _LiteralRoundTripFixture:
)
)
else:
stmt = t.select()
stmt = t.select().where(t.c.x.is_not(None))
rows = connection.execute(stmt).all()
assert rows, "No rows returned"
@@ -118,6 +124,10 @@ class _LiteralRoundTripFixture:
value = filter_(value)
assert value in output
stmt = t.select().where(t.c.x.is_(None))
rows = connection.execute(stmt).all()
eq_(rows, [(None,)])
return run
@@ -452,6 +462,102 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase):
)
class IntervalTest(_LiteralRoundTripFixture, fixtures.TestBase):
__requires__ = ("datetime_interval",)
__backend__ = True
datatype = Interval
data = datetime.timedelta(days=1, seconds=4)
def test_literal(self, literal_round_trip):
literal_round_trip(self.datatype, [self.data], [self.data])
def test_select_direct_literal_interval(self, connection):
row = connection.execute(select(literal(self.data))).first()
eq_(row, (self.data,))
def test_arithmetic_operation_literal_interval(self, connection):
now = datetime.datetime.now().replace(microsecond=0)
# Able to subtract
row = connection.execute(
select(literal(now) - literal(self.data))
).scalar()
eq_(row, now - self.data)
# Able to Add
row = connection.execute(
select(literal(now) + literal(self.data))
).scalar()
eq_(row, now + self.data)
@testing.fixture
def arithmetic_table_fixture(cls, metadata, connection):
class Decorated(TypeDecorator):
impl = cls.datatype
cache_ok = True
it = Table(
"interval_table",
metadata,
Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
),
Column("interval_data", cls.datatype),
Column("date_data", DateTime),
Column("decorated_interval_data", Decorated),
)
it.create(connection)
return it
def test_arithmetic_operation_table_interval_and_literal_interval(
self, connection, arithmetic_table_fixture
):
interval_table = arithmetic_table_fixture
data = datetime.timedelta(days=2, seconds=5)
connection.execute(
interval_table.insert(), {"id": 1, "interval_data": data}
)
# Subtraction Operation
value = connection.execute(
select(interval_table.c.interval_data - literal(self.data))
).scalar()
eq_(value, data - self.data)
# Addition Operation
value = connection.execute(
select(interval_table.c.interval_data + literal(self.data))
).scalar()
eq_(value, data + self.data)
def test_arithmetic_operation_table_date_and_literal_interval(
self, connection, arithmetic_table_fixture
):
interval_table = arithmetic_table_fixture
now = datetime.datetime.now().replace(microsecond=0)
connection.execute(
interval_table.insert(), {"id": 1, "date_data": now}
)
# Subtraction Operation
value = connection.execute(
select(interval_table.c.date_data - literal(self.data))
).scalar()
eq_(value, (now - self.data))
# Addition Operation
value = connection.execute(
select(interval_table.c.date_data + literal(self.data))
).scalar()
eq_(value, (now + self.data))
class PrecisionIntervalTest(IntervalTest):
__requires__ = ("datetime_interval",)
__backend__ = True
datatype = Interval(day_precision=9, second_precision=9)
data = datetime.timedelta(days=103, seconds=4)
class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase):
compare = None
@@ -1940,6 +2046,8 @@ __all__ = (
"TextTest",
"NumericTest",
"IntegerTest",
"IntervalTest",
"PrecisionIntervalTest",
"CastTypeDecoratorTest",
"DateTimeHistoricTest",
"DateTimeCoercedToDateTimeTest",

View File

@@ -6,6 +6,7 @@ from ..schema import Column
from ..schema import Table
from ... import Integer
from ... import String
from ... import testing
class SimpleUpdateDeleteTest(fixtures.TablesTest):
@@ -58,5 +59,71 @@ class SimpleUpdateDeleteTest(fixtures.TablesTest):
[(1, "d1"), (3, "d3")],
)
@testing.variation("criteria", ["rows", "norows", "emptyin"])
@testing.requires.update_returning
def test_update_returning(self, connection, criteria):
t = self.tables.plain_pk
stmt = t.update().returning(t.c.id, t.c.data)
if criteria.norows:
stmt = stmt.where(t.c.id == 10)
elif criteria.rows:
stmt = stmt.where(t.c.id == 2)
elif criteria.emptyin:
stmt = stmt.where(t.c.id.in_([]))
else:
criteria.fail()
r = connection.execute(stmt, dict(data="d2_new"))
assert not r.is_insert
assert r.returns_rows
eq_(r.keys(), ["id", "data"])
if criteria.rows:
eq_(r.all(), [(2, "d2_new")])
else:
eq_(r.all(), [])
eq_(
connection.execute(t.select().order_by(t.c.id)).fetchall(),
[(1, "d1"), (2, "d2_new"), (3, "d3")]
if criteria.rows
else [(1, "d1"), (2, "d2"), (3, "d3")],
)
@testing.variation("criteria", ["rows", "norows", "emptyin"])
@testing.requires.delete_returning
def test_delete_returning(self, connection, criteria):
t = self.tables.plain_pk
stmt = t.delete().returning(t.c.id, t.c.data)
if criteria.norows:
stmt = stmt.where(t.c.id == 10)
elif criteria.rows:
stmt = stmt.where(t.c.id == 2)
elif criteria.emptyin:
stmt = stmt.where(t.c.id.in_([]))
else:
criteria.fail()
r = connection.execute(stmt)
assert not r.is_insert
assert r.returns_rows
eq_(r.keys(), ["id", "data"])
if criteria.rows:
eq_(r.all(), [(2, "d2")])
else:
eq_(r.all(), [])
eq_(
connection.execute(t.select().order_by(t.c.id)).fetchall(),
[(1, "d1"), (3, "d3")]
if criteria.rows
else [(1, "d1"), (2, "d2"), (3, "d3")],
)
__all__ = ("SimpleUpdateDeleteTest",)