Files
nuzlocke-tracker/backend/src/app/seeds/loader.py

487 lines
16 KiB
Python
Raw Normal View History

"""Database upsert helpers for seed data."""
from sqlalchemy import delete, select, update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.boss_battle import BossBattle
from app.models.boss_pokemon import BossPokemon
from app.models.encounter import Encounter
from app.models.evolution import Evolution
from app.models.game import Game
from app.models.pokemon import Pokemon
from app.models.route import Route
from app.models.route_encounter import RouteEncounter
from app.models.version_group import VersionGroup
async def upsert_version_groups(
session: AsyncSession,
vg_data: dict[str, dict],
) -> dict[str, int]:
"""Upsert version group records, return {slug: id} mapping."""
for vg_slug, vg_info in vg_data.items():
vg_name = " / ".join(
g["name"].replace("Pokemon ", "") for g in vg_info["games"].values()
)
stmt = (
insert(VersionGroup)
.values(
name=vg_name,
slug=vg_slug,
)
.on_conflict_do_update(
index_elements=["slug"],
set_={"name": vg_name},
)
)
await session.execute(stmt)
await session.flush()
result = await session.execute(select(VersionGroup.slug, VersionGroup.id))
return {row.slug: row.id for row in result}
async def upsert_games(
session: AsyncSession,
games: list[dict],
slug_to_vg_id: dict[str, int] | None = None,
) -> dict[str, int]:
"""Upsert game records, return {slug: id} mapping."""
for game in games:
values = {
"name": game["name"],
"slug": game["slug"],
"generation": game["generation"],
"region": game["region"],
"category": game.get("category"),
"release_year": game.get("release_year"),
"color": game.get("color"),
}
update_set = {
"name": game["name"],
"generation": game["generation"],
"region": game["region"],
"category": game.get("category"),
"release_year": game.get("release_year"),
"color": game.get("color"),
}
if slug_to_vg_id is not None:
vg_id = slug_to_vg_id.get(game["slug"])
if vg_id is not None:
values["version_group_id"] = vg_id
update_set["version_group_id"] = vg_id
stmt = (
insert(Game)
.values(**values)
.on_conflict_do_update(
index_elements=["slug"],
set_=update_set,
)
)
await session.execute(stmt)
await session.flush()
result = await session.execute(select(Game.slug, Game.id))
return {row.slug: row.id for row in result}
async def upsert_pokemon(
session: AsyncSession, pokemon_list: list[dict]
) -> dict[int, int]:
"""Upsert pokemon records, return {pokeapi_id: id} mapping."""
for poke in pokemon_list:
stmt = (
insert(Pokemon)
.values(
pokeapi_id=poke["pokeapi_id"],
national_dex=poke["national_dex"],
name=poke["name"],
types=poke["types"],
sprite_url=poke.get("sprite_url"),
)
.on_conflict_do_update(
index_elements=["pokeapi_id"],
set_={
"national_dex": poke["national_dex"],
"name": poke["name"],
"types": poke["types"],
"sprite_url": poke.get("sprite_url"),
},
)
)
await session.execute(stmt)
await session.flush()
result = await session.execute(select(Pokemon.pokeapi_id, Pokemon.id))
return {row.pokeapi_id: row.id for row in result}
async def upsert_routes(
session: AsyncSession,
version_group_id: int,
routes: list[dict],
*,
prune: bool = False,
) -> dict[str, int]:
"""Upsert route records for a version group, return {name: id} mapping.
Handles hierarchical routes: routes with 'children' are parent routes,
and their children get parent_route_id set accordingly.
When prune is True, deletes routes not present in the seed data.
"""
# First pass: upsert all parent routes (without parent_route_id)
for route in routes:
stmt = (
insert(Route)
.values(
name=route["name"],
version_group_id=version_group_id,
order=route["order"],
parent_route_id=None, # Parent routes have no parent
)
.on_conflict_do_update(
constraint="uq_routes_version_group_name",
set_={"order": route["order"], "parent_route_id": None},
)
)
await session.execute(stmt)
await session.flush()
# Get mapping of parent routes
result = await session.execute(
select(Route.name, Route.id).where(Route.version_group_id == version_group_id)
)
name_to_id = {row.name: row.id for row in result}
# Second pass: upsert child routes with parent_route_id
for route in routes:
children = route.get("children", [])
if not children:
continue
parent_id = name_to_id[route["name"]]
for child in children:
stmt = (
insert(Route)
.values(
name=child["name"],
version_group_id=version_group_id,
order=child["order"],
parent_route_id=parent_id,
pinwheel_zone=child.get("pinwheel_zone"),
)
.on_conflict_do_update(
constraint="uq_routes_version_group_name",
set_={
"order": child["order"],
"parent_route_id": parent_id,
"pinwheel_zone": child.get("pinwheel_zone"),
},
)
)
await session.execute(stmt)
await session.flush()
if prune:
seed_names: set[str] = set()
for route in routes:
seed_names.add(route["name"])
for child in route.get("children", []):
seed_names.add(child["name"])
# Find stale route IDs, excluding routes with user encounters
in_use_subq = select(Encounter.route_id).distinct().subquery()
stale_route_ids_result = await session.execute(
select(Route.id).where(
Route.version_group_id == version_group_id,
Route.name.not_in(seed_names),
Route.id.not_in(select(in_use_subq)),
)
)
stale_route_ids = [row.id for row in stale_route_ids_result]
if stale_route_ids:
# Delete encounters referencing stale routes (no DB-level cascade)
await session.execute(
delete(RouteEncounter).where(
RouteEncounter.route_id.in_(stale_route_ids)
)
)
# Nullify boss battle references to stale routes
await session.execute(
update(BossBattle)
.where(BossBattle.after_route_id.in_(stale_route_ids))
.values(after_route_id=None)
)
# Now safe to delete the routes
await session.execute(delete(Route).where(Route.id.in_(stale_route_ids)))
print(f" Pruned {len(stale_route_ids)} stale route(s)")
await session.flush()
# Return full mapping including children
result = await session.execute(
select(Route.name, Route.id).where(Route.version_group_id == version_group_id)
)
return {row.name: row.id for row in result}
async def _upsert_single_encounter(
session: AsyncSession,
route_id: int,
pokemon_id: int,
game_id: int,
method: str,
encounter_rate: int,
min_level: int,
max_level: int,
condition: str = "",
) -> None:
stmt = (
insert(RouteEncounter)
.values(
route_id=route_id,
pokemon_id=pokemon_id,
game_id=game_id,
encounter_method=method,
encounter_rate=encounter_rate,
condition=condition,
min_level=min_level,
max_level=max_level,
)
.on_conflict_do_update(
constraint="uq_route_pokemon_method_game_condition",
set_={
"encounter_rate": encounter_rate,
"min_level": min_level,
"max_level": max_level,
},
)
)
await session.execute(stmt)
async def upsert_route_encounters(
session: AsyncSession,
route_id: int,
encounters: list[dict],
dex_to_id: dict[int, int],
game_id: int,
*,
prune: bool = False,
) -> int:
"""Upsert encounters for a route and game, return count of upserted rows.
When prune is True, deletes encounters not present in the seed data.
"""
seed_keys: set[tuple[int, str, str]] = set()
count = 0
for enc in encounters:
pokemon_id = dex_to_id.get(enc["pokeapi_id"])
if pokemon_id is None:
print(f" Warning: no pokemon_id for pokeapi_id {enc['pokeapi_id']}")
continue
conditions = enc.get("conditions")
if conditions:
for condition_name, rate in conditions.items():
seed_keys.add((pokemon_id, enc["method"], condition_name))
await _upsert_single_encounter(
session,
route_id,
pokemon_id,
game_id,
enc["method"],
rate,
enc["min_level"],
enc["max_level"],
condition=condition_name,
)
count += 1
else:
seed_keys.add((pokemon_id, enc["method"], ""))
await _upsert_single_encounter(
session,
route_id,
pokemon_id,
game_id,
enc["method"],
enc["encounter_rate"],
enc["min_level"],
enc["max_level"],
)
count += 1
if prune:
existing = await session.execute(
select(RouteEncounter).where(
RouteEncounter.route_id == route_id,
RouteEncounter.game_id == game_id,
)
)
stale_ids = [
row.id
for row in existing.scalars()
if (row.pokemon_id, row.encounter_method, row.condition) not in seed_keys
]
if stale_ids:
await session.execute(
delete(RouteEncounter).where(RouteEncounter.id.in_(stale_ids))
)
return count
async def upsert_bosses(
session: AsyncSession,
version_group_id: int,
bosses: list[dict],
dex_to_id: dict[int, int],
route_name_to_id: dict[str, int] | None = None,
slug_to_game_id: dict[str, int] | None = None,
*,
prune: bool = False,
) -> int:
"""Upsert boss battles for a version group, return count of bosses upserted.
When prune is True, deletes boss battles not present in the seed data.
"""
count = 0
for boss in bosses:
# Resolve after_route_name to an ID
after_route_id = None
after_route_name = boss.get("after_route_name")
if after_route_name and route_name_to_id:
after_route_id = route_name_to_id.get(after_route_name)
if after_route_id is None:
print(
f" Warning: route '{after_route_name}' not found for boss '{boss['name']}'"
)
# Resolve game_slug to game_id
game_id = None
game_slug = boss.get("game_slug")
if game_slug and slug_to_game_id:
game_id = slug_to_game_id.get(game_slug)
if game_id is None:
print(
f" Warning: game '{game_slug}' not found for boss '{boss['name']}'"
)
# Upsert the boss battle on (version_group_id, order) conflict
stmt = (
insert(BossBattle)
.values(
version_group_id=version_group_id,
name=boss["name"],
boss_type=boss["boss_type"],
specialty_type=boss.get("specialty_type"),
badge_name=boss.get("badge_name"),
badge_image_url=boss.get("badge_image_url"),
level_cap=boss["level_cap"],
order=boss["order"],
after_route_id=after_route_id,
location=boss["location"],
section=boss.get("section"),
sprite_url=boss.get("sprite_url"),
game_id=game_id,
)
.on_conflict_do_update(
constraint="uq_boss_battles_version_group_order",
set_={
"name": boss["name"],
"boss_type": boss["boss_type"],
"specialty_type": boss.get("specialty_type"),
"badge_name": boss.get("badge_name"),
"badge_image_url": boss.get("badge_image_url"),
"level_cap": boss["level_cap"],
"after_route_id": after_route_id,
"location": boss["location"],
"section": boss.get("section"),
"sprite_url": boss.get("sprite_url"),
"game_id": game_id,
},
)
.returning(BossBattle.id)
)
result = await session.execute(stmt)
boss_id = result.scalar_one()
# Delete existing boss_pokemon for this boss, then re-insert
await session.execute(
delete(BossPokemon).where(BossPokemon.boss_battle_id == boss_id)
)
for bp in boss.get("pokemon", []):
pokemon_id = dex_to_id.get(bp["pokeapi_id"])
if pokemon_id is None:
print(f" Warning: no pokemon_id for pokeapi_id {bp['pokeapi_id']}")
continue
session.add(
BossPokemon(
boss_battle_id=boss_id,
pokemon_id=pokemon_id,
level=bp["level"],
order=bp["order"],
condition_label=bp.get("condition_label"),
)
)
count += 1
if prune:
seed_orders = {boss["order"] for boss in bosses}
pruned = await session.execute(
delete(BossBattle)
.where(
BossBattle.version_group_id == version_group_id,
BossBattle.order.not_in(seed_orders),
)
.returning(BossBattle.id)
)
pruned_count = len(pruned.all())
if pruned_count:
print(f" Pruned {pruned_count} stale boss battle(s)")
await session.flush()
return count
async def upsert_evolutions(
session: AsyncSession,
evolutions: list[dict],
dex_to_id: dict[int, int],
) -> int:
"""Upsert evolution pairs, return count of upserted rows."""
await session.execute(delete(Evolution))
count = 0
for evo in evolutions:
from_id = dex_to_id.get(evo["from_pokeapi_id"])
to_id = dex_to_id.get(evo["to_pokeapi_id"])
if from_id is None or to_id is None:
continue
evolution = Evolution(
from_pokemon_id=from_id,
to_pokemon_id=to_id,
trigger=evo["trigger"],
min_level=evo.get("min_level"),
item=evo.get("item"),
held_item=evo.get("held_item"),
condition=evo.get("condition"),
region=evo.get("region"),
)
session.add(evolution)
count += 1
await session.flush()
return count