c/c++语言开发共享洛谷P4007 小 Y 和恐怖的奴隶主(期望dp 矩阵乘法)

题意 “题目链接” Sol 首先不难想到一种暴力dp,设$f[i][a][b][c]$表示还有$i$轮没打,场上有$a$个1血,$b$个2血,$c$个三血 发现状态数只有$s = 166$个,复杂度为$O(ns)$ 矩乘优化一下复杂度为$O(s^3 logn T)$,还是过不去。 因为每次询问都是独 …


题意

题目链接

sol

首先不难想到一种暴力dp,设(f[i][a][b][c])表示还有(i)轮没打,场上有(a)个1血,(b)个2血,(c)个三血

发现状态数只有(s = 166)个,复杂度为(o(ns))

矩乘优化一下复杂度为(o(s^3 logn t)),还是过不去。

因为每次询问都是独立的,那么可以预处理出(2^i)的转移矩阵,回答询问只需要拿一个行向量去乘log个矩阵

构造矩阵的时候可以加一个列向量表示期望

#include<bits/stdc++.h> #define ll long long  using namespace std; const int b = 60, mod = 998244353; template <typename a, typename b> inline bool chmin(a &a, b b){if(a > b) {a = b; return 1;} return 0;} template <typename a, typename b> inline bool chmax(a &a, b b){if(a < b) {a = b; return 1;} return 0;} template <typename a, typename b> inline ll add(a x, b y) {if(x + y < 0) return x + y + mod; return x + y >= mod ? x + y - mod : x + y;} template <typename a, typename b> inline void add2(a &x, b y) {if(x + y < 0) x = x + y + mod; else x = (x + y >= mod ? x + y - mod : x + y);} ll mul(int x, int y) {return 1ll * x * y % mod;}  inline ll read() {     char c = getchar(); ll x = 0, f = 1;     while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}     while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();     return x * f; } int fp(int a, int p) {     int base = 1;     while(p) {         if(p & 1) base = mul(base, a);         a = mul(a, a); p >>= 1;     }     return base; } int t, m, k;   namespace s3 {     int id[11][11][11], cnt, lim;     int ans[168];     ll inv[11];          struct ma {         int m[168][168];         ma() {             memset(m, 0, sizeof(m));             }         void init() {             for(int i = 0; i <= lim; i++) m[i][i] = 1;         }         void print() {             for(int i = 1; i <= lim; i++, puts(""))                 for(int j = 1; j <= lim; j++)                     printf("%d ", m[i][j]);         }         ma operator * (const ma &rhs) const {             ma gg = {};             for(int i = 1; i <= lim; i++)                 for(int j = 1; j <= lim; j++) {                     __int128 tmp = 0;                     for(int k = 1; k <= lim; k++)                          tmp += mul(m[i][k], rhs.m[k][j]);                     tmp %= mod;                     gg.m[i][j] = tmp;                 }                                      return gg;         }     }f[b + 1];     void pre() {         for(int i = 1; i <= k + 1; i++) inv[i] = fp(i, mod - 2);         for(int a = 0; a <= k; a++)              for(int b = 0; a + b <= k; b++)                 for(int c = 0; a + b + c <= k; c++)                     id[a][b][c] = ++cnt;         for(int a = 0; a <= k; a++)              for(int b = 0; a + b <= k; b++)                 for(int c = 0; a + b + c <= k; c++) {                     int down = inv[a + b + c + 1], tag = (a + b + c < k), now = id[a][b][c];                     if(a) f[0].m[now][id[a - 1][b][c]] = mul(a, down);                     if(b) f[0].m[now][id[a + 1][b - 1][c + tag]] = mul(b, down);                     if(c) f[0].m[now][id[a][b + 1][c - 1 + tag]] = mul(c, down);                     f[0].m[now][now] = down;                     f[0].m[now][cnt + 1] = down;                 }         f[0].m[cnt + 1][cnt + 1] = 1;         lim = cnt + 1;         for(int i = 1; i <= b; i++) f[i] = f[i - 1] * f[i - 1];     }     int tmp[168];     void mul(ma a) {         memset(tmp, 0, sizeof(tmp));         for(int j = 1; j <= lim; j++)             for(int i = 1; i <= lim; i++)                 add2(tmp[j], 1ll * ans[i] * a.m[i][j] % mod);         memcpy(ans, tmp, sizeof(tmp));     }     void matrixpow(ll p) {         for(int i = 0; p; p >>= 1, i++)             if(p & 1)                  mul(f[i]);     }        void work() {         pre();         while(t--) {             ll n = read();             memset(ans, 0, sizeof(ans)); ans[id[0][0][1]] = 1;             matrixpow(n);             cout << ans[cnt + 1] << 'n';         }        } }   namespace s2 {     int id[11][11], cnt, lim;     int ans[168];     ll inv[11];          struct ma {         int m[168][168];         ma() {             memset(m, 0, sizeof(m));             }         void init() {             for(int i = 0; i <= lim; i++) m[i][i] = 1;         }         void print() {             for(int i = 1; i <= lim; i++, puts(""))                 for(int j = 1; j <= lim; j++)                     printf("%d ", m[i][j]);         }         ma operator * (const ma &rhs) const {             ma gg = {};             for(int i = 1; i <= lim; i++)                 for(int j = 1; j <= lim; j++) {                     __int128 tmp = 0;                     for(int k = 1; k <= lim; k++)                          tmp += mul(m[i][k], rhs.m[k][j]);                     tmp %= mod;                     gg.m[i][j] = tmp;                 }                                      return gg;         }     }f[b + 1];     void pre() {         for(int i = 1; i <= k + 1; i++) inv[i] = fp(i, mod - 2);         for(int a = 0; a <= k; a++)              for(int b = 0; a + b <= k; b++)                 id[a][b] = ++cnt;         for(int a = 0; a <= k; a++)              for(int b = 0; a + b <= k; b++) {                 int down = inv[a + b + 1], tag = (a + b < k), now = id[a][b];                 if(a) f[0].m[now][id[a - 1][b]] = mul(a, down);                 if(b) f[0].m[now][id[a + 1][b - 1 + tag]] = mul(b, down);                 f[0].m[now][now] = down;                 f[0].m[now][cnt + 1] = down;             }         f[0].m[cnt + 1][cnt + 1] = 1;         lim = cnt + 1;         for(int i = 1; i <= b; i++) f[i] = f[i - 1] * f[i - 1];     }     int tmp[168];     void mul(ma a) {         memset(tmp, 0, sizeof(tmp));         for(int j = 1; j <= lim; j++)             for(int i = 1; i <= lim; i++)                 add2(tmp[j], 1ll * ans[i] * a.m[i][j] % mod);         memcpy(ans, tmp, sizeof(tmp));     }     void matrixpow(ll p) {         for(int i = 0; p; p >>= 1, i++)             if(p & 1)                  mul(f[i]);     }        void work() {         pre();         while(t--) {             ll n = read();             memset(ans, 0, sizeof(ans)); ans[id[0][1]] = 1;             matrixpow(n);             cout << ans[cnt + 1] << 'n';         }        } }  namespace s1 {     int n,  f[12][9][9][9];      int inv(int a) {         return fp(a, mod - 2);     }     void work() {         n = 11;         for(int i = 1; i <= n; i++) {             for(int a = 0; a <= k; a++) {                 for(int b = 0; a + b <= k; b++) {                     for(int c = 0; a + b + c <= k; c++) {                         int down = a + b + c + 1;                         if(a) add2(f[i][a][b][c], mul(mul(a, inv(down)), f[i - 1][a - 1][b][c]));                         if(b) {                             if(down <= k) add2(f[i][a][b][c], mul(mul(b, inv(down)), f[i - 1][a + 1][b - 1 + (m == 2)][c + (m == 3)]));                             else add2(f[i][a][b][c], mul(mul(b, inv(down)), f[i - 1][a + 1][b - 1][c]));                         }                         if(c) {                             if(down <= k) add2(f[i][a][b][c], mul(mul(c, inv(down)), f[i - 1][a][b + 1 + (m == 2)][c - 1 + (m == 3)]));                             else add2(f[i][a][b][c], mul(mul(c, inv(down)), f[i - 1][a][b + 1][c - 1]));                         }                         add2(f[i][a][b][c], mul(inv(down), f[i - 1][a][b][c] + 1));                     }                 }             }         }         while(t--) {             int n = read();             printf("%dn", f[n][m == 1][m == 2][m == 3]);         }     } }  int main() {     t = read(); m = read(); k = read();     if(m == 1) s1::work();     else if(m == 2) s2::work();     else s3::work();      return 0; }

本文来自网络收集,不代表计算机技术网立场,如涉及侵权请联系管理员删除。

ctvol管理联系方式QQ:251552304

本文章地址:https://www.ctvol.com/c-cdevelopment/605223.html

(0)
上一篇 2021年5月13日
下一篇 2021年5月13日

精彩推荐