Railroad-Construction: First working version

This commit is contained in:
2025-10-21 11:02:00 -04:00
parent 97e92b8b14
commit 38117921ed

View File

@@ -8,13 +8,13 @@ Honor Code and Acknowledgments:
Comments here on your code and submission. Comments here on your code and submission.
""" """
import pprint
from collections import defaultdict from collections import defaultdict
from math import sqrt from math import sqrt
type Coordinate = tuple[float, float] type Coordinate = tuple[float, float]
type CoordMap = dict[int, Coordinate] type CoordMap = dict[int, Coordinate]
type WeightedGraph = dict[int, dict[int, float]] type WeightedGraph = dict[int, dict[int, float]]
type Edge = tuple[int, int, float]
def distance(a: Coordinate, b: Coordinate) -> float: def distance(a: Coordinate, b: Coordinate) -> float:
@@ -41,9 +41,54 @@ def create_graph(coords: CoordMap) -> WeightedGraph:
def calculate_cost(graph: WeightedGraph) -> float: def calculate_cost(graph: WeightedGraph) -> float:
cost: float = 0.0 parent: dict[int, int] = {v: v for v in graph}
rank: dict[int, int] = {v: 0 for v in graph}
return cost def find(v: int) -> int:
if parent[v] != v:
parent[v] = find(parent[v])
return parent[v]
def union(u: int, v: int) -> None:
root_u, root_v = find(u), find(v)
if root_u == root_v:
return
if rank[root_u] < rank[root_v]:
parent[root_u] = root_v
elif rank[root_u] > rank[root_v]:
parent[root_v] = root_u
else:
parent[root_v] = root_u
rank[root_u] += 1
num_components = len(graph)
mst_edges: list[Edge] = []
while num_components > 1:
cheapest: dict[int, Edge] = {}
for u in graph:
for v, w in graph[u].items():
set_u = find(u)
set_v = find(v)
if set_u == set_v:
continue
if set_u not in cheapest or cheapest[set_u][2] > w:
cheapest[set_u] = (u, v, w)
if set_v not in cheapest or cheapest[set_v][2] > w:
cheapest[set_v] = (v, u, w)
for edge in cheapest.values():
u, v, w = edge
set_u = find(u)
set_v = find(v)
if set_u == set_v:
continue
mst_edges.append(edge)
union(set_u, set_v)
num_components -= 1
return sum(map(lambda x: x[2], mst_edges))
# All modules for CS 412 must include a main method that allows it # All modules for CS 412 must include a main method that allows it
@@ -59,9 +104,9 @@ def main():
weighted_graph = create_graph(coords) weighted_graph = create_graph(coords)
pprint.pprint(weighted_graph) cost = calculate_cost(weighted_graph)
# print(f"${cost:.1f}M") print(f"${cost:.1f}M")
if __name__ == "__main__": if __name__ == "__main__":