0%

HDU 5730 解题报告

题意简述

你希望组合出一条长度为\(n\)的项链,给定\(a[1] \cdots a[n]\)\(a[i]\)表示长度为\(i\)的项链的种类有\(a[i]\)种,问一共能组合出多少种长度为\(n\)的项链。答案对\(313\)取模。

多组数据,数据不超过20组。

数据范围

\[1 \leq n \leq 10 ^ 5\] \[1 \leq a[i] \leq 10 ^ 7\]

题目链接

HDU 5730

题解

根据题意列出如下式子: \[f[n] = \sum _ {i = 0} ^ {n - 1} f[i]a[n - i]\] \(f[n]\)是要求的答案。

这是分治FFT的模板题,关于分治FFT的内容可以查看分治FFT 学习笔记

需要注意的是\(f和g\)不能用复数存储,且在\([l,mid]\)\([mid + 1,r]\)的贡献计算中要取模来保证精度不出错。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>

using namespace std;

typedef double d;

const double PI = acos(-1);
const int MAXN = 131080;
const int MODDER = 313;

template<typename T>
T read() {
T result = 0;int f = 1;int c = getchar();
while(c > '9' || c < '0') {if(c == '-') f *= -1;c = getchar();}
while(c <= '9' && c >= '0') {result = result * 10 + c - '0';c = getchar();}
return result * f;
}

struct C {
d r,i;

C() : r(0) , i(0) { }

C(d r,d i) : r(r) , i(i) { }

C operator + (const C &oC) const {
return C(r + oC.r,i + oC.i);
}

C operator - (const C &oC) const {
return C(r - oC.r,i - oC.i);
}

C operator * (const C &oC) const {
return C(r * oC.r - i * oC.i,r * oC.i + i * oC.r);
}
};

namespace FFT {
int bitValues[MAXN];

void init(int n) {
int bitCount = 0;
while((1 << bitCount) < n) bitCount++;
bitValues[0] = 0;
for(int i = 1;i < n;i++) {
bitValues[i] = (bitValues[i >> 1] >> 1) | ((i & 1) << (bitCount - 1));
}
}

void transform(C *a,int n,int iR) {
for(int i = 0;i < n;i++) {
if(bitValues[i] > i) {
swap(a[i],a[bitValues[i]]);
}
}
for(int l = 2;l <= n;l <<= 1) {
int mid = l >> 1;
C wn = C(cos(2 * PI / l),iR * sin(2 * PI / l));
for(C *pos = a;pos != a + n;pos += l) {
C w = C(1,0);
for(int i = 0;i < mid;i++) {
C x = pos[i],y = pos[mid + i] * w;
pos[i] = x + y;
pos[mid + i] = x - y;
w = w * wn;
}
}
}
}

void dft(C *a,int n) {
transform(a,n,1);
}

void idft(C *a,int n) {
transform(a,n,-1);
for(int i = 0;i < n;i++) {
a[i].r /= n;
}
}
}

using namespace FFT;

C A[MAXN],B[MAXN];

int f[MAXN],g[MAXN];

void cdq(int l,int r) {
if(l == r) {
return;
}
int mid = (l + r) >> 1,len = 1;
while(len < (r - l)) len <<= 1;
cdq(l,mid);
for(int i = 0;i < len;i++) {
A[i] = C();
B[i] = C();
}
for(int i = 0;i <= mid - l;i++) A[i] = C(f[i + l],0);
for(int i = 0;i <= r - l - 1;i++) B[i] = C(g[i + 1],0);
init(len);
dft(A,len);
dft(B,len);
for(int i = 0;i < len;i++) A[i] = A[i] * B[i];
idft(A,len);
for(int i = mid + 1;i <= r;i++) f[i] = (f[i] + static_cast<int>(round(A[i - l - 1].r))) % MODDER;
cdq(mid + 1,r);
}

int main() {
int n;
while(~scanf("%d",&n)) {
if(n == 0) return 0;
for(int i = 0;i <= n;i++) {
f[i] = 0;
g[i] = 0;
}
f[0] = 1;
for(int i = 1;i <= n;i++) {
g[i] = read<int>() % MODDER;
}
cdq(0,n);
printf("%d\n",f[n] % MODDER);
}
return 0;
}