Skip to content
On this page

Heap & Priority Queue: K Closest Points

The K Closest Points pattern involves finding the K nearest points to a given reference point (often the origin), typically using Euclidean distance. This is a classic "top K" problem that showcases heap usage for geometric algorithms and distance-based optimization.

This pattern is fundamental for spatial algorithms, recommendation systems, and nearest neighbor problems.


Common Interview Questions

  1. K Closest Points to Origin (LC 973)

    • Problem: Given an array of points where points[i] = [xi, yi] represents a point on the X-Y plane and an integer k, return the k closest points to the origin (0, 0). The distance between two points is the Euclidean distance.
    • Heap Signal: This is a classic top K problem with custom distance comparison.
  2. Find K Closest Elements (LC 658)

    • Problem: Given a sorted integer array arr, two integers k and x, return the k closest integers to x in the array. The result should be sorted in ascending order.
    • Heap Signal: Similar pattern but with 1D distances and additional sorting requirements.
  3. Kth Smallest Element in a Sorted Matrix (LC 378)

    • Problem: Given an n x n matrix where each row and column is sorted in ascending order, find the kth smallest element in the matrix.
    • Heap Signal: Uses heap to efficiently explore a sorted 2D space.

Core Logic & Approach

The key insight is to use distance as the comparison criterion while maintaining the K closest points efficiently.

Distance calculation:

  • Euclidean distance: √((x₁-x₂)² + (y₁-y₂)²)
  • For comparison purposes, we can skip the square root
  • Distance from origin: x² + y²

Heap strategies:

  1. Max-heap of size K: Keep K closest points, remove furthest when size exceeds K
  2. Min-heap with all points: Add all points, extract K smallest
  3. Quick select: O(n) average case solution

Approach 1: Max-Heap of Size K

Maintain a max-heap of the K closest points. When the heap size exceeds K, remove the point with maximum distance.

Core Template:

python
def k_closest_max_heap(points, k):
    import heapq
    
    heap = []
    
    for x, y in points:
        # Calculate squared distance (avoid sqrt for efficiency)
        dist = x * x + y * y
        
        # Push as max-heap (negate distance)
        heapq.heappush(heap, (-dist, x, y))
        
        # Maintain heap size <= k
        if len(heap) > k:
            heapq.heappop(heap)  # Remove furthest point
    
    # Extract points (ignore distances)
    return [[x, y] for neg_dist, x, y in heap]

# Time: O(n log k)
# Space: O(k)

Example Walkthrough:

python
points = [[1,1],[2,2],[3,3]], k = 2

# Process [1,1]: dist = 2, heap = [(-2, 1, 1)]
# Process [2,2]: dist = 8, heap = [(-8, 2, 2), (-2, 1, 1)]  
# Process [3,3]: dist = 18, heap = [(-18, 3, 3), (-2, 1, 1), (-8, 2, 2)]
#                len(heap) > k, so pop (-18, 3, 3)
#                heap = [(-8, 2, 2), (-2, 1, 1)]

# Result: [[2, 2], [1, 1]]

Approach 2: Min-Heap with All Points

Add all points to a min-heap, then extract the K smallest.

Core Template:

python
def k_closest_min_heap(points, k):
    import heapq
    
    # Build min-heap with all points
    heap = []
    for x, y in points:
        dist = x * x + y * y
        heapq.heappush(heap, (dist, x, y))
    
    # Extract k closest
    result = []
    for _ in range(k):
        dist, x, y = heapq.heappop(heap)
        result.append([x, y])
    
    return result

# Time: O(n log n)  
# Space: O(n)

Approach 3: Using heapq.nsmallest

Python's built-in function provides a clean solution:

Core Template:

python
def k_closest_builtin(points, k):
    import heapq
    
    # Use custom key function for distance
    def distance(point):
        return point[0] ** 2 + point[1] ** 2
    
    return heapq.nsmallest(k, points, key=distance)

# Time: O(n log k)
# Space: O(k)

Approach 4: Quick Select (Optimal)

For the best average-case performance, use quick select with custom partitioning:

Core Template:

python
def k_closest_quickselect(points, k):
    def distance_squared(point):
        return point[0] ** 2 + point[1] ** 2
    
    def partition(left, right, pivot_idx):
        pivot_dist = distance_squared(points[pivot_idx])
        # Move pivot to end
        points[pivot_idx], points[right] = points[right], points[pivot_idx]
        
        store_idx = left
        for i in range(left, right):
            if distance_squared(points[i]) < pivot_dist:
                points[store_idx], points[i] = points[i], points[store_idx]
                store_idx += 1
        
        # Move pivot to final position
        points[right], points[store_idx] = points[store_idx], points[right]
        return store_idx
    
    def quickselect(left, right, k):
        if left == right:
            return
        
        # Choose random pivot
        import random
        pivot_idx = random.randint(left, right)
        pivot_idx = partition(left, right, pivot_idx)
        
        if k == pivot_idx:
            return
        elif k < pivot_idx:
            quickselect(left, pivot_idx - 1, k)
        else:
            quickselect(pivot_idx + 1, right, k)
    
    quickselect(0, len(points) - 1, k - 1)
    return points[:k]

# Time: O(n) average, O(n²) worst case
# Space: O(1)

Advanced: K Closest Elements in 1D Array

For the 1D version with additional constraints:

Core Template:

python
def find_closest_elements(arr, k, x):
    import heapq
    
    # Method 1: Max-heap approach
    heap = []
    
    for num in arr:
        distance = abs(num - x)
        heapq.heappush(heap, (-distance, -num))  # Max-heap by distance, min-heap by value for ties
        
        if len(heap) > k:
            heapq.heappop(heap)
    
    # Extract and sort
    result = [-num for neg_dist, neg_num in heap]
    return sorted(result)

# Method 2: Two-pointer approach (more efficient for sorted array)
def find_closest_elements_two_pointer(arr, k, x):
    left, right = 0, len(arr) - k
    
    while left < right:
        # Compare distances from window boundaries
        if x - arr[left] > arr[right + k - 1] - x:
            left += 1
        else:
            right -= 1
    
    return arr[left:left + k]

# Method 3: Binary search + expand
def find_closest_elements_binary_search(arr, k, x):
    # Find insertion point
    left, right = 0, len(arr)
    while left < right:
        mid = (left + right) // 2
        if arr[mid] < x:
            left = mid + 1
        else:
            right = mid
    
    # Expand around insertion point
    left -= 1  # Start just before insertion point
    right = left + 1
    
    for _ in range(k):
        if left < 0:
            right += 1
        elif right >= len(arr):
            left -= 1
        elif x - arr[left] <= arr[right] - x:
            left -= 1
        else:
            right += 1
    
    return arr[left + 1:right]

Approach Comparison

ApproachTime ComplexitySpace ComplexityBest Use Case
Max-heap (size k)O(n log k)O(k)When k << n, memory-efficient
Min-heap (all points)O(n log n)O(n)When k is close to n
heapq.nsmallestO(n log k)O(k)Clean code, good performance
Quick selectO(n) avg, O(n²) worstO(1)Best average performance
Two-pointer (1D sorted)O(n)O(1)When array is sorted

Distance Calculation Optimizations

Skip Square Root:

python
# Instead of: sqrt(x² + y²)
# Use: x² + y²
# Since sqrt is monotonic, ordering is preserved

def distance_squared(point):
    return point[0] ** 2 + point[1] ** 2

# This avoids expensive floating-point operations

Custom Distance Functions:

python
# Manhattan distance
def manhattan_distance(p1, p2):
    return abs(p1[0] - p2[0]) + abs(p1[1] - p2[1])

# Chebyshev distance (max of coordinate differences)
def chebyshev_distance(p1, p2):
    return max(abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))

# General Minkowski distance
def minkowski_distance(p1, p2, p):
    return (abs(p1[0] - p2[0])**p + abs(p1[1] - p2[1])**p)**(1/p)

Example Walkthrough

Let's trace through K Closest Points to Origin with points = [[3,3],[5,-1],[-2,4]], k = 2:

Max-Heap Approach:

  1. Process [3,3]:

    • Distance² = 3² + 3² = 18
    • heap = [(-18, 3, 3)]
  2. Process [5,-1]:

    • Distance² = 5² + (-1)² = 26
    • heap = [(-26, 5, -1), (-18, 3, 3)]
  3. Process [-2,4]:

    • Distance² = (-2)² + 4² = 20
    • Push: heap = [(-26, 5, -1), (-18, 3, 3), (-20, -2, 4)]
    • Size > k, so pop maximum: (-26, 5, -1)
    • Final: heap = [(-20, -2, 4), (-18, 3, 3)]
  4. Extract results: [[3, 3], [-2, 4]]

Quick Select Approach:

  1. Calculate distances: [18, 26, 20] for points [[3,3], [5,-1], [-2,4]]

  2. Partition around k=2:

    • Choose pivot (e.g., index 1, distance 26)
    • Partition: [18, 20] | [26]
    • k=2 falls in left partition
  3. Second partition:

    • Pivot distance 18
    • Result: points with distances [18, 20] = [[3,3], [-2,4]]

Result: [[3,3], [-2,4]]


Tricky Parts & Decision Points

  1. Choosing approach based on constraints:
python
if k << n:
    use_max_heap_of_size_k()  # Most memory efficient
elif n <= 10000:
    use_builtin_nsmallest()   # Clean and fast enough
else:
    use_quickselect()         # Best average performance
  1. Handling floating-point precision:
python
# Avoid floating-point by using squared distances
# Instead of: sqrt(x² + y²)
distance_squared = x * x + y * y
  1. Tie-breaking rules:
python
# If distances are equal, problem might specify tie-breaking
# Use tuple comparison for consistent ordering
heapq.heappush(heap, (-dist, x, y))  # Lexicographic tie-breaking
# or
heapq.heappush(heap, (-dist, -x, -y))  # Reverse lexicographic
  1. Result ordering requirements:
python
# Problem might require specific output ordering
result = k_closest_points(points, k)

# Sort by original indices if needed
# Sort by distance if needed
# Or return in heap order (often acceptable)
  1. Memory optimization for large datasets:
python
# For streaming data or memory constraints
def k_closest_streaming(point_stream, k):
    heap = []
    for point in point_stream:
        process_point(heap, point, k)
        if len(heap) > k:
            heapq.heappop(heap)
    return heap
  1. Edge cases:
python
# Handle edge cases
if k >= len(points):
    return points[:]

if k == 0:
    return []

if not points:
    return []
  1. 3D and higher dimensions:
python
def distance_nd(point1, point2):
    return sum((a - b) ** 2 for a, b in zip(point1, point2))

# Works with any number of dimensions