This commit is contained in:
2025-05-22 20:25:38 +02:00
parent 09f6750c2b
commit ce03fbf12f
529 changed files with 3353 additions and 3312 deletions

View File

@@ -411,7 +411,7 @@ def Nullable(
.. versionadded:: 2.0.20
"""
return val # type: ignore
return val
@overload

View File

@@ -300,7 +300,7 @@ class Annotated(SupportsAnnotations):
def _annotate(self, values: _AnnotationDict) -> Self:
_values = self._annotations.union(values)
new = self._with_annotations(_values) # type: ignore
new = self._with_annotations(_values)
return new
def _with_annotations(self, values: _AnnotationDict) -> Self:

View File

@@ -273,7 +273,7 @@ def _generative(fn: _Fn) -> _Fn:
"""
@util.decorator # type: ignore
@util.decorator
def _generative(
fn: _Fn, self: _SelfGenerativeType, *args: Any, **kw: Any
) -> _SelfGenerativeType:
@@ -299,7 +299,7 @@ def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]:
for name in names
]
@util.decorator # type: ignore
@util.decorator
def check(fn, *args, **kw):
# make pylance happy by not including "self" in the argument
# list
@@ -315,7 +315,7 @@ def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]:
raise exc.InvalidRequestError(msg)
return fn(self, *args, **kw)
return check # type: ignore
return check
def _clone(element, **kw):
@@ -1176,6 +1176,7 @@ class Executable(roles.StatementRole):
autoflush: bool = False,
synchronize_session: SynchronizeSessionArgument = ...,
dml_strategy: DMLStrategyArgument = ...,
render_nulls: bool = ...,
is_delete_using: bool = ...,
is_update_from: bool = ...,
**opt: Any,

View File

