"""Database upsert helpers for seed data.""" from sqlalchemy import delete, select, update from sqlalchemy.dialects.postgresql import insert from sqlalchemy.ext.asyncio import AsyncSession from app.models.boss_battle import BossBattle from app.models.boss_pokemon import BossPokemon from app.models.encounter import Encounter from app.models.evolution import Evolution from app.models.game import Game from app.models.pokemon import Pokemon from app.models.route import Route from app.models.route_encounter import RouteEncounter from app.models.version_group import VersionGroup async def upsert_version_groups( session: AsyncSession, vg_data: dict[str, dict], ) -> dict[str, int]: """Upsert version group records, return {slug: id} mapping.""" for vg_slug, vg_info in vg_data.items(): vg_name = " / ".join( g["name"].replace("Pokemon ", "") for g in vg_info["games"].values() ) stmt = ( insert(VersionGroup) .values( name=vg_name, slug=vg_slug, ) .on_conflict_do_update( index_elements=["slug"], set_={"name": vg_name}, ) ) await session.execute(stmt) await session.flush() result = await session.execute(select(VersionGroup.slug, VersionGroup.id)) return {row.slug: row.id for row in result} async def upsert_games( session: AsyncSession, games: list[dict], slug_to_vg_id: dict[str, int] | None = None, ) -> dict[str, int]: """Upsert game records, return {slug: id} mapping.""" for game in games: values = { "name": game["name"], "slug": game["slug"], "generation": game["generation"], "region": game["region"], "category": game.get("category"), "release_year": game.get("release_year"), "color": game.get("color"), } update_set = { "name": game["name"], "generation": game["generation"], "region": game["region"], "category": game.get("category"), "release_year": game.get("release_year"), "color": game.get("color"), } if slug_to_vg_id is not None: vg_id = slug_to_vg_id.get(game["slug"]) if vg_id is not None: values["version_group_id"] = vg_id update_set["version_group_id"] = vg_id stmt = ( insert(Game) .values(**values) .on_conflict_do_update( index_elements=["slug"], set_=update_set, ) ) await session.execute(stmt) await session.flush() result = await session.execute(select(Game.slug, Game.id)) return {row.slug: row.id for row in result} async def upsert_pokemon( session: AsyncSession, pokemon_list: list[dict] ) -> dict[int, int]: """Upsert pokemon records, return {pokeapi_id: id} mapping.""" for poke in pokemon_list: stmt = ( insert(Pokemon) .values( pokeapi_id=poke["pokeapi_id"], national_dex=poke["national_dex"], name=poke["name"], types=poke["types"], sprite_url=poke.get("sprite_url"), ) .on_conflict_do_update( index_elements=["pokeapi_id"], set_={ "national_dex": poke["national_dex"], "name": poke["name"], "types": poke["types"], "sprite_url": poke.get("sprite_url"), }, ) ) await session.execute(stmt) await session.flush() result = await session.execute(select(Pokemon.pokeapi_id, Pokemon.id)) return {row.pokeapi_id: row.id for row in result} async def upsert_routes( session: AsyncSession, version_group_id: int, routes: list[dict], *, prune: bool = False, ) -> dict[str, int]: """Upsert route records for a version group, return {name: id} mapping. Handles hierarchical routes: routes with 'children' are parent routes, and their children get parent_route_id set accordingly. When prune is True, deletes routes not present in the seed data. """ # First pass: upsert all parent routes (without parent_route_id) for route in routes: stmt = ( insert(Route) .values( name=route["name"], version_group_id=version_group_id, order=route["order"], parent_route_id=None, # Parent routes have no parent ) .on_conflict_do_update( constraint="uq_routes_version_group_name", set_={"order": route["order"], "parent_route_id": None}, ) ) await session.execute(stmt) await session.flush() # Get mapping of parent routes result = await session.execute( select(Route.name, Route.id).where(Route.version_group_id == version_group_id) ) name_to_id = {row.name: row.id for row in result} # Second pass: upsert child routes with parent_route_id for route in routes: children = route.get("children", []) if not children: continue parent_id = name_to_id[route["name"]] for child in children: stmt = ( insert(Route) .values( name=child["name"], version_group_id=version_group_id, order=child["order"], parent_route_id=parent_id, pinwheel_zone=child.get("pinwheel_zone"), ) .on_conflict_do_update( constraint="uq_routes_version_group_name", set_={ "order": child["order"], "parent_route_id": parent_id, "pinwheel_zone": child.get("pinwheel_zone"), }, ) ) await session.execute(stmt) await session.flush() if prune: seed_names: set[str] = set() for route in routes: seed_names.add(route["name"]) for child in route.get("children", []): seed_names.add(child["name"]) # Find stale route IDs, excluding routes with user encounters in_use_subq = select(Encounter.route_id).distinct().subquery() stale_route_ids_result = await session.execute( select(Route.id).where( Route.version_group_id == version_group_id, Route.name.not_in(seed_names), Route.id.not_in(select(in_use_subq)), ) ) stale_route_ids = [row.id for row in stale_route_ids_result] if stale_route_ids: # Delete encounters referencing stale routes (no DB-level cascade) await session.execute( delete(RouteEncounter).where( RouteEncounter.route_id.in_(stale_route_ids) ) ) # Nullify boss battle references to stale routes await session.execute( update(BossBattle) .where(BossBattle.after_route_id.in_(stale_route_ids)) .values(after_route_id=None) ) # Now safe to delete the routes await session.execute(delete(Route).where(Route.id.in_(stale_route_ids))) print(f" Pruned {len(stale_route_ids)} stale route(s)") await session.flush() # Return full mapping including children result = await session.execute( select(Route.name, Route.id).where(Route.version_group_id == version_group_id) ) return {row.name: row.id for row in result} async def _upsert_single_encounter( session: AsyncSession, route_id: int, pokemon_id: int, game_id: int, method: str, encounter_rate: int, min_level: int, max_level: int, condition: str = "", ) -> None: stmt = ( insert(RouteEncounter) .values( route_id=route_id, pokemon_id=pokemon_id, game_id=game_id, encounter_method=method, encounter_rate=encounter_rate, condition=condition, min_level=min_level, max_level=max_level, ) .on_conflict_do_update( constraint="uq_route_pokemon_method_game_condition", set_={ "encounter_rate": encounter_rate, "min_level": min_level, "max_level": max_level, }, ) ) await session.execute(stmt) async def upsert_route_encounters( session: AsyncSession, route_id: int, encounters: list[dict], dex_to_id: dict[int, int], game_id: int, *, prune: bool = False, ) -> int: """Upsert encounters for a route and game, return count of upserted rows. When prune is True, deletes encounters not present in the seed data. """ seed_keys: set[tuple[int, str, str]] = set() count = 0 for enc in encounters: pokemon_id = dex_to_id.get(enc["pokeapi_id"]) if pokemon_id is None: print(f" Warning: no pokemon_id for pokeapi_id {enc['pokeapi_id']}") continue conditions = enc.get("conditions") if conditions: for condition_name, rate in conditions.items(): seed_keys.add((pokemon_id, enc["method"], condition_name)) await _upsert_single_encounter( session, route_id, pokemon_id, game_id, enc["method"], rate, enc["min_level"], enc["max_level"], condition=condition_name, ) count += 1 else: seed_keys.add((pokemon_id, enc["method"], "")) await _upsert_single_encounter( session, route_id, pokemon_id, game_id, enc["method"], enc["encounter_rate"], enc["min_level"], enc["max_level"], ) count += 1 if prune: existing = await session.execute( select(RouteEncounter).where( RouteEncounter.route_id == route_id, RouteEncounter.game_id == game_id, ) ) stale_ids = [ row.id for row in existing.scalars() if (row.pokemon_id, row.encounter_method, row.condition) not in seed_keys ] if stale_ids: await session.execute( delete(RouteEncounter).where(RouteEncounter.id.in_(stale_ids)) ) return count async def upsert_bosses( session: AsyncSession, version_group_id: int, bosses: list[dict], dex_to_id: dict[int, int], route_name_to_id: dict[str, int] | None = None, slug_to_game_id: dict[str, int] | None = None, *, prune: bool = False, ) -> int: """Upsert boss battles for a version group, return count of bosses upserted. When prune is True, deletes boss battles not present in the seed data. """ count = 0 for boss in bosses: # Resolve after_route_name to an ID after_route_id = None after_route_name = boss.get("after_route_name") if after_route_name and route_name_to_id: after_route_id = route_name_to_id.get(after_route_name) if after_route_id is None: print( f" Warning: route '{after_route_name}' not found for boss '{boss['name']}'" ) # Resolve game_slug to game_id game_id = None game_slug = boss.get("game_slug") if game_slug and slug_to_game_id: game_id = slug_to_game_id.get(game_slug) if game_id is None: print( f" Warning: game '{game_slug}' not found for boss '{boss['name']}'" ) # Upsert the boss battle on (version_group_id, order) conflict stmt = ( insert(BossBattle) .values( version_group_id=version_group_id, name=boss["name"], boss_type=boss["boss_type"], specialty_type=boss.get("specialty_type"), badge_name=boss.get("badge_name"), badge_image_url=boss.get("badge_image_url"), level_cap=boss["level_cap"], order=boss["order"], after_route_id=after_route_id, location=boss["location"], section=boss.get("section"), sprite_url=boss.get("sprite_url"), game_id=game_id, ) .on_conflict_do_update( constraint="uq_boss_battles_version_group_order", set_={ "name": boss["name"], "boss_type": boss["boss_type"], "specialty_type": boss.get("specialty_type"), "badge_name": boss.get("badge_name"), "badge_image_url": boss.get("badge_image_url"), "level_cap": boss["level_cap"], "after_route_id": after_route_id, "location": boss["location"], "section": boss.get("section"), "sprite_url": boss.get("sprite_url"), "game_id": game_id, }, ) .returning(BossBattle.id) ) result = await session.execute(stmt) boss_id = result.scalar_one() # Delete existing boss_pokemon for this boss, then re-insert await session.execute( delete(BossPokemon).where(BossPokemon.boss_battle_id == boss_id) ) for bp in boss.get("pokemon", []): pokemon_id = dex_to_id.get(bp["pokeapi_id"]) if pokemon_id is None: print(f" Warning: no pokemon_id for pokeapi_id {bp['pokeapi_id']}") continue session.add( BossPokemon( boss_battle_id=boss_id, pokemon_id=pokemon_id, level=bp["level"], order=bp["order"], condition_label=bp.get("condition_label"), ) ) count += 1 if prune: seed_orders = {boss["order"] for boss in bosses} pruned = await session.execute( delete(BossBattle) .where( BossBattle.version_group_id == version_group_id, BossBattle.order.not_in(seed_orders), ) .returning(BossBattle.id) ) pruned_count = len(pruned.all()) if pruned_count: print(f" Pruned {pruned_count} stale boss battle(s)") await session.flush() return count async def upsert_evolutions( session: AsyncSession, evolutions: list[dict], dex_to_id: dict[int, int], ) -> int: """Upsert evolution pairs, return count of upserted rows.""" await session.execute(delete(Evolution)) count = 0 for evo in evolutions: from_id = dex_to_id.get(evo["from_pokeapi_id"]) to_id = dex_to_id.get(evo["to_pokeapi_id"]) if from_id is None or to_id is None: continue evolution = Evolution( from_pokemon_id=from_id, to_pokemon_id=to_id, trigger=evo["trigger"], min_level=evo.get("min_level"), item=evo.get("item"), held_item=evo.get("held_item"), condition=evo.get("condition"), region=evo.get("region"), ) session.add(evolution) count += 1 await session.flush() return count