Files
nuzlocke-tracker/backend/src/app/seeds/loader.py

191 lines
6.0 KiB
Python
Raw Normal View History

"""Database upsert helpers for seed data."""
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.evolution import Evolution
from app.models.game import Game
from app.models.pokemon import Pokemon
from app.models.route import Route
from app.models.route_encounter import RouteEncounter
async def upsert_games(session: AsyncSession, games: list[dict]) -> dict[str, int]:
"""Upsert game records, return {slug: id} mapping."""
for game in games:
stmt = insert(Game).values(
name=game["name"],
slug=game["slug"],
generation=game["generation"],
region=game["region"],
release_year=game.get("release_year"),
color=game.get("color"),
).on_conflict_do_update(
index_elements=["slug"],
set_={
"name": game["name"],
"generation": game["generation"],
"region": game["region"],
"release_year": game.get("release_year"),
"color": game.get("color"),
},
)
await session.execute(stmt)
await session.flush()
result = await session.execute(select(Game.slug, Game.id))
return {row.slug: row.id for row in result}
async def upsert_pokemon(session: AsyncSession, pokemon_list: list[dict]) -> dict[int, int]:
"""Upsert pokemon records, return {national_dex: id} mapping."""
for poke in pokemon_list:
stmt = insert(Pokemon).values(
national_dex=poke["national_dex"],
name=poke["name"],
types=poke["types"],
sprite_url=poke.get("sprite_url"),
).on_conflict_do_update(
index_elements=["national_dex"],
set_={
"name": poke["name"],
"types": poke["types"],
"sprite_url": poke.get("sprite_url"),
},
)
await session.execute(stmt)
await session.flush()
result = await session.execute(select(Pokemon.national_dex, Pokemon.id))
return {row.national_dex: row.id for row in result}
async def upsert_routes(
session: AsyncSession,
game_id: int,
routes: list[dict],
) -> dict[str, int]:
"""Upsert route records for a game, return {name: id} mapping.
Handles hierarchical routes: routes with 'children' are parent routes,
and their children get parent_route_id set accordingly.
"""
# First pass: upsert all parent routes (without parent_route_id)
for route in routes:
stmt = insert(Route).values(
name=route["name"],
game_id=game_id,
order=route["order"],
parent_route_id=None, # Parent routes have no parent
).on_conflict_do_update(
constraint="uq_routes_game_name",
set_={"order": route["order"], "parent_route_id": None},
)
await session.execute(stmt)
await session.flush()
# Get mapping of parent routes
result = await session.execute(
select(Route.name, Route.id).where(Route.game_id == game_id)
)
name_to_id = {row.name: row.id for row in result}
# Second pass: upsert child routes with parent_route_id
for route in routes:
children = route.get("children", [])
if not children:
continue
parent_id = name_to_id[route["name"]]
for child in children:
stmt = insert(Route).values(
name=child["name"],
game_id=game_id,
order=child["order"],
parent_route_id=parent_id,
).on_conflict_do_update(
constraint="uq_routes_game_name",
set_={"order": child["order"], "parent_route_id": parent_id},
)
await session.execute(stmt)
await session.flush()
# Return full mapping including children
result = await session.execute(
select(Route.name, Route.id).where(Route.game_id == game_id)
)
return {row.name: row.id for row in result}
async def upsert_route_encounters(
session: AsyncSession,
route_id: int,
encounters: list[dict],
dex_to_id: dict[int, int],
) -> int:
"""Upsert encounters for a route, return count of upserted rows."""
count = 0
for enc in encounters:
pokemon_id = dex_to_id.get(enc["national_dex"])
if pokemon_id is None:
print(f" Warning: no pokemon_id for dex {enc['national_dex']}")
continue
stmt = insert(RouteEncounter).values(
route_id=route_id,
pokemon_id=pokemon_id,
encounter_method=enc["method"],
encounter_rate=enc["encounter_rate"],
min_level=enc["min_level"],
max_level=enc["max_level"],
).on_conflict_do_update(
constraint="uq_route_pokemon_method",
set_={
"encounter_rate": enc["encounter_rate"],
"min_level": enc["min_level"],
"max_level": enc["max_level"],
},
)
await session.execute(stmt)
count += 1
return count
async def upsert_evolutions(
session: AsyncSession,
evolutions: list[dict],
dex_to_id: dict[int, int],
) -> int:
"""Upsert evolution pairs, return count of upserted rows."""
# Clear existing evolutions and re-insert (simpler than complex upsert)
from sqlalchemy import delete
await session.execute(delete(Evolution))
count = 0
for evo in evolutions:
from_id = dex_to_id.get(evo["from_national_dex"])
to_id = dex_to_id.get(evo["to_national_dex"])
if from_id is None or to_id is None:
continue
evolution = Evolution(
from_pokemon_id=from_id,
to_pokemon_id=to_id,
trigger=evo["trigger"],
min_level=evo.get("min_level"),
item=evo.get("item"),
held_item=evo.get("held_item"),
condition=evo.get("condition"),
)
session.add(evolution)
count += 1
await session.flush()
return count