翻转对

Posted by Jason on Saturday, April 25, 2020

TOC

题目

给定一个数组 nums ,如果 i < j 且 nums[i] > 2*nums[j] 我们就将 (i, j) 称作一个重要翻转对。 你需要返回给定数组中的重要翻转对的数量。

示例 1:

输入: [1,3,2,3,1]
输出: 2

示例 2:

输入: [2,4,3,5,1]
输出: 3

注意:

给定数组的长度不会超过50000。 输入数组中的所有数字都在32位整数的表示范围内。

来源:力扣(LeetCode) 链接:https://leetcode-cn.com/problems/reverse-pairs 著作权归领扣网络所有。商业转载请联系官方授权,非商业转载请注明出处。

解题思路

如果使用暴力解法,两个循环嵌套能够解决这个问题,复杂度为 O(N2)。思考一下,如果数组有序,通过二分法查找,可以将复杂度降到 O(NlogN)。针对这个思路,我们有两种优化:

归并

使用分治的思想,将原始问题化解为多个子问题。那么如何分解子问题呢?

  • 如果我们把原始数组等分成 2 块 N1 和 N2。前后两块假设都有序的话,可以按任意一块来遍历,遍历 N/2 次我们就可以求出 N1 大于 N2 中翻转对总数。
  • 进一步,我们将 N1 和 N2 继续拆分,最终拆分成只包含一个数字的 N1’ 和 N2’,求解子问题。
  • 最终等效成归并的思想,将子问题累加就是题目的解。
		[2,4,3,5,1]
	    [2,4,3]	[5,1]
         [2,4]   [3]   [5]  [1]
        [2] [4]
	------------ 合并 -------------
     [2|0]  [4|0]		// [num|sum] 表示 num 表示数字集合,sum 表示该集合中翻转对总和
        [2,4|0] [3|0]  [5|0] [1|0]
           [2,3,4|0]    [1,5|1] 
               [1,2,3,4,5|3]

func merge(nums []int, left int, right int) (cnt int) {
	if right <= left {
		return
	}

	mid := (left + right) >> 1
	cnt += merge(nums, left, mid)
	cnt += merge(nums, mid+1, right)

	// 这里优化到了 O(n)
	l := left
	for r := mid + 1; r <= right; r++ {
		for ; l <= mid; l++ {
			if nums[l] > 2*nums[r] {
				cnt += mid - l + 1
				break
			}
		}
	}

	r := mid + 1
	l = left
	for i := left; i <= right; i++ {
		if l <= mid && (r > right || nums[r] > nums[l]) {
			numscopy[i] = nums[l]
			l++
		} else if r <= right {
			numscopy[i] = nums[r]
			r++
		}
	}

	copy(nums[left:right+1], numscopy[left:right+1])
	return
}

树状数组

归并是将问题分解为子问题,我们可以换个思路。

  • 首先,将原始数组排序生成一个有序数组 sort_nums。
  • 然后,从左至右遍历原始数组 nums,对于每个数字 num 在 sort_nums 中查找 2*num+1 索引位,索引位即为对于 num 的翻转对。

但是,上面存在一个问题,sort_nums 中可能存在不是 num 左侧的元素,所以需要标记某个数字是否已遍历,但是这样找到 2*num+1 后还需要继续向右遍历,来过滤掉未遍历过的数字。那么我们可以进一步优化,使用 bit 数组记录 sort_nums 中遍历过的数字记录为1,未遍历过为0。这样 getSum(i) = sum(bit[i], bit[i+1], … , bit[n])。 所以,上面第二部的过程可以简化成 getSum 和 update两个操作。而这两个操作就是典型的树状数组结构。

nums=[2,4,3,5,1], sort_nums=[1,2,3,4,5], binary_indexed_tree=[0,0,0,0,0]

num 2*nums+1 in sort_nums binary_indexed_tree reverse num
2 5 [0,1,0,0,0] 0
4 5 [0,1,0,1,0] 0
3 5 [0,2,1,1,0] 0
5 5 [0,2,1,2,1] 0
1 3 [1,2,1,2,1] getSum(i) = atree[3] + atree[4]
func reversePairs(nums []int) (cnt int) {
	numscopy := make([]int, len(nums))
	copy(numscopy, nums)
	sort.Sort(sortNums(numscopy))

	atree := ArrayTree{
		bit: make([]int, len(nums)+1),
	}

	for _, num := range nums {
		// query reverse num index
		idx := binarySearch(numscopy, 2*num+1)
		cnt += atree.getSum(idx + 1)

		// update num index in binary indexed tree
		idx = binarySearch(numscopy, num)
		atree.update(idx+1, 1)
	}

	return
}

type sortNums []int

func (sn sortNums) Len() int {
	return len(sn)
}

func (sn sortNums) Swap(i, j int) {
	sn[i], sn[j] = sn[j], sn[i]
}

func (sn sortNums) Less(i, j int) bool {
	return sn[i] < sn[j]
}

type ArrayTree struct {
	bit []int
}

func binarySearch(nums []int, val int) int {
	left, right := 0, len(nums)-1
	mid := left + (right-left)>>1

	for left <= right {
		if nums[mid] >= val {
			right = mid - 1
		} else {
			left = mid + 1
		}

		mid = left + (right-left)>>1
	}

	return left
}

func (at *ArrayTree) update(index, num int) {
	for index > 0 {
		at.bit[index] += num
		index -= index & (-index)
	}
}

func (at *ArrayTree) getSum(index int) int {
	var sum int

	for index < len(at.bit) {
		sum += at.bit[index]
		index += index & (-index)
	}

	return sum
}

「真诚赞赏,手留余香」

Jason Blog

真诚赞赏,手留余香

使用微信扫描二维码完成支付


comments powered by Disqus