Segment Tree

Starred by 25+ users, GitHub Repo: LeetCode Pattern 500 offers:

  • 500 solutions for LeetCode problems in Python and Java
  • 17 notes on essential concepts related to data structures and algorithms
  • 130 patterns for solving LeetCode problems

segment tree

intro

AspectPrefix SumDifference ArraySegment Tree
Primary Purposerange sum queriesrange updatesrange queries, and range updates
Operation TimeSum Query: O(1)Update: O(1)Query: O(logC) or O(logn); Update: O(logC) or O(logn)
Reconstructionget original array from diff of elementsget original array from prefix sum of elementsN/A
Use Casestatic arrayscumulative updatesinterval-based manipulations
  • segment tree is a tree where each node is an interval
  • tree based is more easy to understand
  • build tree O(n)
    • or use dynamic build, only build node when update() and query(), cost O(logC) (C is max val of num we assign)
  • point modify O(logC) or O(logn)
  • range query O(logC) or O(logn)
    • sum
    • count
    • max
    • other related aggregations
  • range modify O(logC) if using lazy propagation
    • can have multiple lazy variables depends on how many operations need
  • push_down(), push_up(), update(), query() will be implemented differently depends on the diff type of range query and range modify
# sum type
# non dynamic build
# non range modification
# non lazy propagation
class Node:
    def __init__(self, start, end):
        self.start = start
        self.end = end
        self.mid = (start + end) // 2
        self.left = None
        self.right = None
        self.val = 0

class SegmentTree:
    def __init__(self, nums):
        def build(start, end):
            if start == end:
                node = Node(start, end)
                node.val = nums[start]
                return node
            node = Node(start, end)
            node.left = build(start, node.mid)
            node.right = build(node.mid + 1, end)
            node.val = node.left.val + node.right.val
            return node
        
        self.root = build(0, len(nums) - 1)
    
    def update(self, index, val):
        def helper(node, index, val):
            if node.start == node.end:
                node.val = val
                return 
            if index <= node.mid:
                helper(node.left, index, val)
            else:
                helper(node.right, index, val)
            node.val = node.left.val + node.right.val
        
        helper(self.root, index, val)

    def query(self, left, right):
        def helper(node, start, end):
            if start <= node.start and node.end <= end:
                return node.val
            res = 0
            if start <= node.mid:
                res += helper(node.left, start, end)
            if end >= node.mid + 1:
                res += helper(node.right, start, end)
            return res
            
        return helper(self.root, left, right)

# time O(n) for initialize and O(logn) for update() and query()
# space O(n), due to segment tree
# using segment tree and sum type segment tree
# sum type
# with dynamic build
# with range modification
# with lazy propagation
class Node:
    def __init__(self, start, end):
        self.start = start
        self.end = end
        self.mid = (start + end) // 2
        self.left = None
        self.right = None
        self.val = 0
        self.lazy = 0

class SegmentTree:
    def __init__(self, start, end):
        self.root = Node(start, end)

    def push_down(self, node):
        if not node.left:
            node.left = Node(node.start, node.mid)
        if not node.right:
            node.right = Node(node.mid + 1, node.end)
        if node.lazy == 0:
            return
        node.left.val += node.lazy * (node.left.end - node.left.start + 1)
        node.right.val += node.lazy * (node.right.end - node.right.start + 1)
        node.left.lazy += node.lazy
        node.right.lazy += node.lazy
        node.lazy = 0

    def push_up(self, node):
        node.val = node.left.val + node.right.val

    def update(self, node, start, end, add):
        if start <= node.start and node.end <= end:
            node.val += add * (node.end - node.start + 1)
            node.lazy += add
            return
        self.push_down(node)
        if start <= node.mid:
            self.update(node.left, start, end, add)
        if end >= node.mid + 1:
            self.update(node.right, start, end, add)
        self.push_up(node)

    def query(self, node, start, end):
        if start <= node.start and node.end <= end:
            return node.val
        self.push_down(node)
        res = 0
        if start <= node.mid:
            res += self.query(node.left, start, end)
        if end >= node.mid + 1:
            res += self.query(node.right, start, end)
        return res

# time O(1) for initialize and O(logC) for others
# space O(min(n, C)), due to segment tree
# using segment tree and sum type segment tree

pattern

  • use sum type segment tree
  • use count type segment tree
  • use max type segment tree