LeetCode 321. 拼接最大数

题目描述

给定长度分别为 m 和 n 的两个数组,其元素由 0-9 构成,表示两个自然数各位上的数字。现在从这两个数组中选出 k (k <= m + n) 个数字拼接成一个新的数,要求从同一个数组中取出的数字保持其在原数组中的相对顺序。

求满足该条件的最大数。结果返回一个表示该最大数的长度为 k 的数组。

说明: 请尽可能地优化你算法的时间和空间复杂度。

示例 1

输入:

nums1 = [3, 4, 6, 5]
nums2 = [9, 1, 2, 5, 8, 3]
k = 5

输出: [9, 8, 6, 5, 3]

示例 2

输入:

nums1 = [6, 7]
nums2 = [6, 0, 4]
k = 5

输出: [6, 7, 6, 0, 4]

示例 3

输入:

nums1 = [3, 9]
nums2 = [8, 9]
k = 3

输出: [9, 8, 9]

题解

写的时候,秒出了官方题解的思路,但是感觉太暴力了给否了……

整体分成一下几步:

  1. 首先需要实现一个可以获取串s的长度为k的子串(该子串应该满足int("".join())最大,后面直接称其为最大)。记该函数为getMax(s, k)
  2. 对于最终长度为k的串,分别以(0,k),(1,k-1),……,(l1, k-l1),为对应子串长度,获取对应的最大值
    1. 将两个串的最大值进行拼接
    2. 维护所有合并串的最大值
  3. 返回合并串最大值

最大子序列

先给这个概念下一个定义:对于一个由 090 \sim 9 构成的序列,按照顺序选取出其中 kk 个元素,使得其按照原序列顺序组成的新序列为所有子序列中最大的。其中,最大指按位拼接后得到的整数最大。

需要注意的是,尽管需要尽可能大,但是这里对序列的长度有要求,必须是 kk 位。因此在某些序列中,为了确保满足长度要求,可能需要妥协部分元素并非单调递减。但是无论如何,我们应该尽可能保持高位更大。

对于一个长度为 nn 的序列,从中选出 kk 个元素,那么必然有 nkn-k 个元素未被使用,相当于我们有 nkn-k 次替换已选择元素的机会。
如果发现新的元素比当前栈顶的元素更大,那么可以将其替换掉(这时不涉及位数变化,因此结果必然更大);而如果我们仍然有多余的替换机会,我们也可以与上一位也进行对比,如果也更大,那么也可以替换(因为高位变大,即使低位变小也无妨)。这样,一直替换到机会用尽,或者之前的数更大,停止替换。这里的替换严格来说是删除+插入,先把能删除的尽可能删除,然后插入当前数。
如果当前数未被使用(当前数比原本的数还小),那么就需要减少一个替换机会(或者说让一个数不被使用的机会)。

最容易迷惑的部分在于“替换机会”。如果将这个数据结构看作一个栈,我们实际上是尽可能让这个栈自低向顶保持单调递减。但是由于长度问题,单调递减可能无法满足。但是我们仍然应该确保栈底的数尽可能大。比如[8, 7, 6, 5, 4, 9, 0, 0, 0]如果要选 44 个元素,那么必然是 90009000 而非 87658765。只要可供选择数足够,就可以把之前的最优解直接抛弃。而这里替换机会实际上就是判断后面的数是否还足够使用。比如这里减去一个零,那么结果就变成 89008900 因为 00 的个数不够,不足以替换掉 88

该部分时间复杂度为 O(n)O(n),其中 nn 为序列长度

最大序列拼接

将两个序列在保持单个序列顺序的情况下拼接成一个新的序列,并且保持新的序列最大。

基本思路是将两个序列看作一个队列,每次都从队列前面选择更大的那个数加入新队列。和前面的总体思想一样,尽可能保持高位最大。

