h_nosonの日記

競プロなど

yukicoder No.335 門松宝くじ

問題
No.335 門松宝くじ - yukicoder
宝くじが2枚か3枚与えられ,それぞれ数字がN個書かれている.
当選日になると1つの宝くじにつきランダムに2つの数字が選ばれる.1つ数字を自由に選ぶことができ,その3つの数字で門松列が出来れば3つの数字の最大値が当選金額となる.
得られる当選金額の期待値が最大になる宝くじはどれか.

解法
愚直にやるとO(N^3)かかる.
2つの数字a,b
a{<}bの時,aより左の最大値,a,bの間の最大値最小値,bより右の最小値を比べる
a{>}bの時,aより左の最小値,a,bの間の最大値最小値,bより右の最大値を比べる
とすると最適な3つ目の数字が選べる.
よって,区間の最大値,最小値がわかればいいのでセグメント木を使う.
セグメント木を使えばO(N^2\log N)で解ける.

ソースコード

#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;

#define RREP(i,s,e) for (i = s; i >= e; i--)
#define rrep(i,n) RREP(i,(int)(n)-1,0)
#define REP(i,s,e) for (i = s; i <= e; i++)
#define rep(i,n) REP(i,0,(int)(n)-1)
#define INF 100000000

typedef long long ll;

int segmax[2048];
int segmin[2048];
int sz;

void update_max(int k, int a) {
    k += sz - 1;
    segmax[k] = a;
    while (k > 0) {
        k = (k - 1) / 2;
        segmax[k] = max(segmax[2*k+1],segmax[2*k+2]);
    }
}

void update_min(int k, int a) {
    k += sz - 1;
    segmax[k] = a;
    while (k > 0) {
        k = (k - 1) / 2;
        segmin[k] = min(segmin[2*k+1],segmin[2*k+2]);
    }
}

int query_max(int a, int b, int k, int l, int r) {
    if (r <= a || b <= l)
        return 0;
    if (a <= l && r <= b)
        return segmax[k];
    else
        return max(query_max(a,b,2*k+1,l,(l+r)/2),query_max(a,b,2*k+2,(l+r)/2,r));
}

int query_min(int a, int b, int k, int l, int r) {
    if (r <= a || b <= l)
        return INF;
    if (a <= l && r <= b)
        return segmin[k];
    else
        return min(query_min(a,b,2*k+1,l,(l+r)/2),query_min(a,b,2*k+2,(l+r)/2,r));
}

int main() {
    int i, j, k, l, n, m, ans;
    double mx;
    int e[3][800];
    cin >> n >> m;
    rep (i,m) rep (j,n) cin >> e[i][j];
    mx = ans = 0;
    sz = 1;
    while (sz < n) sz *= 2;
    rep (i,m) {
        rep (j,2*sz-1) {
            segmax[j] = 0;
            segmin[j] = INF;
        }
        rep (j,n) {
            update_max(j,e[i][j]);
            update_min(j,e[i][j]);
        }
        double sum = 0, cnt = 0;
        rep (j,n) REP (k,j+1,n-1) {
            int x, p = 0;
            if (e[i][j] < e[i][k]) {
                x = query_max(0,j,0,0,sz);
                if (x > e[i][j])
                    p = max(x,e[i][k]);
                x = query_max(j+1,k,0,0,sz);
                if (x > e[i][k])
                    p = max(p,x);
                x = query_min(j+1,k,0,0,sz);
                if (x < e[i][j])
                    p = max(p,e[i][k]);
                x = query_min(k+1,n,0,0,sz);
                if (x < e[i][k])
                    p = max(p,e[i][k]);
            }
            else {
                x = query_min(0,j,0,0,sz);
                if (x < e[i][j])
                    p = e[i][j];
                x = query_max(j+1,k,0,0,sz);
                if (x > e[i][j])
                    p = max(p,x);
                x = query_min(j+1,k,0,0,sz);
                if (x < e[i][k])
                    p = max(p,e[i][j]);
                x = query_max(k+1,n,0,0,sz);
                if (x > e[i][k])
                    p = max({p,e[i][k],x});
            }
            sum += p;
            cnt++;
        }
        sum /= cnt;
        if (mx < sum) {
            mx = sum;
            ans = i;
        }
    }
    cout << ans << endl;
    return 0;
}