@@ -546,6 +546,9 @@ class CacheKey(NamedTuple):
def _apply_params_to_element(
self, original_cache_key: CacheKey, target_element: ClauseElement
) -> ClauseElement:
if target_element._is_immutable:
return target_element
translate = {
k.key: v.value
for k, v in zip(original_cache_key.bindparams, self.bindparams)

View File

@@ -2080,14 +2080,12 @@ class SQLCompiler(Compiled):
if parameter in self.literal_execute_params:
if escaped_name not in replacement_expressions:
value = parameters.pop(escaped_name)
replacement_expressions[
escaped_name
] = self.render_literal_bindparam(
parameter,
render_literal_value=value,
)
replacement_expressions[
escaped_name
] = self.render_literal_bindparam(
parameter,
render_literal_value=parameters.pop(escaped_name),
)
continue
if parameter in self.post_compile_params:
@@ -2742,6 +2740,7 @@ class SQLCompiler(Compiled):
except KeyError as err:
raise exc.UnsupportedCompilationError(self, operator_) from err
else:
kw["_in_operator_expression"] = True
return self._generate_delimited_list(
clauselist.clauses, opstring, **kw
)
@@ -3370,9 +3369,9 @@ class SQLCompiler(Compiled):
def _generate_generic_binary(
self, binary, opstring, eager_grouping=False, **kw
):
_in_binary = kw.get("_in_binary", False)
_in_operator_expression = kw.get("_in_operator_expression", False)
kw["_in_binary"] = True
kw["_in_operator_expression"] = True
kw["_binary_op"] = binary.operator
text = (
binary.left._compiler_dispatch(
@@ -3384,7 +3383,7 @@ class SQLCompiler(Compiled):
)
)
if _in_binary and eager_grouping:
if _in_operator_expression and eager_grouping:
text = "(%s)" % text
return text
@@ -3767,6 +3766,12 @@ class SQLCompiler(Compiled):
"""
if value is None and not type_.should_evaluate_none:
# issue #10535 - handle NULL in the compiler without placing
# this onto each type, except for "evaluate None" types
# (e.g. JSON)
return self.process(elements.Null._instance())
processor = type_._cached_literal_processor(self.dialect)
if processor:
try:
@@ -4089,7 +4094,7 @@ class SQLCompiler(Compiled):
from_linter.froms[cte._de_clone()] = cte_name
if not is_new_cte and embedded_in_current_named_cte:
return self.preparer.format_alias(cte, cte_name) # type: ignore[no-any-return] # noqa: E501
return self.preparer.format_alias(cte, cte_name)
if cte_pre_alias_name:
text = self.preparer.format_alias(cte, cte_pre_alias_name)
@@ -6688,8 +6693,6 @@ class DDLCompiler(Compiled):
text.append("NO MAXVALUE")
if identity_options.cache is not None:
text.append("CACHE %d" % identity_options.cache)
if identity_options.order is not None:
text.append("ORDER" if identity_options.order else "NO ORDER")
if identity_options.cycle is not None:
text.append("CYCLE" if identity_options.cycle else "NO CYCLE")
return " ".join(text)

View File

@@ -491,10 +491,10 @@ def _key_getters_for_crud_column(
key: Union[ColumnClause[Any], str]
) -> Union[str, Tuple[str, str]]:
str_key = c_key_role(key)
if hasattr(key, "table") and key.table in _et: # type: ignore
if hasattr(key, "table") and key.table in _et:
return (key.table.name, str_key) # type: ignore
else:
return str_key # type: ignore
return str_key
def _getattr_col_key(
col: ColumnClause[Any],
@@ -513,7 +513,7 @@ def _key_getters_for_crud_column(
return col.key
else:
_column_as_key = functools.partial( # type: ignore
_column_as_key = functools.partial(
coercions.expect_as_key, roles.DMLColumnRole
)
_getattr_col_key = _col_bind_name = operator.attrgetter("key") # type: ignore # noqa: E501
@@ -647,6 +647,9 @@ def _scan_cols(
compiler_implicit_returning = compiler.implicit_returning
# TODO - see TODO(return_defaults_columns) below
# cols_in_params = set()
for c in cols:
# scan through every column in the target table
@@ -672,6 +675,9 @@ def _scan_cols(
kw,
)
# TODO - see TODO(return_defaults_columns) below
# cols_in_params.add(c)
elif isinsert:
# no parameter is present and it's an insert.
@@ -764,6 +770,19 @@ def _scan_cols(
if c in remaining_supplemental
)
# TODO(return_defaults_columns): there can still be more columns in
# _return_defaults_columns in the case that they are from something like an
# aliased of the table. we can add them here, however this breaks other ORM
# things. so this is for another day. see
# test/orm/dml/test_update_delete_where.py -> test_update_from_alias
# if stmt._return_defaults_columns:
# compiler_implicit_returning.extend(
# set(stmt._return_defaults_columns)
# .difference(compiler_implicit_returning)
# .difference(cols_in_params)
# )
return (use_insertmanyvalues, use_sentinel_columns)
@@ -1559,7 +1578,11 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
should_implicit_return_defaults = (
implicit_returning and stmt._return_defaults
)
explicit_returning = should_implicit_return_defaults or stmt._returning
explicit_returning = (
should_implicit_return_defaults
or stmt._returning
or stmt._supplemental_returning
)
use_insertmanyvalues = (
toplevel
and compiler.for_executemany

View File

@@ -403,17 +403,14 @@ class DDL(ExecutableDDLElement):
self.context = context or {}
def __repr__(self):
parts = [repr(self.statement)]
if self.context:
parts.append(f"context={self.context}")
return "<%s@%s; %s>" % (
type(self).__name__,
id(self),
", ".join(
[repr(self.statement)]
+ [
"%s=%r" % (key, getattr(self, key))
for key in ("on", "context")
if getattr(self, key)
]
),
", ".join(parts),
)
@@ -470,7 +467,7 @@ class CreateSchema(_CreateBase):
__visit_name__ = "create_schema"
stringify_dialect = "default" # type: ignore
stringify_dialect = "default"
def __init__(
self,
@@ -491,7 +488,7 @@ class DropSchema(_DropBase):
__visit_name__ = "drop_schema"
stringify_dialect = "default" # type: ignore
stringify_dialect = "default"
def __init__(
self,

View File

@@ -211,7 +211,11 @@ class DMLState(CompileState):
primary_table = all_tables[0]
seen = {primary_table}
for crit in statement._where_criteria:
consider = statement._where_criteria
if self._dict_parameters:
consider += tuple(self._dict_parameters.values())
for crit in consider:
for item in _from_objects(crit):
if not seen.intersection(item._cloned_set):
froms.append(item)
@@ -563,7 +567,8 @@ class UpdateBase(
3. :meth:`.UpdateBase.return_defaults` can be called against any
backend. Backends that don't support RETURNING will skip the usage
of the feature, rather than raising an exception. The return value
of the feature, rather than raising an exception, *unless*
``supplemental_cols`` is passed. The return value
of :attr:`_engine.CursorResult.returned_defaults` will be ``None``
for backends that don't support RETURNING or for which the target
:class:`.Table` sets :paramref:`.Table.implicit_returning` to

View File

@@ -117,6 +117,7 @@ _NUMERIC = Union[float, Decimal]
_NUMBER = Union[float, int, Decimal]
_T = TypeVar("_T", bound="Any")
_T_co = TypeVar("_T_co", bound=Any, covariant=True)
_OPT = TypeVar("_OPT", bound="Any")
_NT = TypeVar("_NT", bound="_NUMERIC")
@@ -804,7 +805,7 @@ class CompilerColumnElement(
# SQLCoreOperations should be suiting the ExpressionElementRole
# and ColumnsClauseRole. however the MRO issues become too elaborate
# at the moment.
class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly):
__slots__ = ()
# annotations for comparison methods
@@ -873,7 +874,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
def __or__(self, other: Any) -> BooleanClauseList:
...
def __invert__(self) -> ColumnElement[_T]:
def __invert__(self) -> ColumnElement[_T_co]:
...
def __lt__(self, other: Any) -> ColumnElement[bool]:
@@ -882,6 +883,13 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
def __le__(self, other: Any) -> ColumnElement[bool]:
...
# declare also that this class has an hash method otherwise
# it may be assumed to be None by type checkers since the
# object defines __eq__ and python sets it to None in that case:
# https://docs.python.org/3/reference/datamodel.html#object.__hash__
def __hash__(self) -> int:
...
def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
...
@@ -900,7 +908,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
def __ge__(self, other: Any) -> ColumnElement[bool]:
...
def __neg__(self) -> UnaryExpression[_T]:
def __neg__(self) -> UnaryExpression[_T_co]:
...
def __contains__(self, other: Any) -> ColumnElement[bool]:
@@ -961,7 +969,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
def bitwise_and(self, other: Any) -> BinaryExpression[Any]:
...
def bitwise_not(self) -> UnaryExpression[_T]:
def bitwise_not(self) -> UnaryExpression[_T_co]:
...
def bitwise_lshift(self, other: Any) -> BinaryExpression[Any]:
@@ -1074,22 +1082,22 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
) -> ColumnElement[str]:
...
def desc(self) -> UnaryExpression[_T]:
def desc(self) -> UnaryExpression[_T_co]:
...
def asc(self) -> UnaryExpression[_T]:
def asc(self) -> UnaryExpression[_T_co]:
...
def nulls_first(self) -> UnaryExpression[_T]:
def nulls_first(self) -> UnaryExpression[_T_co]:
...
def nullsfirst(self) -> UnaryExpression[_T]:
def nullsfirst(self) -> UnaryExpression[_T_co]:
...
def nulls_last(self) -> UnaryExpression[_T]:
def nulls_last(self) -> UnaryExpression[_T_co]:
...
def nullslast(self) -> UnaryExpression[_T]:
def nullslast(self) -> UnaryExpression[_T_co]:
...
def collate(self, collation: str) -> CollationClause:
@@ -1100,7 +1108,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
) -> BinaryExpression[bool]:
...
def distinct(self: _SQO[_T]) -> UnaryExpression[_T]:
def distinct(self: _SQO[_T_co]) -> UnaryExpression[_T_co]:
...
def any_(self) -> CollectionAggregate[Any]:
@@ -1128,19 +1136,11 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
) -> ColumnElement[str]:
...
@overload
def __add__(self, other: Any) -> ColumnElement[Any]:
...
def __add__(self, other: Any) -> ColumnElement[Any]:
...
@overload
def __radd__(self: _SQO[_NT], other: Any) -> ColumnElement[_NT]:
...
@overload
def __radd__(self: _SQO[int], other: Any) -> ColumnElement[int]:
def __radd__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]:
...
@overload
@@ -1282,7 +1282,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
class SQLColumnExpression(
SQLCoreOperations[_T], roles.ExpressionElementRole[_T], TypingOnly
SQLCoreOperations[_T_co], roles.ExpressionElementRole[_T_co], TypingOnly
):
"""A type that may be used to indicate any SQL column element or object
that acts in place of one.
@@ -1613,12 +1613,12 @@ class ColumnElement(
*other: Any,
**kwargs: Any,
) -> ColumnElement[Any]:
return op(self.comparator, *other, **kwargs) # type: ignore[return-value,no-any-return] # noqa: E501
return op(self.comparator, *other, **kwargs) # type: ignore[no-any-return] # noqa: E501
def reverse_operate(
self, op: operators.OperatorType, other: Any, **kwargs: Any
) -> ColumnElement[Any]:
return op(other, self.comparator, **kwargs) # type: ignore[return-value,no-any-return] # noqa: E501
return op(other, self.comparator, **kwargs) # type: ignore[no-any-return] # noqa: E501
def _bind_param(
self,
@@ -3132,7 +3132,7 @@ class BooleanClauseList(ExpressionClauseList[bool]):
}, *args)'.""",
version="1.4",
)
return cls._construct_raw(operator) # type: ignore[no-any-return]
return cls._construct_raw(operator)
lcc, convert_clauses = cls._process_clauses_for_boolean(
operator,
@@ -3162,7 +3162,7 @@ class BooleanClauseList(ExpressionClauseList[bool]):
assert lcc
# just one element. return it as a single boolean element,
# not a list and discard the operator.
return convert_clauses[0] # type: ignore[no-any-return] # noqa: E501
return convert_clauses[0]
@classmethod
def _construct_for_whereclause(
@@ -4182,7 +4182,7 @@ class Over(ColumnElement[_T]):
element: ColumnElement[_T]
"""The underlying expression object to which this :class:`.Over`
object refers towards."""
object refers."""
range_: Optional[typing_Tuple[int, int]]

View File

@@ -916,6 +916,10 @@ class _FunctionGenerator:
# code within this block is **programmatically,
# statically generated** by tools/generate_sql_functions.py
@property
def aggregate_strings(self) -> Type[aggregate_strings]:
...
@property
def ansifunction(self) -> Type[AnsiFunction[Any]]:
...
@@ -1795,3 +1799,30 @@ class grouping_sets(GenericFunction[_T]):
"""
_has_args = True
inherit_cache = True
class aggregate_strings(GenericFunction[str]):
"""Implement a generic string aggregation function.
This function will concatenate non-null values into a string and
separate the values by a delimiter.
This function is compiled on a per-backend basis, into functions
such as ``group_concat()``, ``string_agg()``, or ``LISTAGG()``.
e.g. Example usage with delimiter '.'::
stmt = select(func.aggregate_strings(table.c.str_col, "."))
The return type of this function is :class:`.String`.
.. versionadded: 2.0.21
"""
type = sqltypes.String()
_has_args = True
inherit_cache = True
def __init__(self, clause, separator):
super().__init__(clause, separator)

View File

@@ -718,7 +718,7 @@ class LinkedLambdaElement(StatementLambdaElement):
opts: Union[Type[LambdaOptions], LambdaOptions],
):
self.opts = opts
self.fn = fn # type: ignore[assignment]
self.fn = fn
self.parent_lambda = parent_lambda
self.tracker_key = parent_lambda.tracker_key + (fn.__code__,)

View File

@@ -307,7 +307,7 @@ class Operators:
)
def against(other: Any) -> Operators:
return operator(self, other) # type: ignore
return operator(self, other)
return against
@@ -569,8 +569,16 @@ class ColumnOperators(Operators):
"""
return self.operate(le, other)
# TODO: not sure why we have this
__hash__ = Operators.__hash__ # type: ignore
# ColumnOperators defines an __eq__ so it must explicitly declare also
# an hash or it's set to None by python:
# https://docs.python.org/3/reference/datamodel.html#object.__hash__
if TYPE_CHECKING:
def __hash__(self) -> int:
...
else:
__hash__ = Operators.__hash__
def __eq__(self, other: Any) -> ColumnOperators: # type: ignore[override]
"""Implement the ``==`` operator.
@@ -2533,8 +2541,8 @@ _PRECEDENCE: Dict[OperatorType, int] = {
bitwise_and_op: 7,
bitwise_lshift_op: 7,
bitwise_rshift_op: 7,
concat_op: 6,
filter_op: 6,
concat_op: 5,
match_op: 5,
not_match_op: 5,
regexp_match_op: 5,

View File

@@ -23,6 +23,7 @@ if TYPE_CHECKING:
from .selectable import Subquery
_T = TypeVar("_T", bound=Any)
_T_co = TypeVar("_T_co", bound=Any, covariant=True)
class SQLRole:
@@ -110,7 +111,7 @@ class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole):
raise NotImplementedError()
class TypedColumnsClauseRole(Generic[_T], SQLRole):
class TypedColumnsClauseRole(Generic[_T_co], SQLRole):
"""element-typed form of ColumnsClauseRole"""
__slots__ = ()
@@ -162,7 +163,7 @@ class WhereHavingRole(OnClauseRole):
_role_name = "SQL expression for WHERE/HAVING role"
class ExpressionElementRole(TypedColumnsClauseRole[_T]):
class ExpressionElementRole(TypedColumnsClauseRole[_T_co]):
# note when using generics for ExpressionElementRole,
# the generic type needs to be in
# sqlalchemy.sql.coercions._impl_lookup mapping also.

