区間を set で管理するやつを Rust で実装

概要

この記事を読むことで「区間を set で管理するやつ」の Rust での実装例が分かります。

「区間を set で管理するやつ」とは

「区間を set で管理するやつ」は競技プログラミングで使用されることがあるデータ構造です。
IntervalSet や RangeSet という名前で実装されることが多い印象です。

「区間を set で管理するやつ」では以下のことができます。

  • 区間の追加
  • 区間の削除
  • ある区間が追加されているかどうかの判定
  • 追加されている区間の取得
  • 追加されている区間の合計の取得
  • mex の取得

実装例

本実装では区間を半開区間で扱っています。

半開区間の具体例は以下をご覧下さい。
[2, 4) は 2 を含み、 4 は含みません。
[l, r) は l を含み、 r は含みません。

IntervalSet の実装
use std::collections::BTreeSet;

pub struct IntervalSet<T>
where
    T: Ord + Copy + Default + std::ops::Sub<Output = T> + std::ops::Add<Output = T>,
{
    set: BTreeSet<(T, T)>,
    sum: T,
}

impl<T> IntervalSet<T>
where
    T: Ord + Copy + Default + std::ops::Sub<Output = T> + std::ops::Add<Output = T>,
{
    pub fn new() -> Self {
        Self {
            set: BTreeSet::new(),
            sum: T::default(),
        }
    }

    // x を含む区間 [l, r) を返す
    // 存在しない場合は None を返す
    pub fn get(&self, x: T) -> Option<(T, T)> {
        if let Some(&(l, r)) = self.set.range(..(x, x)).next_back() {
            if l <= x && x < r {
                return Some((l, r));
            }
        }
        if let Some(&(l, r)) = self.set.range((x, x)..).next() {
            if l <= x && x < r {
                return Some((l, r));
            }
        }
        None
    }

    // x より左にある区間 [l, r) を返す
    // 存在しない場合は None を返す
    pub fn get_left(&self, x: T) -> Option<(T, T)> {
        for &(l, r) in self.set.range(..(x, x)).rev() {
            if r <= x {
                return Some((l, r));
            }
        }
        None
    }

    // x より右にある区間 [l, r) を返す
    // 存在しない場合は None を返す
    pub fn get_right(&self, x: T) -> Option<(T, T)> {
        for &(l, r) in self.set.range((x, x)..) {
            if x < l {
                return Some((l, r));
            }
        }
        None
    }

    // 追加されている区間を返す
    pub fn get_all(&self) -> Vec<(T, T)> {
        self.set.iter().map(|&(l, r)| (l, r)).collect()
    }

    // 追加されている区間の合計を返す
    pub fn sum(&self) -> T {
        self.sum
    }

    // 区間の個数を返す
    pub fn len(&self) -> usize {
        self.set.len()
    }

    // [l, r) が追加されている場合は true を返す
    // そうでない場合は false を返す
    pub fn is_covered(&self, l: T, r: T) -> bool {
        assert!(l <= r);
        if l == r {
            return true;
        }
        if let Some((_a, b)) = self.get(l) {
            if r <= b {
                return true;
            }
        }
        false
    }

    // mex を返す
    pub fn mex(&self, x: T) -> T {
        if let Some((_a, b)) = self.get(x) {
            return b;
        }
        x
    }

    // 区間 [l, r) を追加する
    pub fn insert(&mut self, l: T, r: T) {
        assert!(l <= r);
        if l == r {
            return;
        }

        let mut l = l;
        let mut r = r;

        if let Some((a, b)) = self.get_left(l) {
            if b == l {
                self.set.remove(&(a, b));
                self.sum = self.sum - (b - a);
                l = a;
            }
        }
        if let Some((a, b)) = self.get(l) {
            self.set.remove(&(a, b));
            self.sum = self.sum - (b - a);
            l = a;
            r = r.max(b);
        }
        while let Some((a, b)) = self.get_right(l) {
            if a <= r {
                self.set.remove(&(a, b));
                self.sum = self.sum - (b - a);
                r = r.max(b);
            } else {
                break;
            }
        }

        self.set.insert((l, r));
        self.sum = self.sum + (r - l);
    }

    // 区間 [l, r) を削除する
    pub fn remove(&mut self, l: T, r: T) {
        assert!(l <= r);
        if l == r {
            return;
        }

        if let Some((a, b)) = self.get(l) {
            self.set.remove(&(a, b));
            self.sum = self.sum - (b - a);
            if a < l {
                self.set.insert((a, l));
                self.sum = self.sum + (l - a);
            }
            if r < b {
                self.set.insert((r, b));
                self.sum = self.sum + (b - r);
            }
        }
        while let Some((a, b)) = self.get_right(l) {
            if r <= a {
                break;
            }
            self.set.remove(&(a, b));
            self.sum = self.sum - (b - a);
            if r < b {
                self.set.insert((r, b));
                self.sum = self.sum + (b - r);
            }
        }
    }

    // 空にする
    pub fn clear(&mut self) {
        self.set.clear();
        self.sum = T::default();
    }
}
Rust

