"""Database upsert helpers for seed data.""" from sqlalchemy import select from sqlalchemy.dialects.postgresql import insert from sqlalchemy.ext.asyncio import AsyncSession from app.models.game import Game from app.models.pokemon import Pokemon from app.models.route import Route from app.models.route_encounter import RouteEncounter async def upsert_games(session: AsyncSession, games: list[dict]) -> dict[str, int]: """Upsert game records, return {slug: id} mapping.""" for game in games: stmt = insert(Game).values( name=game["name"], slug=game["slug"], generation=game["generation"], region=game["region"], release_year=game.get("release_year"), ).on_conflict_do_update( index_elements=["slug"], set_={ "name": game["name"], "generation": game["generation"], "region": game["region"], "release_year": game.get("release_year"), }, ) await session.execute(stmt) await session.flush() result = await session.execute(select(Game.slug, Game.id)) return {row.slug: row.id for row in result} async def upsert_pokemon(session: AsyncSession, pokemon_list: list[dict]) -> dict[int, int]: """Upsert pokemon records, return {national_dex: id} mapping.""" for poke in pokemon_list: stmt = insert(Pokemon).values( national_dex=poke["national_dex"], name=poke["name"], types=poke["types"], sprite_url=poke.get("sprite_url"), ).on_conflict_do_update( index_elements=["national_dex"], set_={ "name": poke["name"], "types": poke["types"], "sprite_url": poke.get("sprite_url"), }, ) await session.execute(stmt) await session.flush() result = await session.execute(select(Pokemon.national_dex, Pokemon.id)) return {row.national_dex: row.id for row in result} async def upsert_routes( session: AsyncSession, game_id: int, routes: list[dict], ) -> dict[str, int]: """Upsert route records for a game, return {name: id} mapping.""" for route in routes: stmt = insert(Route).values( name=route["name"], game_id=game_id, order=route["order"], ).on_conflict_do_update( constraint="uq_routes_game_name", set_={"order": route["order"]}, ) await session.execute(stmt) await session.flush() result = await session.execute( select(Route.name, Route.id).where(Route.game_id == game_id) ) return {row.name: row.id for row in result} async def upsert_route_encounters( session: AsyncSession, route_id: int, encounters: list[dict], dex_to_id: dict[int, int], ) -> int: """Upsert encounters for a route, return count of upserted rows.""" count = 0 for enc in encounters: pokemon_id = dex_to_id.get(enc["national_dex"]) if pokemon_id is None: print(f" Warning: no pokemon_id for dex {enc['national_dex']}") continue stmt = insert(RouteEncounter).values( route_id=route_id, pokemon_id=pokemon_id, encounter_method=enc["method"], encounter_rate=enc["encounter_rate"], min_level=enc["min_level"], max_level=enc["max_level"], ).on_conflict_do_update( constraint="uq_route_pokemon_method", set_={ "encounter_rate": enc["encounter_rate"], "min_level": enc["min_level"], "max_level": enc["max_level"], }, ) await session.execute(stmt) count += 1 return count