import os import re from pathlib import Path def update_migration_file(file_path): with open(file_path, 'r') as f: content = f.read() # Skip if already has the pattern if 'from sqlalchemy import inspect' in content: return False # Add import if not present if 'import sqlalchemy as sa' in content: content = content.replace('import sqlalchemy as sa', 'import sqlalchemy as sa\nfrom sqlalchemy import inspect') else: content = content.replace('from alembic import op', 'from alembic import op\nfrom sqlalchemy import inspect') # Find all create_table operations create_table_pattern = r'op\.create_table\([\'"](\w+)[\'"]' tables = re.findall(create_table_pattern, content) for table in tables: # Create the check pattern with proper indentation check_pattern = f""" conn = op.get_bind() inspector = inspect(conn) tables = inspector.get_table_names() if '{table}' not in tables:""" # Find the create_table line and its indentation create_table_line = f"op.create_table('{table}'" if create_table_line in content: # Get the indentation of the create_table line lines = content.split('\n') for i, line in enumerate(lines): if create_table_line in line: indent = len(line) - len(line.lstrip()) # Add the check before the create_table with matching indentation check_lines = check_pattern.split('\n') check_lines = [' ' * indent + line.lstrip() for line in check_lines] check_pattern = '\n'.join(check_lines) # Add extra indentation to the create_table line create_table_line = ' ' * (indent + 4) + create_table_line # Replace in the content content = content.replace(line, f"{check_pattern}\n{create_table_line}") # Write back the updated content with open(file_path, 'w') as f: f.write(content) return True def main(): migrations_dir = Path('migrations/versions') updated_files = [] for file in migrations_dir.glob('*.py'): if file.name != '__init__.py': if update_migration_file(file): updated_files.append(file.name) if updated_files: print("Updated the following migration files:") for file in updated_files: print(f"- {file}") else: print("No files needed updating.") if __name__ == '__main__': main()