debugging 有人能帮忙找出这个算法的错误边缘情况吗?

byqmnocz  于 2023-01-05  发布在  其他
关注(0)|答案(1)|浏览(103)

I 'm solving 'Non overlap intervals' problem on leetcode [https://leetcode.com/problems/non-overlapping-intervals/]简而言之,我们需要定义要删除的间隔的最小数量,以创建它们的非重叠集(要删除的数量是请求的结果)。
我的解决方案是从所有区间中构建增广区间树([https://en.wikipedia.org/wiki/Interval_tree#Augmented_tree])(时间复杂度为O((n log n)),然后(第二次遍历区间)测量每个给定区间与多少个其他区间相交(时间复杂度也为O((n log n))(它也给出+1自相交,但我只将其用作相对度量),并根据这个“其他区间的相交数”度量对所有区间进行排序。在最后一步,我只是从排序后的区间中一个接一个地得到区间,如上所述,列出并创建非重叠集(使用区间树的另一个示例显式检查非重叠),形成应该删除的结果集。
下面我给予了完整的代码所描述的解决方案,以发挥leetcode。
这种方法工作足够快,但有时我会出错,相差1个结果。Leetcode没有给予太多反馈,我的反馈是“预期810”,而不是“811”。所以我仍在调试挖掘811间隔......:)

即使知道这个问题的其他解决方案,我还是希望找到所描述的方法失败的情况(它本身就是有用的边缘情况)。因此,如果有人看到类似的问题,或者只是能用一些“新鲜的眼睛”发现它-这将是最受欢迎的!

提前感谢任何建设性的意见和想法!
解决方案代码:

class Interval:
    def __init__(self, lo: int, hi: int):
        self.lo = lo
        self.hi = hi

class Node:
    def __init__(self, interval: Interval, left: 'Node' = None, right: 'Node' = None):
        self.left = left
        self.right = right
        self.interval = interval
        self.max_hi = interval.hi

class IntervalTree:
    def __init__(self):
        self.root = None

    def __add(self, interval: Interval, node:Node) -> Node:
        if node is None:
            node = Node(interval)
            node.max_hi = interval.hi
            return node

        if node.interval.lo > interval.lo:
            node.left = self.__add(interval, node.left)
        else:
            node.right = self.__add(interval, node.right)
        node.max_hi = max(node.left.max_hi if node.left else 0, node.right.max_hi if node.right else 0, node.interval.hi)
        return node

    def add(self, lo: int, hi: int):
        interval = Interval(lo, hi)
        self.root = self.__add(interval, self.root)

    def __is_intersect(self, interval: Interval, node: Node) -> bool:
        if node is None:
            return False
        if not (node.interval.lo >= interval.hi or node.interval.hi <= interval.lo):
            # print(f'{interval.lo}-{interval.hi} intersects {node.interval.lo}-{node.interval.hi}')
            return True
        if node.left and node.left.max_hi > interval.lo:
            return self.__is_intersect(interval, node.left)
        return self.__is_intersect(interval, node.right)

    def is_intersect(self, lo: int, hi: int) -> bool:
        interval = Interval(lo, hi)
        return self.__is_intersect(interval, self.root)

    def __all_intersect(self, interval: Interval, node: Node) -> Iterable[Interval]:
        if node is None:
            yield from ()
        else:
            if not (node.interval.lo >= interval.hi or node.interval.hi <= interval.lo):
                # print(f'{interval.lo}-{interval.hi} intersects {node.interval.lo}-{node.interval.hi}')
                yield node.interval
            if node.left and node.left.max_hi > interval.lo:
                yield from self.__all_intersect(interval, node.left)
            yield from self.__all_intersect(interval, node.right)

    def all_intersect(self, lo: int, hi: int) -> Iterable[Interval]:
        interval = Interval(lo, hi)
        yield from self.__all_intersect(interval, self.root)

class Solution:
    def eraseOverlapIntervals(self, intervals: List[List[int]]) -> int:
        ranged_intervals = []

        interval_tree = IntervalTree()
        for interval in intervals:
            interval_tree.add(interval[0], interval[1])
        for interval in intervals:
            c = interval_tree.all_intersect(interval[0], interval[1])
            ranged_intervals.append((len(list(c))-1, interval))  # decrement intersection to account self intersection

        interval_tree = IntervalTree()
        res = []
        ranged_intervals.sort(key=lambda t: t[0], reverse=True)
        while ranged_intervals:
            _, interval = ranged_intervals.pop()
            if not interval_tree.is_intersect(interval[0], interval[1]):
                interval_tree.add(interval[0], interval[1])
            else:
                res.append(interval)

        return len(res)
ggazkfy8

ggazkfy81#

为了给你的算法做一个反例,你可以构造一个问题,在这个问题中,选择交集最少的线段会破坏解,如下所示:

[----][----][----][----]
[-------][----][-------]
[-------]      [-------]
[-------]      [-------]
[-------]      [-------]

您的算法将首先选择中心间隔,这与最优解不兼容:

[----][----][----][----]

当存在任何重叠时,一个确实有效的算法是:
1.查找最左侧的重叠点
1.拾取与该点重叠的任意两个间隔,并删除向右延伸最远的间隔。
这个算法实现起来也很简单,你可以在一个遍历区间列表的过程中完成,区间列表按起点排序:

class Solution:
    def eraseOverlapIntervals(self, intervals: List[List[int]]) -> int:
        intervals.sort()
        extent = None
        deletes = 0
        for interval in intervals:
            if extent == None or extent <= interval[0]:
                extent = interval[1]
            else:
                deletes += 1
                extent = min(extent, interval[1])
        return deletes

相关问题