View File

@@ -50,7 +50,6 @@ from typing import overload
from typing import Sequence as _typing_Sequence
from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
@@ -1433,7 +1432,7 @@ class Table(
elif schema is None:
actual_schema = metadata.schema
else:
actual_schema = schema # type: ignore
actual_schema = schema
key = _get_table_key(name, actual_schema)
if key in metadata.tables:
util.warn(
@@ -2452,14 +2451,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
# Constraint objects plus non-constraint-bound ForeignKey objects
args: List[SchemaItem] = [
c._copy(**kw)
for c in self.constraints
if not c._type_bound # type: ignore
] + [
c._copy(**kw) # type: ignore
for c in self.foreign_keys
if not c.constraint
]
c._copy(**kw) for c in self.constraints if not c._type_bound
] + [c._copy(**kw) for c in self.foreign_keys if not c.constraint]
# ticket #5276
column_kwargs = {}
@@ -2529,6 +2522,15 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
if self.primary_key:
other.primary_key = True
if self.autoincrement != "auto" and other.autoincrement == "auto":
other.autoincrement = self.autoincrement
if self.system:
other.system = self.system
if self.info:
other.info.update(self.info)
type_ = self.type
if not type_._isnull and other.type._isnull:
if isinstance(type_, SchemaEventTarget):
@@ -2574,6 +2576,12 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
if self.index and not other.index:
other.index = True
if self.doc and other.doc is None:
other.doc = self.doc
if self.comment and other.comment is None:
other.comment = self.comment
if self.unique and not other.unique:
other.unique = True
@@ -3973,7 +3981,7 @@ class FetchedValue(SchemaEventTarget):
if for_update == self.for_update:
return self
else:
return self._clone(for_update) # type: ignore
return self._clone(for_update)
def _copy(self) -> FetchedValue:
return FetchedValue(self.for_update)
@@ -4151,7 +4159,7 @@ class Constraint(DialectKWArgs, HasConditionalDDL, SchemaItem):
"and will be removed in a future release.",
)
def copy(self, **kw: Any) -> Self:
return self._copy(**kw) # type: ignore
return self._copy(**kw)
def _copy(self, **kw: Any) -> Self:
raise NotImplementedError()
@@ -5286,35 +5294,31 @@ class Index(
)
_AllConstraints = Union[
Index,
UniqueConstraint,
CheckConstraint,
ForeignKeyConstraint,
PrimaryKeyConstraint,
]
_NamingSchemaCallable = Callable[[_AllConstraints, Table], str]
_NamingSchemaCallable = Callable[[Constraint, Table], str]
_NamingSchemaDirective = Union[str, _NamingSchemaCallable]
class _NamingSchemaTD(TypedDict, total=False):
fk: Union[str, _NamingSchemaCallable]
pk: Union[str, _NamingSchemaCallable]
ix: Union[str, _NamingSchemaCallable]
ck: Union[str, _NamingSchemaCallable]
uq: Union[str, _NamingSchemaCallable]
fk: _NamingSchemaDirective
pk: _NamingSchemaDirective
ix: _NamingSchemaDirective
ck: _NamingSchemaDirective
uq: _NamingSchemaDirective
_NamingSchemaParameter = Union[
# it seems like the TypedDict here is useful for pylance typeahead,
# and not much else
_NamingSchemaTD,
Mapping[
Union[Type[_AllConstraints], str], Union[str, _NamingSchemaCallable]
],
# there is no form that allows Union[Type[Any], str] to work in all
# cases, including breaking out Mapping[] entries for each combination
# even, therefore keys must be `Any` (see #10264)
Mapping[Any, _NamingSchemaDirective],
]
DEFAULT_NAMING_CONVENTION: _NamingSchemaParameter = util.immutabledict(
{"ix": "ix_%(column_0_label)s"} # type: ignore[arg-type]
{"ix": "ix_%(column_0_label)s"}
)
@@ -5522,7 +5526,7 @@ class MetaData(HasSchemaAttr):
def _remove_table(self, name: str, schema: Optional[str]) -> None:
key = _get_table_key(name, schema)
removed = dict.pop(self.tables, key, None) # type: ignore
removed = dict.pop(self.tables, key, None)
if removed is not None:
for fk in removed.foreign_keys:
fk._remove_from_metadata(self)
@@ -5622,7 +5626,9 @@ class MetaData(HasSchemaAttr):
bind: Union[Engine, Connection],
schema: Optional[str] = None,
views: bool = False,
only: Optional[_typing_Sequence[str]] = None,
only: Union[
_typing_Sequence[str], Callable[[str, MetaData], bool], None
] = None,
extend_existing: bool = False,
autoload_replace: bool = True,
resolve_fks: bool = True,

