【算法】从n个数里面随机取m个数

题目简介

最近面试Two Sigma,电话面试问了一个很有趣的问题:给一个函数getRandom(n)能随机返回[0,n)中的一个数,求实现一个RandomNumberGenerator(n)数据结构,里面有一个generate()接口,要求每次调用随机返回[0,n)中的一个数,同时不能返回已经返回过的数。

题目分析

这个题换一种说法其实就是从n个数里面随机取m个数。工作中经常遇到类似的问题,一般调用已有库函数就可以实现,所以从来没有想过这个方法的具体实现。

最直接解法

用一个Set来保存已经出现的数,generate()函数不停地调用getRandom(n),直到生成的数字不在set里面,返回这个数,同时把这个数加进Set

public class RandomNumberGenerator {
	private final int n;
	private final Set<Integer> blacklist;
	
	public int generate() {
		if (blacklist.size() == n) {
			return -1;
		}
		while (true) {
			int next = getRandom(n);
			if (!blacklist.contains(next)) {
				blacklist.add(next);
				return next;
			}
		}
	} 
}

这个解法最坏情况可能导致无限死循环,假设blacklist里面数字个数为m,那么每次成功的概率为(n – m) / n,如果n和m十分接近,则这个概率会很小。

改进版直接解法

上面方法的问题在于过于随机性,我们需要一个运行时间是确定的解法。

可以稍微调整一下思路,我们可以用一个数组list来保存所有没有被访问过的数,generate()函数只需要调用getRandom(list.size())一次,得到返回数的index,通过index将这个数从数组移除,同时返回这个数。

public class RandomNumberGenerator {
	private final List<Integer> numbers;
	
	public int generate() {
		if (numbers.size() == 0) {
			return -1;
		}
		int index = getRandom(numbers.size());
		return numbers.remove(index);
	} 
}

时间复杂度是O(n),因为每次remove后需要调整。

空间复杂度是O(n)。

如果输入的n很大浪费时间空间。

稍微优化算法

可以先考虑优化时间复杂度。可以看到上一个解法时间复杂度之所以是O(n),主要是由于remove后需要整体调整导致的,那有什么办法可以减少这里的操作呢?其实我们不需要remove后整体往前移动,只需要swap交换被选中的index元素和List最后没有被交换的元素即可,这样交换只需要O(1)

public class RandomNumberGenerator {
	private final List<Integer> numbers;
	// init cur = n
	private int cur;
	
	public int generate() {
		if (cur == 0) {
			return -1;
		}
		int index = getRandom(cur);
		int result = numbers.get(index);
		// swap index number with the last number
		numbers.set(index, numbers.get(cur - 1));
		numbers.set(cur - 1, result);
		cur--;
		return result;
	} 
}

这个算法时间复杂度缩减为O(1),但是空间复杂度还是O(n)。

优化算法

通常从n个数里面随机取m个数,m都远小于n,所以对于上面的解法我们还可以对空间复杂度进行优化。

5个数里面随机取3个
5个数里面随机取3个

我们如果研究下上面的numbers数组,可以发现如果一个index从来没有被选过,那么他所对应的数与index想同,所以我们其实只需要一个数据结构来存index对应值与index不相同的index与其值(很自然相当key-value)

public class RandomNumberGenerator {
	private final Map<Integer, Integer> indexToNumber;
	// init cur = n
	private int cur;
	
	public int generate() {
		if (cur == 0) {
			return -1;
		}
		int index = getRandom(cur);
		int result = indexToNumber.getOrDefault(index, index);
		numbers.set(index, indexToNumber.getOrDefault(cur - 1, cur - 1));
		cur--;
		return result;
	} 
}

算法时间复杂度是O(1),空间复杂度是O(m)