Heap & Priority Queue: K Closest Points
The K Closest Points pattern involves finding the K nearest points to a given reference point (often the origin), typically using Euclidean distance. This is a classic "top K" problem that showcases heap usage for geometric algorithms and distance-based optimization.
This pattern is fundamental for spatial algorithms, recommendation systems, and nearest neighbor problems.
Common Interview Questions
K Closest Points to Origin (LC 973)
- Problem: Given an array of
pointswherepoints[i] = [xi, yi]represents a point on the X-Y plane and an integerk, return thekclosest points to the origin(0, 0). The distance between two points is the Euclidean distance. - Heap Signal: This is a classic top K problem with custom distance comparison.
- Problem: Given an array of
Find K Closest Elements (LC 658)
- Problem: Given a sorted integer array
arr, two integerskandx, return thekclosest integers toxin the array. The result should be sorted in ascending order. - Heap Signal: Similar pattern but with 1D distances and additional sorting requirements.
- Problem: Given a sorted integer array
Kth Smallest Element in a Sorted Matrix (LC 378)
- Problem: Given an
n x nmatrix where each row and column is sorted in ascending order, find thekthsmallest element in the matrix. - Heap Signal: Uses heap to efficiently explore a sorted 2D space.
- Problem: Given an
Core Logic & Approach
The key insight is to use distance as the comparison criterion while maintaining the K closest points efficiently.
Distance calculation:
- Euclidean distance: √((x₁-x₂)² + (y₁-y₂)²)
- For comparison purposes, we can skip the square root
- Distance from origin: x² + y²
Heap strategies:
- Max-heap of size K: Keep K closest points, remove furthest when size exceeds K
- Min-heap with all points: Add all points, extract K smallest
- Quick select: O(n) average case solution
Approach 1: Max-Heap of Size K
Maintain a max-heap of the K closest points. When the heap size exceeds K, remove the point with maximum distance.
Core Template:
def k_closest_max_heap(points, k):
import heapq
heap = []
for x, y in points:
# Calculate squared distance (avoid sqrt for efficiency)
dist = x * x + y * y
# Push as max-heap (negate distance)
heapq.heappush(heap, (-dist, x, y))
# Maintain heap size <= k
if len(heap) > k:
heapq.heappop(heap) # Remove furthest point
# Extract points (ignore distances)
return [[x, y] for neg_dist, x, y in heap]
# Time: O(n log k)
# Space: O(k)Example Walkthrough:
points = [[1,1],[2,2],[3,3]], k = 2
# Process [1,1]: dist = 2, heap = [(-2, 1, 1)]
# Process [2,2]: dist = 8, heap = [(-8, 2, 2), (-2, 1, 1)]
# Process [3,3]: dist = 18, heap = [(-18, 3, 3), (-2, 1, 1), (-8, 2, 2)]
# len(heap) > k, so pop (-18, 3, 3)
# heap = [(-8, 2, 2), (-2, 1, 1)]
# Result: [[2, 2], [1, 1]]Approach 2: Min-Heap with All Points
Add all points to a min-heap, then extract the K smallest.
Core Template:
def k_closest_min_heap(points, k):
import heapq
# Build min-heap with all points
heap = []
for x, y in points:
dist = x * x + y * y
heapq.heappush(heap, (dist, x, y))
# Extract k closest
result = []
for _ in range(k):
dist, x, y = heapq.heappop(heap)
result.append([x, y])
return result
# Time: O(n log n)
# Space: O(n)Approach 3: Using heapq.nsmallest
Python's built-in function provides a clean solution:
Core Template:
def k_closest_builtin(points, k):
import heapq
# Use custom key function for distance
def distance(point):
return point[0] ** 2 + point[1] ** 2
return heapq.nsmallest(k, points, key=distance)
# Time: O(n log k)
# Space: O(k)Approach 4: Quick Select (Optimal)
For the best average-case performance, use quick select with custom partitioning:
Core Template:
def k_closest_quickselect(points, k):
def distance_squared(point):
return point[0] ** 2 + point[1] ** 2
def partition(left, right, pivot_idx):
pivot_dist = distance_squared(points[pivot_idx])
# Move pivot to end
points[pivot_idx], points[right] = points[right], points[pivot_idx]
store_idx = left
for i in range(left, right):
if distance_squared(points[i]) < pivot_dist:
points[store_idx], points[i] = points[i], points[store_idx]
store_idx += 1
# Move pivot to final position
points[right], points[store_idx] = points[store_idx], points[right]
return store_idx
def quickselect(left, right, k):
if left == right:
return
# Choose random pivot
import random
pivot_idx = random.randint(left, right)
pivot_idx = partition(left, right, pivot_idx)
if k == pivot_idx:
return
elif k < pivot_idx:
quickselect(left, pivot_idx - 1, k)
else:
quickselect(pivot_idx + 1, right, k)
quickselect(0, len(points) - 1, k - 1)
return points[:k]
# Time: O(n) average, O(n²) worst case
# Space: O(1)Advanced: K Closest Elements in 1D Array
For the 1D version with additional constraints:
Core Template:
def find_closest_elements(arr, k, x):
import heapq
# Method 1: Max-heap approach
heap = []
for num in arr:
distance = abs(num - x)
heapq.heappush(heap, (-distance, -num)) # Max-heap by distance, min-heap by value for ties
if len(heap) > k:
heapq.heappop(heap)
# Extract and sort
result = [-num for neg_dist, neg_num in heap]
return sorted(result)
# Method 2: Two-pointer approach (more efficient for sorted array)
def find_closest_elements_two_pointer(arr, k, x):
left, right = 0, len(arr) - k
while left < right:
# Compare distances from window boundaries
if x - arr[left] > arr[right + k - 1] - x:
left += 1
else:
right -= 1
return arr[left:left + k]
# Method 3: Binary search + expand
def find_closest_elements_binary_search(arr, k, x):
# Find insertion point
left, right = 0, len(arr)
while left < right:
mid = (left + right) // 2
if arr[mid] < x:
left = mid + 1
else:
right = mid
# Expand around insertion point
left -= 1 # Start just before insertion point
right = left + 1
for _ in range(k):
if left < 0:
right += 1
elif right >= len(arr):
left -= 1
elif x - arr[left] <= arr[right] - x:
left -= 1
else:
right += 1
return arr[left + 1:right]Approach Comparison
| Approach | Time Complexity | Space Complexity | Best Use Case |
|---|---|---|---|
| Max-heap (size k) | O(n log k) | O(k) | When k << n, memory-efficient |
| Min-heap (all points) | O(n log n) | O(n) | When k is close to n |
| heapq.nsmallest | O(n log k) | O(k) | Clean code, good performance |
| Quick select | O(n) avg, O(n²) worst | O(1) | Best average performance |
| Two-pointer (1D sorted) | O(n) | O(1) | When array is sorted |
Distance Calculation Optimizations
Skip Square Root:
# Instead of: sqrt(x² + y²)
# Use: x² + y²
# Since sqrt is monotonic, ordering is preserved
def distance_squared(point):
return point[0] ** 2 + point[1] ** 2
# This avoids expensive floating-point operationsCustom Distance Functions:
# Manhattan distance
def manhattan_distance(p1, p2):
return abs(p1[0] - p2[0]) + abs(p1[1] - p2[1])
# Chebyshev distance (max of coordinate differences)
def chebyshev_distance(p1, p2):
return max(abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
# General Minkowski distance
def minkowski_distance(p1, p2, p):
return (abs(p1[0] - p2[0])**p + abs(p1[1] - p2[1])**p)**(1/p)Example Walkthrough
Let's trace through K Closest Points to Origin with points = [[3,3],[5,-1],[-2,4]], k = 2:
Max-Heap Approach:
Process [3,3]:
- Distance² = 3² + 3² = 18
- heap = [(-18, 3, 3)]
Process [5,-1]:
- Distance² = 5² + (-1)² = 26
- heap = [(-26, 5, -1), (-18, 3, 3)]
Process [-2,4]:
- Distance² = (-2)² + 4² = 20
- Push: heap = [(-26, 5, -1), (-18, 3, 3), (-20, -2, 4)]
- Size > k, so pop maximum: (-26, 5, -1)
- Final: heap = [(-20, -2, 4), (-18, 3, 3)]
Extract results: [[3, 3], [-2, 4]]
Quick Select Approach:
Calculate distances: [18, 26, 20] for points [[3,3], [5,-1], [-2,4]]
Partition around k=2:
- Choose pivot (e.g., index 1, distance 26)
- Partition: [18, 20] | [26]
- k=2 falls in left partition
Second partition:
- Pivot distance 18
- Result: points with distances [18, 20] = [[3,3], [-2,4]]
Result: [[3,3], [-2,4]]
Tricky Parts & Decision Points
- Choosing approach based on constraints:
if k << n:
use_max_heap_of_size_k() # Most memory efficient
elif n <= 10000:
use_builtin_nsmallest() # Clean and fast enough
else:
use_quickselect() # Best average performance- Handling floating-point precision:
# Avoid floating-point by using squared distances
# Instead of: sqrt(x² + y²)
distance_squared = x * x + y * y- Tie-breaking rules:
# If distances are equal, problem might specify tie-breaking
# Use tuple comparison for consistent ordering
heapq.heappush(heap, (-dist, x, y)) # Lexicographic tie-breaking
# or
heapq.heappush(heap, (-dist, -x, -y)) # Reverse lexicographic- Result ordering requirements:
# Problem might require specific output ordering
result = k_closest_points(points, k)
# Sort by original indices if needed
# Sort by distance if needed
# Or return in heap order (often acceptable)- Memory optimization for large datasets:
# For streaming data or memory constraints
def k_closest_streaming(point_stream, k):
heap = []
for point in point_stream:
process_point(heap, point, k)
if len(heap) > k:
heapq.heappop(heap)
return heap- Edge cases:
# Handle edge cases
if k >= len(points):
return points[:]
if k == 0:
return []
if not points:
return []- 3D and higher dimensions:
def distance_nd(point1, point2):
return sum((a - b) ** 2 for a, b in zip(point1, point2))
# Works with any number of dimensions