Weighted Random in Log(N)

Problem is here: https://leetcode.com/problems/random-pick-with-weight/

Given an array w of positive integers, where w[i] describes the weight of index i, write a function pickIndex which randomly picks an index in proportion to its weight.
Note:
  1. 1 <= w.length <= 10000
  2. 1 <= w[i] <= 10^5
  3. pickIndex will be called at most 10000 times.
Example 1:
Input: 
["Solution","pickIndex"]
[[[1]],[]]
Output: [null,0]
Example 2:
Input: 
["Solution","pickIndex","pickIndex","pickIndex","pickIndex","pickIndex"]
[[[1,3]],[],[],[],[],[]]
Output: [null,0,1,1,1,0]
Explanation of Input Syntax:
The input is two lists: the subroutines called and their arguments. Solution's constructor has one argument, the array wpickIndex has no arguments. Arguments are always wrapped with a list, even if there aren't any.
There must be an O(N) solution, but the one described here is O(Log(N)). Basically the first aspect is to create a cumulative array based on the input array. Hence, if the input is:

3, 17, 4, 1

The cumulative array will be:

3, 20, 24, 25

The above is O(N) but done once in the constructor, hence we assume it is part of the input reading.
Next, pick a random number between [1..25]. Call it randomGuess.
At this point you can do a binary search in the cumulative array looking for the index where randomGuess falls in between. For example, if randomGuess = 12, it will fall between indexes 0 (3) and 1 (20). In which case you'll always pick the rightmost index (in this case (1)).

That's it. Beats quite a few other submissions. Code is below - cheers, ACC.

    public class Solution
    {
        private int[] wl = null;
        private Random rd = new Random();
        
        public Solution(int[] w)
        {
            wl = new int[w.Length];

            wl[0] = w[0];
            for (int i = 1; i < wl.Length; i++)
            {
                wl[i] = w[i] + wl[i - 1];
            }
        }

        public int PickIndex()
        {
            int value = rd.Next(1, wl[wl.Length - 1] + 1);
            return BinarySearch(value);
        }

        private int BinarySearch(int value)
        {
            int left = 0;
            int right = wl.Length - 1;

            while (left < right)
            {
                int mid = (left + right) / 2;
                if (wl[mid] == value) return mid;
                else if (wl[mid] < value) left = mid + 1;
                else right = mid - 1;
            }

            if (left < wl.Length - 1 && value > wl[left]) return left + 1;
            return left;
        }
    }

Comments

Popular posts from this blog

Advent of Code - Day 6, 2024: BFS and FSM

Advent of Code - Day 7, 2024: Backtracking and Eval

Golang vs. C#: performance characteristics (simple case study)