71 lines
2.6 KiB
Python
71 lines
2.6 KiB
Python
import os
|
|
import re
|
|
from pathlib import Path
|
|
|
|
def update_migration_file(file_path):
|
|
with open(file_path, 'r') as f:
|
|
content = f.read()
|
|
|
|
# Skip if already has the pattern
|
|
if 'from sqlalchemy import inspect' in content:
|
|
return False
|
|
|
|
# Add import if not present
|
|
if 'import sqlalchemy as sa' in content:
|
|
content = content.replace('import sqlalchemy as sa', 'import sqlalchemy as sa\nfrom sqlalchemy import inspect')
|
|
else:
|
|
content = content.replace('from alembic import op', 'from alembic import op\nfrom sqlalchemy import inspect')
|
|
|
|
# Find all create_table operations
|
|
create_table_pattern = r'op\.create_table\([\'"](\w+)[\'"]'
|
|
tables = re.findall(create_table_pattern, content)
|
|
|
|
for table in tables:
|
|
# Create the check pattern with proper indentation
|
|
check_pattern = f""" conn = op.get_bind()
|
|
inspector = inspect(conn)
|
|
tables = inspector.get_table_names()
|
|
|
|
if '{table}' not in tables:"""
|
|
|
|
# Find the create_table line and its indentation
|
|
create_table_line = f"op.create_table('{table}'"
|
|
if create_table_line in content:
|
|
# Get the indentation of the create_table line
|
|
lines = content.split('\n')
|
|
for i, line in enumerate(lines):
|
|
if create_table_line in line:
|
|
indent = len(line) - len(line.lstrip())
|
|
# Add the check before the create_table with matching indentation
|
|
check_lines = check_pattern.split('\n')
|
|
check_lines = [' ' * indent + line.lstrip() for line in check_lines]
|
|
check_pattern = '\n'.join(check_lines)
|
|
# Add extra indentation to the create_table line
|
|
create_table_line = ' ' * (indent + 4) + create_table_line
|
|
# Replace in the content
|
|
content = content.replace(line, f"{check_pattern}\n{create_table_line}")
|
|
|
|
# Write back the updated content
|
|
with open(file_path, 'w') as f:
|
|
f.write(content)
|
|
|
|
return True
|
|
|
|
def main():
|
|
migrations_dir = Path('migrations/versions')
|
|
updated_files = []
|
|
|
|
for file in migrations_dir.glob('*.py'):
|
|
if file.name != '__init__.py':
|
|
if update_migration_file(file):
|
|
updated_files.append(file.name)
|
|
|
|
if updated_files:
|
|
print("Updated the following migration files:")
|
|
for file in updated_files:
|
|
print(f"- {file}")
|
|
else:
|
|
print("No files needed updating.")
|
|
|
|
if __name__ == '__main__':
|
|
main() |