题目简介
最近面试Two Sigma,电话面试问了一个很有趣的问题:给一个函数getRandom(n)能随机返回[0,n)中的一个数,求实现一个RandomNumberGenerator(n)数据结构,里面有一个generate()接口,要求每次调用随机返回[0,n)中的一个数,同时不能返回已经返回过的数。
题目分析
这个题换一种说法其实就是从n个数里面随机取m个数。工作中经常遇到类似的问题,一般调用已有库函数就可以实现,所以从来没有想过这个方法的具体实现。
最直接解法
用一个Set来保存已经出现的数,generate()函数不停地调用getRandom(n),直到生成的数字不在set里面,返回这个数,同时把这个数加进Set
public class RandomNumberGenerator {
private final int n;
private final Set<Integer> blacklist;
public int generate() {
if (blacklist.size() == n) {
return -1;
}
while (true) {
int next = getRandom(n);
if (!blacklist.contains(next)) {
blacklist.add(next);
return next;
}
}
}
}
这个解法最坏情况可能导致无限死循环,假设blacklist里面数字个数为m,那么每次成功的概率为(n – m) / n,如果n和m十分接近,则这个概率会很小。
改进版直接解法
上面方法的问题在于过于随机性,我们需要一个运行时间是确定的解法。
可以稍微调整一下思路,我们可以用一个数组list来保存所有没有被访问过的数,generate()函数只需要调用getRandom(list.size())一次,得到返回数的index,通过index将这个数从数组移除,同时返回这个数。
public class RandomNumberGenerator {
private final List<Integer> numbers;
public int generate() {
if (numbers.size() == 0) {
return -1;
}
int index = getRandom(numbers.size());
return numbers.remove(index);
}
}
时间复杂度是O(n),因为每次remove后需要调整。
空间复杂度是O(n)。
如果输入的n很大浪费时间空间。
稍微优化算法
可以先考虑优化时间复杂度。可以看到上一个解法时间复杂度之所以是O(n),主要是由于remove后需要整体调整导致的,那有什么办法可以减少这里的操作呢?其实我们不需要remove后整体往前移动,只需要swap交换被选中的index元素和List最后没有被交换的元素即可,这样交换只需要O(1)
public class RandomNumberGenerator {
private final List<Integer> numbers;
// init cur = n
private int cur;
public int generate() {
if (cur == 0) {
return -1;
}
int index = getRandom(cur);
int result = numbers.get(index);
// swap index number with the last number
numbers.set(index, numbers.get(cur - 1));
numbers.set(cur - 1, result);
cur--;
return result;
}
}
这个算法时间复杂度缩减为O(1),但是空间复杂度还是O(n)。
优化算法
通常从n个数里面随机取m个数,m都远小于n,所以对于上面的解法我们还可以对空间复杂度进行优化。

我们如果研究下上面的numbers数组,可以发现如果一个index从来没有被选过,那么他所对应的数与index想同,所以我们其实只需要一个数据结构来存index对应值与index不相同的index与其值(很自然相当key-value)
public class RandomNumberGenerator {
private final Map<Integer, Integer> indexToNumber;
// init cur = n
private int cur;
public int generate() {
if (cur == 0) {
return -1;
}
int index = getRandom(cur);
int result = indexToNumber.getOrDefault(index, index);
numbers.set(index, indexToNumber.getOrDefault(cur - 1, cur - 1));
cur--;
return result;
}
}
算法时间复杂度是O(1),空间复杂度是O(m)