Files
nuzlocke-tracker/backend/src/app/seeds/loader.py

401 lines
13 KiB
Python
Raw Normal View History

"""Database upsert helpers for seed data."""
from sqlalchemy import delete, select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.boss_battle import BossBattle
from app.models.boss_pokemon import BossPokemon
from app.models.evolution import Evolution
from app.models.game import Game
from app.models.pokemon import Pokemon
from app.models.route import Route
from app.models.route_encounter import RouteEncounter
from app.models.version_group import VersionGroup
async def upsert_version_groups(
session: AsyncSession,
vg_data: dict[str, dict],
) -> dict[str, int]:
"""Upsert version group records, return {slug: id} mapping."""
for vg_slug, vg_info in vg_data.items():
vg_name = " / ".join(
g["name"].replace("Pokemon ", "") for g in vg_info["games"].values()
)
stmt = (
insert(VersionGroup)
.values(
name=vg_name,
slug=vg_slug,
)
.on_conflict_do_update(
index_elements=["slug"],
set_={"name": vg_name},
)
)
await session.execute(stmt)
await session.flush()
result = await session.execute(select(VersionGroup.slug, VersionGroup.id))
return {row.slug: row.id for row in result}
async def upsert_games(
session: AsyncSession,
games: list[dict],
slug_to_vg_id: dict[str, int] | None = None,
) -> dict[str, int]:
"""Upsert game records, return {slug: id} mapping."""
for game in games:
values = {
"name": game["name"],
"slug": game["slug"],
"generation": game["generation"],
"region": game["region"],
"category": game.get("category"),
"release_year": game.get("release_year"),
"color": game.get("color"),
}
update_set = {
"name": game["name"],
"generation": game["generation"],
"region": game["region"],
"category": game.get("category"),
"release_year": game.get("release_year"),
"color": game.get("color"),
}
if slug_to_vg_id is not None:
vg_id = slug_to_vg_id.get(game["slug"])
if vg_id is not None:
values["version_group_id"] = vg_id
update_set["version_group_id"] = vg_id
stmt = (
insert(Game)
.values(**values)
.on_conflict_do_update(
index_elements=["slug"],
set_=update_set,
)
)
await session.execute(stmt)
await session.flush()
result = await session.execute(select(Game.slug, Game.id))
return {row.slug: row.id for row in result}
async def upsert_pokemon(
session: AsyncSession, pokemon_list: list[dict]
) -> dict[int, int]:
"""Upsert pokemon records, return {pokeapi_id: id} mapping."""
for poke in pokemon_list:
stmt = (
insert(Pokemon)
.values(
pokeapi_id=poke["pokeapi_id"],
national_dex=poke["national_dex"],
name=poke["name"],
types=poke["types"],
sprite_url=poke.get("sprite_url"),
)
.on_conflict_do_update(
index_elements=["pokeapi_id"],
set_={
"national_dex": poke["national_dex"],
"name": poke["name"],
"types": poke["types"],
"sprite_url": poke.get("sprite_url"),
},
)
)
await session.execute(stmt)
await session.flush()
result = await session.execute(select(Pokemon.pokeapi_id, Pokemon.id))
return {row.pokeapi_id: row.id for row in result}
async def upsert_routes(
session: AsyncSession,
version_group_id: int,
routes: list[dict],
) -> dict[str, int]:
"""Upsert route records for a version group, return {name: id} mapping.
Handles hierarchical routes: routes with 'children' are parent routes,
and their children get parent_route_id set accordingly.
"""
# First pass: upsert all parent routes (without parent_route_id)
for route in routes:
stmt = (
insert(Route)
.values(
name=route["name"],
version_group_id=version_group_id,
order=route["order"],
parent_route_id=None, # Parent routes have no parent
)
.on_conflict_do_update(
constraint="uq_routes_version_group_name",
set_={"order": route["order"], "parent_route_id": None},
)
)
await session.execute(stmt)
await session.flush()
# Get mapping of parent routes
result = await session.execute(
select(Route.name, Route.id).where(Route.version_group_id == version_group_id)
)
name_to_id = {row.name: row.id for row in result}
# Second pass: upsert child routes with parent_route_id
for route in routes:
children = route.get("children", [])
if not children:
continue
parent_id = name_to_id[route["name"]]
for child in children:
stmt = (
insert(Route)
.values(
name=child["name"],
version_group_id=version_group_id,
order=child["order"],
parent_route_id=parent_id,
pinwheel_zone=child.get("pinwheel_zone"),
)
.on_conflict_do_update(
constraint="uq_routes_version_group_name",
set_={
"order": child["order"],
"parent_route_id": parent_id,
"pinwheel_zone": child.get("pinwheel_zone"),
},
)
)
await session.execute(stmt)
await session.flush()
# Return full mapping including children
result = await session.execute(
select(Route.name, Route.id).where(Route.version_group_id == version_group_id)
)
return {row.name: row.id for row in result}
async def _upsert_single_encounter(
session: AsyncSession,
route_id: int,
pokemon_id: int,
game_id: int,
method: str,
encounter_rate: int,
min_level: int,
max_level: int,
condition: str = "",
) -> None:
stmt = (
insert(RouteEncounter)
.values(
route_id=route_id,
pokemon_id=pokemon_id,
game_id=game_id,
encounter_method=method,
encounter_rate=encounter_rate,
condition=condition,
min_level=min_level,
max_level=max_level,
)
.on_conflict_do_update(
constraint="uq_route_pokemon_method_game_condition",
set_={
"encounter_rate": encounter_rate,
"min_level": min_level,
"max_level": max_level,
},
)
)
await session.execute(stmt)
async def upsert_route_encounters(
session: AsyncSession,
route_id: int,
encounters: list[dict],
dex_to_id: dict[int, int],
game_id: int,
) -> int:
"""Upsert encounters for a route and game, return count of upserted rows."""
count = 0
for enc in encounters:
pokemon_id = dex_to_id.get(enc["pokeapi_id"])
if pokemon_id is None:
print(f" Warning: no pokemon_id for pokeapi_id {enc['pokeapi_id']}")
continue
conditions = enc.get("conditions")
if conditions:
for condition_name, rate in conditions.items():
await _upsert_single_encounter(
session,
route_id,
pokemon_id,
game_id,
enc["method"],
rate,
enc["min_level"],
enc["max_level"],
condition=condition_name,
)
count += 1
else:
await _upsert_single_encounter(
session,
route_id,
pokemon_id,
game_id,
enc["method"],
enc["encounter_rate"],
enc["min_level"],
enc["max_level"],
)
count += 1
return count
async def upsert_bosses(
session: AsyncSession,
version_group_id: int,
bosses: list[dict],
dex_to_id: dict[int, int],
route_name_to_id: dict[str, int] | None = None,
slug_to_game_id: dict[str, int] | None = None,
) -> int:
"""Upsert boss battles for a version group, return count of bosses upserted."""
count = 0
for boss in bosses:
# Resolve after_route_name to an ID
after_route_id = None
after_route_name = boss.get("after_route_name")
if after_route_name and route_name_to_id:
after_route_id = route_name_to_id.get(after_route_name)
if after_route_id is None:
print(
f" Warning: route '{after_route_name}' not found for boss '{boss['name']}'"
)
# Resolve game_slug to game_id
game_id = None
game_slug = boss.get("game_slug")
if game_slug and slug_to_game_id:
game_id = slug_to_game_id.get(game_slug)
if game_id is None:
print(
f" Warning: game '{game_slug}' not found for boss '{boss['name']}'"
)
# Upsert the boss battle on (version_group_id, order) conflict
stmt = (
insert(BossBattle)
.values(
version_group_id=version_group_id,
name=boss["name"],
boss_type=boss["boss_type"],
specialty_type=boss.get("specialty_type"),
badge_name=boss.get("badge_name"),
badge_image_url=boss.get("badge_image_url"),
level_cap=boss["level_cap"],
order=boss["order"],
after_route_id=after_route_id,
location=boss["location"],
section=boss.get("section"),
sprite_url=boss.get("sprite_url"),
game_id=game_id,
)
.on_conflict_do_update(
constraint="uq_boss_battles_version_group_order",
set_={
"name": boss["name"],
"boss_type": boss["boss_type"],
"specialty_type": boss.get("specialty_type"),
"badge_name": boss.get("badge_name"),
"badge_image_url": boss.get("badge_image_url"),
"level_cap": boss["level_cap"],
"after_route_id": after_route_id,
"location": boss["location"],
"section": boss.get("section"),
"sprite_url": boss.get("sprite_url"),
"game_id": game_id,
},
)
.returning(BossBattle.id)
)
result = await session.execute(stmt)
boss_id = result.scalar_one()
# Delete existing boss_pokemon for this boss, then re-insert
await session.execute(
delete(BossPokemon).where(BossPokemon.boss_battle_id == boss_id)
)
for bp in boss.get("pokemon", []):
pokemon_id = dex_to_id.get(bp["pokeapi_id"])
if pokemon_id is None:
print(f" Warning: no pokemon_id for pokeapi_id {bp['pokeapi_id']}")
continue
session.add(
BossPokemon(
boss_battle_id=boss_id,
pokemon_id=pokemon_id,
level=bp["level"],
order=bp["order"],
condition_label=bp.get("condition_label"),
)
)
count += 1
await session.flush()
return count
async def upsert_evolutions(
session: AsyncSession,
evolutions: list[dict],
dex_to_id: dict[int, int],
) -> int:
"""Upsert evolution pairs, return count of upserted rows."""
await session.execute(delete(Evolution))
count = 0
for evo in evolutions:
from_id = dex_to_id.get(evo["from_pokeapi_id"])
to_id = dex_to_id.get(evo["to_pokeapi_id"])
if from_id is None or to_id is None:
continue
evolution = Evolution(
from_pokemon_id=from_id,
to_pokemon_id=to_id,
trigger=evo["trigger"],
min_level=evo.get("min_level"),
item=evo.get("item"),
held_item=evo.get("held_item"),
condition=evo.get("condition"),
region=evo.get("region"),
)
session.add(evolution)
count += 1
await session.flush()
return count