各メソッドの解説

ここでは IntervalSet 構造体で実装されているメソッドについて解説します。

new()

空の IntervalSet を作ります。

以下の例の様に、整数型を指定して利用することができます。

Examples
// i64 で利用する場合
let mut interval_set: IntervalSet<i64> = IntervalSet::new();
interval_set.insert(1, 5);
assert_eq!(interval_set.get_all(), vec![(1, 5)]);
assert_eq!(interval_set.sum(), 4);
assert_eq!(interval_set.len(), 1);

// usize で利用する場合
let mut interval_set: IntervalSet<usize> = IntervalSet::new();
interval_set.insert(1, 5);
assert_eq!(interval_set.get_all(), vec![(1, 5)]);
assert_eq!(interval_set.sum(), 4);
assert_eq!(interval_set.len(), 1);
Rust

get()

引数で x を渡します。
x を含む区間 [l, r) を返します。
存在しない場合は None を返します。

Examples
let mut interval_set = IntervalSet::new();
interval_set.insert(1, 2);
interval_set.insert(7, 9);
interval_set.insert(3, 5);
assert_eq!(interval_set.get_all(), vec![(1, 2), (3, 5), (7, 9)]);

assert_eq!(interval_set.get(0), None);
assert_eq!(interval_set.get(1), Some((1, 2)));
assert_eq!(interval_set.get(2), None);
assert_eq!(interval_set.get(3), Some((3, 5)));
assert_eq!(interval_set.get(4), Some((3, 5)));
assert_eq!(interval_set.get(5), None);
assert_eq!(interval_set.get(6), None);
assert_eq!(interval_set.get(7), Some((7, 9)));
assert_eq!(interval_set.get(8), Some((7, 9)));
assert_eq!(interval_set.get(9), None);
assert_eq!(interval_set.get(10), None);
Rust

get_left()

引数で x を渡します。
x より左にある区間 [l, r) の中で r が最大の区間を返します。この時、r <= x を満たします。
存在しない場合は None を返します。

Examples
let mut interval_set = IntervalSet::new();
interval_set.insert(1, 2);
interval_set.insert(7, 9);
interval_set.insert(3, 5);
assert_eq!(interval_set.get_all(), vec![(1, 2), (3, 5), (7, 9)]);

assert_eq!(interval_set.get_left(0), None);
assert_eq!(interval_set.get_left(1), None);
assert_eq!(interval_set.get_left(2), Some((1, 2)));
assert_eq!(interval_set.get_left(3), Some((1, 2)));
assert_eq!(interval_set.get_left(4), Some((1, 2)));
assert_eq!(interval_set.get_left(5), Some((3, 5)));
assert_eq!(interval_set.get_left(6), Some((3, 5)));
assert_eq!(interval_set.get_left(7), Some((3, 5)));
assert_eq!(interval_set.get_left(8), Some((3, 5)));
assert_eq!(interval_set.get_left(9), Some((7, 9)));
assert_eq!(interval_set.get_left(10), Some((7, 9)));
Rust

get_right()

引数で x を渡します。
x より右にある区間 [l, r) の中で l が最小の区間を返します。この時、x < l を満たします。
存在しない場合は None を返します。

Examples
let mut interval_set = IntervalSet::new();
interval_set.insert(1, 2);
interval_set.insert(7, 9);
interval_set.insert(3, 5);
assert_eq!(interval_set.get_all(), vec![(1, 2), (3, 5), (7, 9)]);

assert_eq!(interval_set.get_right(0), Some((1, 2)));
assert_eq!(interval_set.get_right(1), Some((3, 5)));
assert_eq!(interval_set.get_right(2), Some((3, 5)));
assert_eq!(interval_set.get_right(3), Some((7, 9)));
assert_eq!(interval_set.get_right(4), Some((7, 9)));
assert_eq!(interval_set.get_right(5), Some((7, 9)));
assert_eq!(interval_set.get_right(6), Some((7, 9)));
assert_eq!(interval_set.get_right(7), None);
assert_eq!(interval_set.get_right(8), None);
assert_eq!(interval_set.get_right(9), None);
assert_eq!(interval_set.get_right(10), None);
Rust

get_all()

追加されている区間を返します。

Examples
let mut interval_set = IntervalSet::new();
assert_eq!(interval_set.get_all(), vec![]);

interval_set.insert(1, 2);
assert_eq!(interval_set.get_all(), vec![(1, 2)]);

