0%

Luogu P4721 解题报告

题意简述

给定\(g[0] \cdots g[n - 1]\),求\(f[0] \cdots f[n - 1]\),其中 \[f[i] = \sum _ {j = 1} ^ i f[i - j]g[j]\] 边界为\(f[0] = 1\)。答案模\(998244353\)

数据范围

\[2 \leq n \leq 10 ^ 5\] \[0 \leq g[i] < 998244353\]

时间限制:1s 空间限制:128MB

题目链接

Luogu P4721

题解

变换一下式子,可以得到 \[f[i] = \sum _ {j = 0} ^ {i - 1} f[j]g[i - j]\] 这是分治FFT的模板题,关于分治FFT的内容可以查看分治FFT 学习笔记

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>

using namespace std;

typedef long long ll;

const int MODDER = 998244353;
const int G = 3;
const int MAXN = 262150;

template<typename T>
T read() {
T result = 0;int f = 1;int c = getchar();
while(c > '9' || c < '0') {if(c == '-') f *= -1;c = getchar();}
while(c <= '9' && c >= '0') {result = result * 10 + c - '0';c = getchar();}
return result * f;
}

ll quickPow(ll a,ll b) {
ll result = 1,base = a;
while(b) {
if(b & 1) result = (result * base) % MODDER;
base = (base * base) % MODDER;
b >>= 1;
}
return result % MODDER;
}

ll inv(ll value) {
return quickPow(value,MODDER - 2);
}

namespace NTT {
int bitValue[MAXN];

void init(int size) {
int bitCount = 0;
bitValue[0] = 0;
while((1 << bitCount) < size) bitCount++;
for(int i = 1;i < size;i++) {
bitValue[i] = (bitValue[i >> 1] >> 1) | ((i & 1) << (bitCount - 1));
}
}

void bitReverse(ll *a,int n) {
for(int i = 0;i < n;i++) {
if(bitValue[i] < i) {
swap(a[i],a[bitValue[i]]);
}
}
}

void transform(ll *a,int n,bool isReverse) {
bitReverse(a,n);
ll baseW = quickPow(G,(MODDER - 1) / n);
if(isReverse) {
baseW = inv(baseW);
}
for(int length = 2;length <= n;length <<= 1) {
int mid = length / 2;
ll wn = quickPow(baseW,n / length);
for(ll *pos = a;pos != a + n;pos += length) {
ll w = 1;
for(int i = 0;i < mid;i++) {
ll x = pos[i],y = pos[mid + i] * w % MODDER;
pos[i] = (x + y) % MODDER;
pos[mid + i] = (x - y + MODDER) % MODDER;
w = (w * wn) % MODDER;
}
}
}
}

void dft(ll *a,int n) {
transform(a,n,false);
}

void idft(ll *a,int n) {
transform(a,n,true);
ll x = inv(n);
for(int i = 0;i < n;i++) {
a[i] = (a[i] * x) % MODDER;
}
}


void multiply(ll *a,ll *b,int length) {
dft(a,length);
dft(b,length);
for(int i = 0;i < length;i++) {
a[i] = (a[i] * b[i]) % MODDER;
}
idft(a,length);
}
}

using namespace NTT;

ll f[MAXN],g[MAXN],A[MAXN],B[MAXN];

int n,m;

void cdq(int l,int r) {
if(l == r) {
return;
}
int mid = (l + r) >> 1,length = 1;
while(length < (r - l)) length <<= 1;
cdq(l,mid);
for(int i = 0;i < length;i++) {
A[i] = 0;
B[i] = 0;
}
for(int i = 0;i <= mid - l;i++) A[i] = f[i + l];
for(int i = 0;i <= r - l - 1;i++) B[i] = g[i + 1];
init(length);
dft(A,length);
dft(B,length);
for(int i = 0;i < length;i++) A[i] = (A[i] * B[i]) % MODDER;
idft(A,length);
for(int i = mid + 1;i <= r;i++) f[i] = (f[i] + A[i - l - 1]) % MODDER;
cdq(mid + 1,r);
}

int main() {
n = read<int>();
for(int i = 1;i <= n - 1;i++) {
g[i] = read<ll>();
}
f[0] = 1;
cdq(0,n - 1);
for(int i = 0;i < n;i++) printf("%lld ",f[i]);
return 0;
}