View File

@@ -323,9 +323,7 @@ class Selectable(ReturnsRows):
object, returning a copy of this :class:`_expression.FromClause`.
"""
return util.preloaded.sql_util.ClauseAdapter(alias).traverse( # type: ignore # noqa: E501
self
)
return util.preloaded.sql_util.ClauseAdapter(alias).traverse(self)
def corresponding_column(
self, column: KeyedColumnElement[Any], require_embedded: bool = False
@@ -1420,7 +1418,7 @@ class Join(roles.DMLTableRole, FromClause):
continue
for fk in sorted(
b.foreign_keys,
key=lambda fk: fk.parent._creation_order, # type: ignore
key=lambda fk: fk.parent._creation_order,
):
if (
consider_as_foreign_keys is not None
@@ -1441,7 +1439,7 @@ class Join(roles.DMLTableRole, FromClause):
if left is not b:
for fk in sorted(
left.foreign_keys,
key=lambda fk: fk.parent._creation_order, # type: ignore
key=lambda fk: fk.parent._creation_order,
):
if (
consider_as_foreign_keys is not None
@@ -2436,7 +2434,7 @@ class HasCTE(roles.HasCTERole, SelectsRows):
SELECT t.c1, t.c2
FROM t
Above, the "anon_1" CTE is not referred towards in the SELECT
Above, the "anon_1" CTE is not referenced in the SELECT
statement, however still accomplishes the task of running an INSERT
statement.
@@ -3151,7 +3149,7 @@ class Values(roles.InElementRole, Generative, LateralFromClause):
__visit_name__ = "values"
_data: Tuple[List[Tuple[Any, ...]], ...] = ()
_data: Tuple[Sequence[Tuple[Any, ...]], ...] = ()
_unnamed: bool
_traverse_internals: _TraverseInternalsType = [
@@ -3169,6 +3167,7 @@ class Values(roles.InElementRole, Generative, LateralFromClause):
):
super().__init__()
self._column_args = columns
if name is None:
self._unnamed = True
self.name = _anonymous_label.safe_construct(id(self), "anon")
@@ -3234,7 +3233,7 @@ class Values(roles.InElementRole, Generative, LateralFromClause):
return self
@_generative
def data(self, values: List[Tuple[Any, ...]]) -> Self:
def data(self, values: Sequence[Tuple[Any, ...]]) -> Self:
"""Return a new :class:`_expression.Values` construct,
adding the given data to the data list.
@@ -3262,6 +3261,13 @@ class Values(roles.InElementRole, Generative, LateralFromClause):
def _populate_column_collection(self) -> None:
for c in self._column_args:
if c.table is not None and c.table is not self:
_, c = c._make_proxy(self)
else:
# if the column was used in other contexts, ensure
# no memoizations of other FROM clauses.
# see test_values.py -> test_auto_proxy_select_direct_col
c._reset_memoizations()
self._columns.add(c)
c.table = self
@@ -3294,7 +3300,7 @@ class ScalarValues(roles.InElementRole, GroupedElement, ColumnElement[Any]):
def __init__(
self,
columns: Sequence[ColumnClause[Any]],
data: Tuple[List[Tuple[Any, ...]], ...],
data: Tuple[Sequence[Tuple[Any, ...]], ...],
literal_binds: bool,
):
super().__init__()
@@ -4744,7 +4750,7 @@ class SelectState(util.MemoizedSlots, CompileState):
Dict[str, ColumnElement[Any]],
]:
with_cols: Dict[str, ColumnElement[Any]] = {
c._tq_label or c.key: c # type: ignore
c._tq_label or c.key: c
for c in self.statement._all_selected_columns
if c._allow_label_resolve
}
@@ -5012,7 +5018,7 @@ class _MemoizedSelectEntities(
c.__dict__ = {k: v for k, v in self.__dict__.items()}
c._is_clone_of = self.__dict__.get("_is_clone_of", self)
return c # type: ignore
return c
@classmethod
def _generate_for_statement(cls, select_stmt: Select[Any]) -> None:
@@ -6720,7 +6726,7 @@ class Exists(UnaryExpression[bool]):
)
return e
def select_from(self, *froms: FromClause) -> Self:
def select_from(self, *froms: _FromClauseArgument) -> Self:
"""Return a new :class:`_expression.Exists` construct,
applying the given
expression to the :meth:`_expression.Select.select_from`

View File

@@ -608,14 +608,21 @@ class Numeric(HasExpressionLookup, TypeEngine[_N]):
class Float(Numeric[_N]):
"""Type representing floating point types, such as ``FLOAT`` or ``REAL``.
This type returns Python ``float`` objects by default, unless the
:paramref:`.Float.asdecimal` flag is set to True, in which case they
:paramref:`.Float.asdecimal` flag is set to ``True``, in which case they
are coerced to ``decimal.Decimal`` objects.
When a :paramref:`.Float.precision` is not provided in a
:class:`_types.Float` type some backend may compile this type as
an 8 bytes / 64 bit float datatype. To use a 4 bytes / 32 bit float
datatype a precision <= 24 can usually be provided or the
:class:`_types.REAL` type can be used.
This is known to be the case in the PostgreSQL and MSSQL dialects
that render the type as ``FLOAT`` that's in both an alias of
``DOUBLE PRECISION``. Other third party dialects may have similar
behavior.
"""
__visit_name__ = "float"
@@ -733,16 +740,12 @@ class _RenderISO8601NoT:
if _portion is not None:
def process(value):
if value is not None:
value = f"""'{value.isoformat().split("T")[_portion]}'"""
return value
return f"""'{value.isoformat().split("T")[_portion]}'"""
else:
def process(value):
if value is not None:
value = f"""'{value.isoformat().replace("T", " ")}'"""
return value
return f"""'{value.isoformat().replace("T", " ")}'"""
return process
@@ -1395,7 +1398,10 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
compliant enumerated type, which should then return a list of string
values to be persisted. This allows for alternate usages such as
using the string value of an enum to be persisted to the database
instead of its name.
instead of its name. The callable must return the values to be
persisted in the same order as iterating through the Enum's
``__member__`` attribute. For example
``lambda x: [i.value for i in x]``.
.. versionadded:: 1.2.3
@@ -1451,7 +1457,11 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
self._default_length = length = 0
if length_arg is not NO_ARG:
if not _disable_warnings and length_arg < length:
if (
not _disable_warnings
and length_arg is not None
and length_arg < length
):
raise ValueError(
"When provided, length must be larger or equal"
" than the length of the longest enum value. %s < %s"
@@ -1658,14 +1668,14 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
)
def as_generic(self, allow_nulltype=False):
if hasattr(self, "enums"):
try:
args = self.enums
else:
except AttributeError:
raise NotImplementedError(
"TypeEngine.as_generic() heuristic "
"is undefined for types that inherit Enum but do not have "
"an `enums` attribute."
)
) from None
return util.constructor_copy(
self, self._generic_type_affinity, *args, _disable_warnings=True
@@ -2038,8 +2048,8 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator[dt.timedelta]):
"""A type for ``datetime.timedelta()`` objects.
The Interval type deals with ``datetime.timedelta`` objects. In
PostgreSQL, the native ``INTERVAL`` type is used; for others, the
value is stored as a date which is relative to the "epoch"
PostgreSQL and Oracle, the native ``INTERVAL`` type is used; for others,
the value is stored as a date which is relative to the "epoch"
(Jan. 1, 1970).
Note that the ``Interval`` type does not currently provide date arithmetic
@@ -2470,6 +2480,9 @@ class JSON(Indexable, TypeEngine[Any]):
value = int_processor(value)
elif string_processor and isinstance(value, str):
value = string_processor(value)
else:
raise NotImplementedError()
return value
return process
@@ -3692,28 +3705,20 @@ class Uuid(Emulated, TypeEngine[_UUID_RETURN]):
if not self.as_uuid:
def process(value):
if value is not None:
value = (
f"""'{value.replace("-", "").replace("'", "''")}'"""
)
return value
return f"""'{value.replace("-", "").replace("'", "''")}'"""
return process
else:
if character_based_uuid:
def process(value):
if value is not None:
value = f"""'{value.hex}'"""
return value
return f"""'{value.hex}'"""
return process
else:
def process(value):
if value is not None:
value = f"""'{str(value).replace("'", "''")}'"""
return value
return f"""'{str(value).replace("'", "''")}'"""
return process

