[백준 19242] 행렬 합

문제 링크

https://www.acmicpc.net/problem/19242

풀이

부분행렬의 합을 구하기 위해서 나이브한 방법 대신 A의 누적 합 배열인 S를 만들어 놓으면 $[x_1,y_1]$ ~ $[x_2,y_2]$ 부분행렬의 구간합을 $S[x_2][y_2] - S[x_2][y_1-1] - S[x_1-1][y_2] + S[x_1-1][y_1-1]$라는 식을 이용해 $O(1)$만에 구할 수 있습니다.

$y_1$와 $y_2$를 고정한다면 $x_1≤x_2$인 $x_1$과 $x_2$의 모든 쌍에 대한 구간이 존재합니다. $x_1$과 $x_2$의 모든 쌍을 검사하면 구간합이 $X$이하인 부분행렬의 갯수를 찾을 수 있겠지만 검사하기 위해서는 $O(N^2)$이라는 시간이 소요됩니다. $S[x_2][y_2]$ $-S[x_2][y_1-1]$ $-S[x_1-1][y_2]$ $+S[x_1-1][y_1-1]$ $≤$ $X$에서 $x_2$와 관련된 식을 넘긴다면 $-S[x_1-1][y_2]$ $+S[x_1-1][y_1-1]$ $≤$ $X-S[x_2][y_2]$ $+S[x_2][y_1-1]$로 $x_1$과 $x_2$를 분리해서 생각할 수 있습니다. $x_1 ≤ x_2$인 모든 $x_1$의 $-S[x_1-1][y_2]$ $+S[x_1-1][y_1-1]$을 저장하면서 저장한 값들 중에서 $X$ $-S[x_2][y_2]$ $+S[x_2][y_1-1]$이하인 것의 개수를 빠르게 구할 수 있으면 시간을 줄일 수 있습니다. $x_2$를 $1$부터 $N$까지 보면서 $K$이하인 것의 개수를 찾은 다음 $-S[x_2-1][y_2]$ $+S[x_2-1][y_1-1]$를 저장한다면 다음 연산에 사용할 수 있어 다시 볼 필요가 없습니다. 저장 연산을 $O(\log N)$, $K$이하인 값을 찾는데 $O(\log N)$인 자료형을 이용해 구현한다면 $O(M^2N \log N)$이라는 시간 복잡도를 가지게 됩니다.

두 연산을 지원하는 pbds를 이용한다면 상수차이로 시간초과가 나기 때문에 상수가 적은 좌표 압축 + 윅 트리를 이용해 구현해야 합니다. 만약 $X$ $-S[x_2][y_2]$ $+S[x_2][y_1-1]$도 좌표압축에 넣는다면 시간초과가 나기 때문에 $-S[x_1-1][y_2]$ $+S[x_1-1][y_1-1]$만 넣어서 좌표압축을 해야합니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include<bits/stdc++.h>
using namespace std;
using ll = long long;

#define f(a) lower_bound(v.begin(), v.end(), a) - v.begin() + 1

int tc, n, m, x, t[333];
ll a[303][202];

void update(int bit, int val){
    while(bit <= n+2){
        t[bit] += val;
        bit += bit & (-bit);
    }
}

int find(int bit){
    int re = 0;
    while(bit){
        re += t[bit];
        bit -= bit & (-bit);
    }
    return re;
}

int main(){
    ios_base::sync_with_stdio(0);cin.tie(0);
    cin >> tc;
    while(tc--){
        cin >> n >> m >> x;
        for(int i=1;i<=n;i++) for(int j=1;j<=m;j++) cin >> a[i][j];
        for(int i=1;i<=n;i++) for(int j=1;j<=m;j++) a[i][j] = a[i][j] + a[i-1][j] + a[i][j-1] - a[i-1][j-1];

        int ans = 0;

        for(int j=1;j<=m;j++){
            for(int k=1;k<=j;k++){
                memset(t, 0, sizeof(t));

                vector<ll> v(n+1);
                for(int i=1;i<=n;i++) v[i] = -a[i][j] + a[i][k-1];
                sort(v.begin(), v.end());
                v.erase(unique(v.begin(), v.end()), v.end());

                update(f(0), 1);
                for(int i=1;i<=n;i++){
                    ans += find(f(x - (a[i][j] - a[i][k-1]) + 1) - 1);
                    update(f(-a[i][j] + a[i][k-1]), 1);
                }
            }
        }

        cout << ans << "\n";
    }
}