某岛

… : "…アッカリ~ン . .. . " .. .
June 2, 2023

换根 dp

#include <lastweapon/number>
using namespace lastweapon;

const int N = int(1e5) + 9;
VI adj[N]; Int f[N][3]; int c[N];
int n, k;

Int s(int u) {
    return f[u][0] + f[u][1] + f[u][2];
}

void dfs(int u = 0, int p = -1) {

    REP(i, 3) if (!c[u] || (c[u]-1) == i) f[u][i] = 1;

    for (auto v: adj[u]) if (v != p) {
        dfs(v, u);
        REP(i, 3) f[u][i] *= s(v) - f[v][i];
    }
}

int main() {

#ifndef ONLINE_JUDGE
    freopen("barnpainting.in", "r", stdin);
    freopen("barnpainting.out", "w", stdout);
#endif

    RD(n, k);
    DO(n-1) {
        int a, b; RD(a, b); --a, --b;
        adj[a].PB(b);
        adj[b].PB(a);
    }
    DO(k) {
        int a; RD(a); --a;
        RD(c[a]);
    }
    dfs();
    cout << s(0) << endl;
}

#include <lastweapon/number>
using namespace lastweapon;

const int N = int(3e5) + 9;
VI adj[N];PII f[N], g[N], h[N];int a[N], m[N], s[N];
int n;
LL k;

void dfs1(int u, int p) {
    for(int& v : adj[u]) {
        if(v == p) continue;
        dfs1(v, u);
        PII val = max(f[v], PII{a[v], -v});
        val.first--;
        if (val > f[u]) {
            h[u] = f[u];
            f[u] = val;
        } else if (val > h[u]) {
            h[u] = val;
        }
    }
}

void dfs2(int u, int p) {
    for(int& v : adj[u]) {
        if(v == p) continue;
        g[v] = max(g[u], (f[u] == max(PII{f[v].first-1, f[v].second}, {a[v]-1, -v}) ? h[u] : f[u]), PII{a[u], -u});
        g[v].first--;
        dfs2(v, u);
    }
}

int main() {

#ifndef ONLINE_JUDGE
    freopen("in.txt", "r", stdin);
    //freopen("out.txt", "w", stdout);
#elif
    //freopen("barnpainting.in", "r", stdin);
    //freopen("barnpainting.out", "w", stdout);
#endif

    RD(n, k);
    REP(i, n) RD(a[i]);
    int a, b;
    DO(n-1) {
        RD(a, b);
        a--, b--;
        adj[a].PB(b);
        adj[b].PB(a);
    }
    REP(i, n) {
        f[i] = {-INT_MAX, -INT_MAX};
        g[i] = {-INT_MAX, -INT_MAX};
        h[i] = {-INT_MAX, -INT_MAX};
    }
    dfs1(0, 0);
    dfs2(0, 0);

    //REP(i, n) cout << i << '/' << f[i] << '/' << g[i] << endl;

    int j = 0;
    REP_1(i, n+1) {
        j = -max(f[j], g[j]).second;
        if(m[j]) {
            int ans;
            if(k <= m[j]) {
                ans = k;
            } else {
                ans = ((k-m[j]) % (i-m[j])) + m[j];
            }
            cout << s[ans]+1 << endl;
            break;
        }
        m[j] = i;
        s[i] = j;
    }
}