interval_set.insert(7, 9);
assert_eq!(interval_set.get_all(), vec![(1, 2), (7, 9)]);

interval_set.insert(3, 5);
assert_eq!(interval_set.get_all(), vec![(1, 2), (3, 5), (7, 9)]);
Rust

sum()

追加されている区間の合計を返します。

Examples
let mut interval_set = IntervalSet::new();
assert_eq!(interval_set.get_all(), vec![]);
assert_eq!(interval_set.sum(), 0);

interval_set.insert(1, 2);
assert_eq!(interval_set.get_all(), vec![(1, 2)]);
assert_eq!(interval_set.sum(), 1);

interval_set.insert(7, 9);
assert_eq!(interval_set.get_all(), vec![(1, 2), (7, 9)]);
assert_eq!(interval_set.sum(), 3);

interval_set.insert(3, 5);
assert_eq!(interval_set.get_all(), vec![(1, 2), (3, 5), (7, 9)]);
assert_eq!(interval_set.sum(), 5);
Rust

len()

区間の個数を返します。

Examples
let mut interval_set = IntervalSet::new();
assert_eq!(interval_set.get_all(), vec![]);
assert_eq!(interval_set.len(), 0);

interval_set.insert(1, 2);
assert_eq!(interval_set.get_all(), vec![(1, 2)]);
assert_eq!(interval_set.len(), 1);

interval_set.insert(7, 9);
assert_eq!(interval_set.get_all(), vec![(1, 2), (7, 9)]);
assert_eq!(interval_set.len(), 2);

interval_set.insert(3, 5);
assert_eq!(interval_set.get_all(), vec![(1, 2), (3, 5), (7, 9)]);
assert_eq!(interval_set.len(), 3);
Rust

is_covered()

引数で l, r を渡します。
区間 [l, r) が追加されている場合は true を返します。
そうでない場合は false を返します。

Examples
let mut interval_set = IntervalSet::new();
interval_set.insert(1, 2);
interval_set.insert(7, 9);
interval_set.insert(3, 5);
assert_eq!(interval_set.get_all(), vec![(1, 2), (3, 5), (7, 9)]);

assert_eq!(interval_set.is_covered(0, 2), false);
assert_eq!(interval_set.is_covered(1, 2), true);
assert_eq!(interval_set.is_covered(1, 3), false);
assert_eq!(interval_set.is_covered(2, 3), false);
assert_eq!(interval_set.is_covered(2, 4), false);
assert_eq!(interval_set.is_covered(3, 4), true);
assert_eq!(interval_set.is_covered(3, 5), true);
assert_eq!(interval_set.is_covered(4, 5), true);
assert_eq!(interval_set.is_covered(6, 6), true);
Rust

mex()

引数で x を渡します。
mex を返します。

Examples
let mut interval_set = IntervalSet::new();
interval_set.insert(1, 2);
interval_set.insert(7, 9);
interval_set.insert(3, 5);
assert_eq!(interval_set.get_all(), vec![(1, 2), (3, 5), (7, 9)]);

assert_eq!(interval_set.mex(0), 0);
assert_eq!(interval_set.mex(1), 2);
assert_eq!(interval_set.mex(2), 2);
assert_eq!(interval_set.mex(3), 5);
assert_eq!(interval_set.mex(4), 5);
assert_eq!(interval_set.mex(5), 5);
assert_eq!(interval_set.mex(6), 6);
assert_eq!(interval_set.mex(7), 9);
assert_eq!(interval_set.mex(8), 9);
assert_eq!(interval_set.mex(9), 9);
assert_eq!(interval_set.mex(10), 10);
Rust

insert()

引数で l, r を渡します。
区間 [l, r) が追加されます。

Examples
let mut interval_set = IntervalSet::new();
assert_eq!(interval_set.get_all(), vec![]);
assert_eq!(interval_set.sum(), 0);
assert_eq!(interval_set.len(), 0);

interval_set.insert(1, 2);
assert_eq!(interval_set.get_all(), vec![(1, 2)]);
assert_eq!(interval_set.sum(), 1);
assert_eq!(interval_set.len(), 1);

interval_set.insert(7, 9);
assert_eq!(interval_set.get_all(), vec![(1, 2), (7, 9)]);
assert_eq!(interval_set.sum(), 3);
assert_eq!(interval_set.len(), 2);

interval_set.insert(3, 5);
assert_eq!(interval_set.get_all(), vec![(1, 2), (3, 5), (7, 9)]);
assert_eq!(interval_set.sum(), 5);
assert_eq!(interval_set.len(), 3);

interval_set.insert(4, 6);
assert_eq!(interval_set.get_all(), vec![(1, 2), (3, 6), (7, 9)]);
assert_eq!(interval_set.sum(), 6);
assert_eq!(interval_set.len(), 3);