View File

@@ -56,15 +56,15 @@ def _preconfigure_traversals(target_hierarchy: Type[Any]) -> None:
if hasattr(cls, "_generate_cache_attrs") and hasattr(
cls, "_traverse_internals"
):
cls._generate_cache_attrs() # type: ignore
cls._generate_cache_attrs()
_copy_internals.generate_dispatch(
cls, # type: ignore
cls._traverse_internals, # type: ignore
cls,
cls._traverse_internals,
"_generated_copy_internals_traversal",
)
_get_children.generate_dispatch(
cls, # type: ignore
cls._traverse_internals, # type: ignore
cls,
cls._traverse_internals,
"_generated_get_children_traversal",
)

View File

@@ -191,7 +191,7 @@ class TypeEngine(Visitable, Generic[_T]):
op_fn, addtl_kw = default_comparator.operator_lookup[op.__name__]
if kwargs:
addtl_kw = addtl_kw.union(kwargs)
return op_fn(self.expr, op, *other, **addtl_kw) # type: ignore
return op_fn(self.expr, op, *other, **addtl_kw)
@util.preload_module("sqlalchemy.sql.default_comparator")
def reverse_operate(
@@ -201,7 +201,7 @@ class TypeEngine(Visitable, Generic[_T]):
op_fn, addtl_kw = default_comparator.operator_lookup[op.__name__]
if kwargs:
addtl_kw = addtl_kw.union(kwargs)
return op_fn(self.expr, op, other, reverse=True, **addtl_kw) # type: ignore # noqa: E501
return op_fn(self.expr, op, other, reverse=True, **addtl_kw)
def _adapt_expression(
self,
@@ -354,33 +354,6 @@ class TypeEngine(Visitable, Generic[_T]):
def copy(self, **kw: Any) -> Self:
return self.adapt(self.__class__)
def compare_against_backend(
self, dialect: Dialect, conn_type: TypeEngine[Any]
) -> Optional[bool]:
"""Compare this type against the given backend type.
This function is currently not implemented for SQLAlchemy
types, and for all built in types will return ``None``. However,
it can be implemented by a user-defined type
where it can be consumed by schema comparison tools such as
Alembic autogenerate.
A future release of SQLAlchemy will potentially implement this method
for builtin types as well.
The function should return True if this type is equivalent to the
given type; the type is typically reflected from the database
so should be database specific. The dialect in use is also
passed. It can also return False to assert that the type is
not equivalent.
:param dialect: a :class:`.Dialect` that is involved in the comparison.
:param conn_type: the type object reflected from the backend.
"""
return None
def copy_value(self, value: Any) -> Any:
return value
@@ -816,7 +789,7 @@ class TypeEngine(Visitable, Generic[_T]):
best_uppercase = None
if not isinstance(self, TypeEngine):
return self.__class__ # type: ignore # mypy bug?
return self.__class__
for t in self.__class__.__mro__:
if (
@@ -2323,7 +2296,7 @@ def to_instance(
return NULLTYPE
if callable(typeobj):
return typeobj(*arg, **kw) # type: ignore # for pyright
return typeobj(*arg, **kw)
else:
return typeobj

View File

@@ -1440,7 +1440,7 @@ def _offset_or_limit_clause_asint_if_possible(
if clause is None:
return None
if hasattr(clause, "_limit_offset_value"):
value = clause._limit_offset_value # type: ignore
value = clause._limit_offset_value
return util.asint(value)
else:
return clause
@@ -1489,13 +1489,11 @@ def _make_slice(
offset_clause = 0
if start != 0:
offset_clause = offset_clause + start # type: ignore
offset_clause = offset_clause + start
if offset_clause == 0:
offset_clause = None
else:
offset_clause = _offset_or_limit_clause(
offset_clause # type: ignore
)
offset_clause = _offset_or_limit_clause(offset_clause)
return limit_clause, offset_clause # type: ignore
return limit_clause, offset_clause

View File

@@ -146,7 +146,7 @@ class Visitable:
cls._original_compiler_dispatch
) = _compiler_dispatch
def __class_getitem__(cls, key: str) -> Any:
def __class_getitem__(cls, key: Any) -> Any:
# allow generic classes in py3.9+
return cls
@@ -161,16 +161,17 @@ class InternalTraversal(Enum):
the ``_traverse_internals`` collection. Such as, the :class:`.Case`
object defines ``_traverse_internals`` as ::
_traverse_internals = [
("value", InternalTraversal.dp_clauseelement),
("whens", InternalTraversal.dp_clauseelement_tuples),
("else_", InternalTraversal.dp_clauseelement),
]
class Case(ColumnElement[_T]):
_traverse_internals = [
("value", InternalTraversal.dp_clauseelement),
("whens", InternalTraversal.dp_clauseelement_tuples),
("else_", InternalTraversal.dp_clauseelement),
]
Above, the :class:`.Case` class indicates its internal state as the
attributes named ``value``, ``whens``, and ``else_``. They each
link to an :class:`.InternalTraversal` method which indicates the type
of datastructure referred towards.
of datastructure to which each attribute refers.
Using the ``_traverse_internals`` structure, objects of type
:class:`.InternalTraversible` will have the following methods automatically