存在一个需要特别注意的情况:两个元素相同
如对于[4,9][4,0],尽管第一个数都是 44,但是由于第一个序列的第二个数是 99,因此应该优先使用第一个序列(大数应该尽可能在高位)

这样,只需要根据比较判断每次需要使用哪个序列的元素即可。如果相同,则比较后面的。

该部分时间复杂度为 O(l1+l2)O(l_1+l_2),其中 l1l_1l2l_2 分别是两个序列的长度。
但是如果每一位都相同,最坏时间复杂度将是 O((l1+l2)×min(l1,l2))O((l_1+l_2) \times min(l_1,l_2))
又因为这里 l1+l2=kl_1+l_2=k,因此最后的时间复杂度为 O(k2)O(k^2)


再看上面的总体思路,还需要对两个序列长度通过 kk 进行分割,事件复杂度为 O(k)O(k)。分割后需要分别获取两个序列的最大子序列并拼接。拼接完成后需要将其与原结果进行比较大小。
因此整体的时间复杂度应该是 O(k×(l1+l2+k2+k))=O(k(l1+l2+k2))O(k\times(l_1+l_2+k^2 + k))=O(k(l_1+l_2+k^2))

代码

class Solution:   
    def maxNumber(self, nums1: List[int], nums2: List[int], k: int) -> List[int]:
        res = [0 for i in range(k)]
        
        m, n = len(nums1), len(nums2)
        start = max(0, k - n)
        end = min(k, m)
        
        for i in range(start, end + 1):
            s1 = self.getMax(nums1, i)
            s2 = self.getMax(nums2, k - i)
            s = self.merge(s1, s2)
            if self.compare(s, res) > 0:
                res = s

        return res

    def getMax(self, nums: List[int], k: int) -> List[int]:
        '''
        获取从 nums 选取 k 个数能拼成的最大序列
        '''
        stack = [0] * k
        top = -1
        unuse = len(nums) - k # 可以不使用的数目个数

        for num in nums: 
            while top >= 0 and stack[top] < num and unuse > 0:
                # 把栈顶比当前数小的都删掉
                # 尽可能保持高位最大(贪心)
                top -= 1
                unuse -= 1
            if top < k - 1:
                # 插入当前数
                top += 1
                stack[top] = num
            else:
                # 当前数未被使用
                unuse -= 1
        
        return stack


    def merge(self, s1: List[int], s2: List[int]) -> List[int]:
        '''
        将两个序列拼接
        '''
        l1 = len(s1)
        l2 = len(s2)
        i1 = 0
        i2 = 0
        pos = 0

        res = [0 for i in range(l1+l2)]
        while pos < l1+l2:
            if self.use1(s1,s2,i1,i2):
                res[pos] = s1[i1]
                i1 += 1
            else:
                res[pos] = s2[i2]
                i2 += 1
            pos +=1 
        return res
    
    def use1(self, s1: List[int], s2:List[int], i1: int, i2:int)->bool:
        '''
        在拼接序列时,是否应该从 s1 选取下一个值
        '''
        l1 = len(s1)
        l2 = len(s2)

        if i1>=l1 and i2>=l2:
            # s1 s2 都已经用完了,随便用谁返回下
            return True
        elif i1<l1 and i2>=l2:
            # s1 没用完,s2 用完,所以用 s1
            return True
        elif i1>=l1 and i2<l2:
            # s1 用完了,s2 没用完,所以用 s2
            return False

        if s1[i1] == s2[i2]:
            # 两个值一样,比较下一个
            return self.use1(s1, s2, i1+1, i2+1)
        elif s1[i1] > s2[i2]:
            # s1 的值更大,用 s1
            return True
        else:
            # s2 的值更大,用 s2
            return False

    def compare(self, s1: List[int], s2: List[int]) -> bool:
        '''
        比较序列大小
        '''
        for i in range(len(s1)):
            c1 = s1[i]
            c2 = s2[i]
            if c1 != c2:
                return c1 - c2
        return 0