interval_set.insert(6, 7);
assert_eq!(interval_set.get_all(), vec![(1, 2), (3, 9)]);
assert_eq!(interval_set.sum(), 7);
assert_eq!(interval_set.len(), 2);

interval_set.insert(-1, 1);
assert_eq!(interval_set.get_all(), vec![(-1, 2), (3, 9)]);
assert_eq!(interval_set.sum(), 9);
assert_eq!(interval_set.len(), 2);

interval_set.insert(-10, 10);
assert_eq!(interval_set.get_all(), vec![(-10, 10)]);
assert_eq!(interval_set.sum(), 20);
assert_eq!(interval_set.len(), 1);
Rust

remove()

引数で l, r を渡します。
区間 [l, r) が削除されます。

Examples
let mut interval_set = IntervalSet::new();
assert_eq!(interval_set.get_all(), vec![]);
assert_eq!(interval_set.sum(), 0);
assert_eq!(interval_set.len(), 0);

interval_set.insert(-10, 10);
assert_eq!(interval_set.get_all(), vec![(-10, 10)]);
assert_eq!(interval_set.sum(), 20);
assert_eq!(interval_set.len(), 1);

interval_set.remove(-20, -1);
interval_set.remove(2, 3);
interval_set.remove(9, 10);
assert_eq!(interval_set.get_all(), vec![(-1, 2), (3, 9)]);
assert_eq!(interval_set.sum(), 9);
assert_eq!(interval_set.len(), 2);

interval_set.remove(-1, 1);
assert_eq!(interval_set.get_all(), vec![(1, 2), (3, 9)]);
assert_eq!(interval_set.sum(), 7);
assert_eq!(interval_set.len(), 2);

interval_set.remove(6, 7);
assert_eq!(interval_set.get_all(), vec![(1, 2), (3, 6), (7, 9)]);
assert_eq!(interval_set.sum(), 6);
assert_eq!(interval_set.len(), 3);

interval_set.remove(5, 6);
assert_eq!(interval_set.get_all(), vec![(1, 2), (3, 5), (7, 9)]);
assert_eq!(interval_set.sum(), 5);
assert_eq!(interval_set.len(), 3);

interval_set.remove(2, 7);
assert_eq!(interval_set.get_all(), vec![(1, 2), (7, 9)]);
assert_eq!(interval_set.sum(), 3);
assert_eq!(interval_set.len(), 2);

interval_set.remove(4, 20);
assert_eq!(interval_set.get_all(), vec![(1, 2)]);
assert_eq!(interval_set.sum(), 1);
assert_eq!(interval_set.len(), 1);

interval_set.remove(-10, 10);
assert_eq!(interval_set.get_all(), vec![]);
assert_eq!(interval_set.sum(), 0);
assert_eq!(interval_set.len(), 0);
Rust

clear()

追加されている区間を空にします。

Examples
let mut interval_set = IntervalSet::new();
interval_set.insert(1, 2);
interval_set.insert(7, 9);
interval_set.insert(3, 5);
assert_eq!(interval_set.get_all(), vec![(1, 2), (3, 5), (7, 9)]);

interval_set.clear();
assert_eq!(interval_set.get_all(), vec![]);
assert_eq!(interval_set.sum(), 0);
assert_eq!(interval_set.len(), 0);
Rust

使用例

使用例
fn main() {
    let mut interval_set: IntervalSet<i64> = IntervalSet::new();
    assert_eq!(interval_set.get_all(), vec![]);
    assert_eq!(interval_set.sum(), 0);
    assert_eq!(interval_set.len(), 0);

    // 区間 [1, 2) を追加
    interval_set.insert(1, 2);
    assert_eq!(interval_set.get_all(), vec![(1, 2)]);
    assert_eq!(interval_set.sum(), 1);
    assert_eq!(interval_set.len(), 1);

    // 区間 [7, 9) を追加
    interval_set.insert(7, 9);
    assert_eq!(interval_set.get_all(), vec![(1, 2), (7, 9)]);
    assert_eq!(interval_set.sum(), 3);
    assert_eq!(interval_set.len(), 2);

    // 区間 [3, 5) を追加
    interval_set.insert(3, 5);
    assert_eq!(interval_set.get_all(), vec![(1, 2), (3, 5), (7, 9)]);
    assert_eq!(interval_set.sum(), 5);
    assert_eq!(interval_set.len(), 3);

    // 区間 [4, 10) を追加
    interval_set.insert(4, 10);
    assert_eq!(interval_set.get_all(), vec![(1, 2), (3, 10)]);
    assert_eq!(interval_set.sum(), 8);
    assert_eq!(interval_set.len(), 2);
}
Rust

問題例

ここでは「区間を set で管理するやつ」を使うことで楽に解くことができる問題をいくつか紹介します。